In [3]:
import torch
import torchvision 
import torchvision.datasets as datasets
import numpy as np
from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.loader import DataLoader

In [4]:
# PyG Dataset 
class GraphMNIST(InMemoryDataset):
    def __init__(self, patch_size, root = None, train = True, transform = None, pre_transform = None, pre_filter = None):
        super().__init__(root, transform, pre_transform, pre_filter)
        X, Y, edge_index = self.loadMNIST(patch_size, train)
        data = [Data(x=x, edge_index=edge_index, y=y) for x, y in zip(X, Y.view(Y.shape[0], 1))]
        self.data, self.slices = self.collate(data)       

    def loadMNIST(self, patch_size, train = True):
        # loading MNIST
        mnist =  datasets.MNIST(root='./data', train=train, download=True, transform=None)
 
        # each image is divided into patches of shape patch_size
        X = mnist.data.unfold(1,patch_size[0],patch_size[0]).unfold(2,patch_size[1],patch_size[1])
        X = X.flatten(start_dim=1,end_dim=2).flatten(start_dim=2).to(torch.float)

        # labels
        Y = mnist.targets

        # adjacency matrix will always be the same:
        # each patch is adjacent to the 8 surrounding patches
        row = []
        col = []
        dim0, dim1 = int(28/patch_size[0]), int(28/patch_size[1])   
        for i in range(dim0):
            for j in range(dim1):
                for i_adj in range(i-1, i+2):
                    for j_adj in range(j-1, j+2):
                        if i_adj >= 0 and i_adj < dim0 and j_adj >=0 and j_adj < dim1 and not (i == i_adj and j == j_adj):
                            row.append(i * dim0 + j)         # patch at (i,j)
                            col.append(i_adj * dim0 + j_adj) # is adjacent to patch at (i_adj, j_adj)
        row = torch.tensor(row, dtype=torch.long)
        col = torch.tensor(col, dtype=torch.long)
        edge_index = torch.stack([row, col])
        return X, Y, edge_index

In [5]:
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool, global_max_pool

In [11]:
class GCN(torch.nn.Module):
    def __init__(self, data):
        super(GCN, self).__init__()
        # using architecture inspired by MNISTSuperpixels example 
        # (https://medium.com/@rtsrumi07/understanding-graph-neural-network-with-hands-on-example-part-2-139a691ebeac)
        hidden_channel_size = 32
        self.initial_conv = GCNConv(data.num_features, hidden_channel_size)
        self.conv1 = GCNConv(hidden_channel_size, hidden_channel_size)
        self.conv2 = GCNConv(hidden_channel_size, hidden_channel_size)
        # self.conv3 = GCNConv(hidden_channel_size, data.num_features)
        self.out = nn.Linear(hidden_channel_size*2, data.num_classes)

    def forward(self, x, edge_index, batch_index):
        hidden = self.initial_conv(x, edge_index)
        hidden = F.relu(hidden)
        hidden = self.conv1(hidden, edge_index)
        hidden = F.relu(hidden)
        hidden = self.conv2(hidden, edge_index)
        hidden = F.relu(hidden)
        # hidden = self.conv3(hidden, edge_index)
        # hidden = F.relu(hidden)
        hidden = torch.cat([global_mean_pool(hidden, batch_index),
                            global_max_pool(hidden, batch_index)], dim=1)
        out = self.out(hidden)
        return out 


In [7]:
def train(dataloader, model, loss_fn, optimizer, device):
    size = len(dataloader.dataset)
    for batch, b in enumerate(dataloader):
        b.to(device)
        pred = model(b.x, b.edge_index, b.batch)
        loss = loss_fn(pred, b.y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch
            print(f"loss: {loss:>7f}  [{(current*64):>5d}/{size:>5d}]")

def test(dataloader, model, loss_fn, device):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for d in dataloader:
            d.to(device)
            pred = model(d.x, d.edge_index, d.batch)
            test_loss += loss_fn(pred, d.y).item()
            correct += (pred.argmax(1) == d.y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")


With patch size 4x4

In [12]:
train_ds_4x4 = GraphMNIST((4,4), train=True)
test_ds_4x4  = GraphMNIST((4,4), train=False)
train_loader_4x4 = DataLoader(train_ds_4x4, batch_size=64)
test_loader_4x4  = DataLoader(test_ds_4x4, batch_size=64)

print(f"Train dataset {train_ds_4x4}: {train_ds_4x4[0]}")
print(f"Test dataset  {test_ds_4x4}: {test_ds_4x4[0]}")

Train dataset GraphMNIST(60000): Data(x=[49, 16], edge_index=[2, 312], y=[1])
Test dataset  GraphMNIST(10000): Data(x=[49, 16], edge_index=[2, 312], y=[1])


In [23]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model_4x4 = GCN(train_ds_4x4).to(device)
optimizer = torch.optim.Adam(model_4x4.parameters(), lr=0.01, weight_decay=5e-4)
loss_fn = torch.nn.CrossEntropyLoss()

epochs = 30
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_loader_4x4, model_4x4, loss_fn, optimizer, device)
    test(test_loader_4x4, model_4x4, loss_fn, device)
print("Done!")


cuda
Epoch 1
-------------------------------
loss: 24.197182  [    0/60000]
loss: 1.684861  [ 6400/60000]
loss: 1.466675  [12800/60000]
loss: 1.299117  [19200/60000]
loss: 1.015022  [25600/60000]
loss: 1.056981  [32000/60000]
loss: 0.650261  [38400/60000]
loss: 1.039791  [44800/60000]
loss: 1.014659  [51200/60000]
loss: 0.880535  [57600/60000]
Test Error: 
 Accuracy: 70.1%, Avg loss: 0.870777 

Epoch 2
-------------------------------
loss: 0.900181  [    0/60000]
loss: 0.765241  [ 6400/60000]
loss: 0.562938  [12800/60000]
loss: 0.706905  [19200/60000]
loss: 0.759695  [25600/60000]
loss: 1.050992  [32000/60000]
loss: 0.504764  [38400/60000]
loss: 0.758036  [44800/60000]
loss: 0.904577  [51200/60000]
loss: 0.697807  [57600/60000]
Test Error: 
 Accuracy: 77.2%, Avg loss: 0.684833 

Epoch 3
-------------------------------
loss: 0.757734  [    0/60000]
loss: 0.576604  [ 6400/60000]
loss: 0.413162  [12800/60000]
loss: 0.501683  [19200/60000]
loss: 0.538255  [25600/60000]
loss: 0.857475  [320

With patch size 2x2

In [16]:
train_ds_2x2 = GraphMNIST((2,2), train=True)
test_ds_2x2  = GraphMNIST((2,2), train=False)
train_loader_2x2 = DataLoader(train_ds_2x2, batch_size=64, shuffle=True)
test_loader_2x2  = DataLoader(test_ds_2x2, batch_size=64, shuffle=True)

print(f"Train dataset {train_ds_2x2}: {train_ds_2x2[0]}")
print(f"Test dataset  {test_ds_2x2}: {test_ds_2x2[0]}")

Train dataset GraphMNIST(60000): Data(x=[196, 4], edge_index=[2, 1404], y=[1])
Test dataset  GraphMNIST(10000): Data(x=[196, 4], edge_index=[2, 1404], y=[1])


In [22]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_2x2 = GCN(train_ds_2x2).to(device)
optimizer = torch.optim.Adam(model_2x2.parameters(), lr=0.01, weight_decay=5e-4)
loss_fn = torch.nn.CrossEntropyLoss()

epochs = 30
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_loader_2x2, model_2x2, loss_fn, optimizer, device)
    test(test_loader_2x2, model_2x2, loss_fn, device)
print("Done!")

Epoch 1
-------------------------------
loss: 14.343238  [    0/60000]
loss: 1.755637  [ 6400/60000]
loss: 1.762703  [12800/60000]
loss: 1.504525  [19200/60000]
loss: 1.312520  [25600/60000]
loss: 1.044519  [32000/60000]
loss: 0.962149  [38400/60000]
loss: 1.148144  [44800/60000]
loss: 1.024849  [51200/60000]
loss: 1.033727  [57600/60000]
Test Error: 
 Accuracy: 69.0%, Avg loss: 0.945309 

Epoch 2
-------------------------------
loss: 1.209273  [    0/60000]
loss: 1.018829  [ 6400/60000]
loss: 0.907649  [12800/60000]
loss: 0.828608  [19200/60000]
loss: 0.988542  [25600/60000]
loss: 0.626101  [32000/60000]
loss: 0.755118  [38400/60000]
loss: 0.879188  [44800/60000]
loss: 0.692013  [51200/60000]
loss: 0.667389  [57600/60000]
Test Error: 
 Accuracy: 75.8%, Avg loss: 0.762505 

Epoch 3
-------------------------------
loss: 0.881806  [    0/60000]
loss: 0.854684  [ 6400/60000]
loss: 0.682217  [12800/60000]
loss: 0.908915  [19200/60000]
loss: 0.615945  [25600/60000]
loss: 0.689455  [32000/60

Patch size 1x1

In [18]:
train_ds_1x1 = GraphMNIST((1,1), train=True)
test_ds_1x1  = GraphMNIST((1,1), train=False)
train_loader_1x1 = DataLoader(train_ds_1x1, batch_size=64, shuffle=True)
test_loader_1x1  = DataLoader(test_ds_1x1, batch_size=64, shuffle=True)

print(f"Train dataset {train_ds_1x1}: {train_ds_1x1[0]}")
print(f"Test dataset  {test_ds_1x1}: {test_ds_1x1[0]}")

Train dataset GraphMNIST(60000): Data(x=[784, 1], edge_index=[2, 5940], y=[1])
Test dataset  GraphMNIST(10000): Data(x=[784, 1], edge_index=[2, 5940], y=[1])


In [21]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_1x1 = GCN(train_ds_1x1).to(device)
optimizer = torch.optim.Adam(model_1x1.parameters(), lr=0.01, weight_decay=5e-4)
loss_fn = torch.nn.CrossEntropyLoss()

epochs = 30
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_loader_1x1, model_1x1, loss_fn, optimizer, device)
    test(test_loader_1x1, model_1x1, loss_fn, device)
print("Done!")

Epoch 1
-------------------------------
loss: 11.158808  [    0/60000]
loss: 2.016950  [ 6400/60000]
loss: 2.094996  [12800/60000]
loss: 1.945999  [19200/60000]
loss: 1.926160  [25600/60000]
loss: 1.991243  [32000/60000]
loss: 1.903397  [38400/60000]
loss: 2.172642  [44800/60000]
loss: 1.845034  [51200/60000]
loss: 1.930171  [57600/60000]
Test Error: 
 Accuracy: 32.3%, Avg loss: 1.817338 

Epoch 2
-------------------------------
loss: 1.714838  [    0/60000]
loss: 1.823523  [ 6400/60000]
loss: 1.898827  [12800/60000]
loss: 1.984338  [19200/60000]
loss: 1.736590  [25600/60000]
loss: 1.762772  [32000/60000]
loss: 1.893613  [38400/60000]
loss: 1.686836  [44800/60000]
loss: 1.810842  [51200/60000]
loss: 1.630694  [57600/60000]
Test Error: 
 Accuracy: 33.7%, Avg loss: 1.810140 

Epoch 3
-------------------------------
loss: 2.032885  [    0/60000]
loss: 1.966486  [ 6400/60000]
loss: 1.875223  [12800/60000]
loss: 1.834963  [19200/60000]
loss: 1.908222  [25600/60000]
loss: 1.824498  [32000/60