# Graph Classification with Batching

**Goal:** This first tutorial will go through an example of graph classification using mini-batching. 

**Concepts:** `Mini-batching`, `Readout`, `GNN training`, `MLX syntax`

In [77]:
from collections import defaultdict

import mlx.core as mx
from mlx_graphs.datasets import TUDataset

## Dataset

For this first tutorial, we will use the [TUDatasets](https://chrsmrrs.github.io/datasets/docs/datasets/) collection, which comprises more than 120 datasets for graph classification and graph regression tasks. 

The datasets proposed in this collection can be easily accessed via the `TUDataset` class.

We will use here the `MUTAG` dataset, where input graphs represent chemical compounds, with vertices symbolizing atoms identified by their atom type through one-hot encoding. Edges between vertices denote the bonds connecting the atoms. The dataset comprises 188 samples of chemical compounds, featuring 7 distinct node labels.

In [78]:
dataset = TUDataset("MUTAG")
dataset

MUTAG(num_graphs=188)

Dataset properties can directly accessed from the `dataset` object, and we can also compute some statistics to better understand the dataset.

In [79]:
# Some useful properties
print("Dataset attributes")
print("-" * 20)
print(f"Number of graphs: {len(dataset)}")
print(f"Number of node features: {dataset.num_node_features}")
print(f"Number of edge features: {dataset.num_edge_features}")
print(f"Number of graph features: {dataset.num_graph_features}")
print(f"Number of graph classes to predict: {dataset.num_graph_classes}\n")

# Statistics of the dataset
stats = defaultdict(list)
for g in dataset:
    stats["Mean node degree"].append(g.num_edges / g.num_nodes)
    stats["Mean num of nodes"].append(g.num_nodes)
    stats["Mean num of edges"].append(g.num_edges)

print("Dataset stats")
print("-" * 20)
for k, v in stats.items():
    mean = mx.mean(mx.array(v)).item()
    print(f"{k}: {mean:.2f}")

Dataset attributes
--------------------
Number of graphs: 188
Number of node features: 7
Number of edge features: 4
Number of graph features: 0
Number of graph classes to predict: 2

Dataset stats
--------------------
Mean node degree: 2.19
Mean num of nodes: 17.93
Mean num of edges: 39.59


A `Dataset` is nothing more than a wrapper around a list of `GraphData` objects. In **mlx-graphs**, a `GraphData` object contains the structure along with features of a graph, similarly as [DGLGraph](https://docs.dgl.ai/en/2.0.x/api/python/dgl.DGLGraph.html#dgl.DGLGraph) in DGL or [Data](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.Data.html) in PyG.

We can directly access these graphs from the dataset using indexing.

In [80]:
dataset[0]

GraphData(
	edge_index(shape=(2, 38), int32)
	node_features(shape=(17, 7), float32)
	edge_features(shape=(38, 4), float32)
	graph_labels(shape=(1,), int32))

The first graph of this dataset comprises 38 edges with 4 edge features and 17 nodes with 7 node features.

When indexing a dataset with sequences or slices, we end up with another `Dataset` object containing the graphs associated with this sequence. Using this indexing strategy, the dataset can be divided into train and test sets.

In [81]:
train_dataset = dataset[:150]
test_dataset = dataset[150:]

print(f"Training dataset: {train_dataset}")
print(f"Testing dataset: {test_dataset}")

Training dataset: MUTAG(num_graphs=150)
Testing dataset: MUTAG(num_graphs=38)


In [82]:
from mlx_graphs.loaders import Dataloader

BATCH_SIZE = 64

train_loader = Dataloader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = Dataloader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

for step, data in enumerate(train_loader):
    print(f'Step {step + 1}:')
    print('=======')
    print(f'Number of graphs in the current batch: {data.num_graphs}')
    print(data)
    print()

Step 1:
Number of graphs in the current batch: 64
GraphDataBatch(
	edge_index(shape=(2, 2590), int32)
	node_features(shape=(1168, 7), float32)
	edge_features(shape=(2590, 4), float32)
	graph_labels(shape=(64,), int32))

Step 2:
Number of graphs in the current batch: 64
GraphDataBatch(
	edge_index(shape=(2, 2620), int32)
	node_features(shape=(1179, 7), float32)
	edge_features(shape=(2620, 4), float32)
	graph_labels(shape=(64,), int32))

Step 3:
Number of graphs in the current batch: 22
GraphDataBatch(
	edge_index(shape=(2, 720), int32)
	node_features(shape=(337, 7), float32)
	edge_features(shape=(720, 4), float32)
	graph_labels(shape=(22,), int32))



In [83]:
import mlx.nn as nn
from mlx_graphs.nn import GCNConv, global_mean_pool


class GCN(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, dropout=0.5):
        super(GCN, self).__init__()

        self.conv1 = GCNConv(in_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, hidden_dim)
        self.lin = nn.Linear(hidden_dim, out_dim)

        self.dropout = nn.Dropout(p=dropout)

    def __call__(self, edge_index, node_features, batch_indices):
        h = nn.relu(self.conv1(edge_index, node_features))
        h = nn.relu(self.conv2(edge_index, h))
        h = self.conv3(edge_index, h)

        h = global_mean_pool(h, batch_indices)

        h = self.dropout(h)
        h = self.lin(h)
        
        return h

In [84]:
mx.random.seed(42)

def loss_fn(y_hat, y, parameters=None):
    return mx.mean(nn.losses.cross_entropy(y_hat, y))

def eval_fn(y_hat, y):
    return mx.mean(mx.argmax(y_hat, axis=1) == y)

def forward_fn(model, graph, labels):
    y_hat = model(graph.edge_index, graph.node_features, graph.batch_indices)
    loss = loss_fn(y_hat, labels, model.parameters())
    return loss, y_hat

In [85]:
import mlx.optimizers as optim

model = GCN(
    in_dim=dataset.num_node_features,
    hidden_dim=64,
    out_dim=dataset.num_graph_classes,
)
print(model)

mx.eval(model.parameters())

optimizer = optim.Adam(learning_rate=0.01)
loss_and_grad_fn = nn.value_and_grad(model, forward_fn)

def train(train_loader):
    loss_sum = 0.0
    for graph in train_loader:
        (loss, y_hat), grads = loss_and_grad_fn(
            model=model,
            graph=graph,
            labels=graph.graph_labels,
        )
        optimizer.update(model, grads)
        mx.eval(model.parameters(), optimizer.state)
        loss_sum += loss.item()
    return loss_sum / len(train_loader.dataset)

def test(loader):
    acc = 0.0
    for graph in loader:
        y_hat = model(graph.edge_index, graph.node_features, graph.batch_indices)
        y_hat = y_hat.argmax(axis=1)
        acc += (y_hat == graph.graph_labels).sum().item()
    
    return acc / len(loader.dataset)


for epoch in range(170):

    loss = train(train_loader)
    train_acc = test(train_loader)
    test_acc = test(test_loader)

    print(
        " | ".join(
            [
                f"Epoch: {epoch:3d}",
                f"Train loss: {loss:.3f}",
                f"Train acc: {train_acc:.3f}",
                f"Test acc: {test_acc:.3f}",
            ]
        )
    )

GCN(
  (conv1): GCNConv(
    (linear): Linear(input_dims=7, output_dims=64, bias=True)
  )
  (conv2): GCNConv(
    (linear): Linear(input_dims=64, output_dims=64, bias=True)
  )
  (conv3): GCNConv(
    (linear): Linear(input_dims=64, output_dims=64, bias=True)
  )
  (lin): Linear(input_dims=64, output_dims=2, bias=True)
  (dropout): Dropout(p=0.5)
)
Epoch:   0 | Train loss: 0.015 | Train acc: 0.580 | Test acc: 0.526
Epoch:   1 | Train loss: 0.013 | Train acc: 0.660 | Test acc: 0.684
Epoch:   2 | Train loss: 0.012 | Train acc: 0.660 | Test acc: 0.684
Epoch:   3 | Train loss: 0.013 | Train acc: 0.647 | Test acc: 0.684
Epoch:   4 | Train loss: 0.013 | Train acc: 0.660 | Test acc: 0.684
Epoch:   5 | Train loss: 0.013 | Train acc: 0.653 | Test acc: 0.658
Epoch:   6 | Train loss: 0.014 | Train acc: 0.620 | Test acc: 0.684
Epoch:   7 | Train loss: 0.013 | Train acc: 0.660 | Test acc: 0.684
Epoch:   8 | Train loss: 0.014 | Train acc: 0.667 | Test acc: 0.684
Epoch:   9 | Train loss: 0.014 | Tra