In [1]:
import torch
from torch_geometric.data import Dataset
from torch_geometric.datasets import MNISTSuperpixels
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv
from torch.nn import Linear, CrossEntropyLoss
import torch.nn.functional as F
from torch_geometric.nn import MLP, GINConv, global_add_pool
import matplotlib.pyplot as plt
import time

torch.manual_seed(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MNIST_PATH = "../datasets/MNISTSuperpixel"



In [2]:
class GIN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers):
        super().__init__()

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            mlp = MLP([in_channels, hidden_channels, hidden_channels])
            self.convs.append(GINConv(nn=mlp, train_eps=False))
            in_channels = hidden_channels

        self.mlp = MLP([hidden_channels, hidden_channels, out_channels],
                       norm=None, dropout=0.5)
        
    def forward(self, x, edge_index, batch, batch_size):
        for conv in self.convs:
            x = conv(x, edge_index).relu()
        # Pass the batch size to avoid CPU communication/graph breaks:
        x = global_add_pool(x, batch, size=batch_size)
        return self.mlp(x)

In [3]:
dataset: Dataset = MNISTSuperpixels(root=MNIST_PATH).shuffle()
train_loader = DataLoader(dataset[:0.9], batch_size=128)
test_loader = DataLoader(dataset[0.9:], batch_size=128)

In [4]:
model = GIN(
    in_channels=dataset.num_features,
    hidden_channels=32,
    out_channels=dataset.num_classes,
    num_layers=5,
).to(device)

# Compile the model into an optimized version:
model = torch.compile(model, dynamic=True)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [5]:
def train():
    model.train()

    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch, data.batch_size)
        loss = F.cross_entropy(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * data.num_graphs
    return total_loss / len(train_loader.dataset)


@torch.no_grad()
def test(loader):
    model.eval()

    total_correct = 0
    for data in loader:
        data = data.to(device)
        out = model(data.x, data.edge_index, data.batch, data.batch_size)
        pred = out.argmax(dim=-1)
        total_correct += int((pred == data.y).sum())
    return total_correct / len(loader.dataset)

In [6]:
times = []
for epoch in range(1, 101):
    start = time.time()
    loss = train()
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    times.append(time.time() - start)
    print(f'\rEpoch: {epoch:03d}, Loss: {loss:.4f}, Train_acc: {train_acc:.4f}, '
          f'Test_acc: {test_acc:.4f}', end="")
print(f'\nMedian time per epoch: {torch.tensor(times).median():.4f}s')

Epoch: 100, Loss: 1.2007, Train_acc: 0.6257, Test_acc: 0.6130
Median time per epoch: 10.3899s


In [11]:
test_dataset: Dataset = MNISTSuperpixels(root=MNIST_PATH, train=False)
loader_test = DataLoader(dataset[:500], batch_size=1)

In [13]:
correct = 0
for test in loader_test:
    test.to(device)
    pred = model(test.x, test.edge_index, test.batch, test.batch_size).softmax(1).argmax(1)
    correct += int((pred == test.y).sum())
print(f"Accuracy: {correct/500 * 100:.2f}%")

Accuracy: 63.20%
