In [1]:
import torch
import torchvision 
import torchvision.datasets as datasets
import numpy as np
from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.loader import DataLoader

In [2]:
# PyG Dataset 
class GraphMNIST(InMemoryDataset):
    def __init__(self, patch_size, root = None, train = True, transform = None, pre_transform = None, pre_filter = None):
        super().__init__(root, transform, pre_transform, pre_filter)
        X, Y, edge_index = self.loadMNIST(patch_size, train)
        data = [Data(x=x, edge_index=edge_index, y=y) for x, y in zip(X, Y.view(Y.shape[0], 1))]
        self.data, self.slices = self.collate(data)       

    def loadMNIST(self, patch_size, train = True):
        # loading MNIST
        mnist =  datasets.MNIST(root='./data', train=train, download=True, transform=None)
 
        # each image is divided into patches of shape patch_size
        X = mnist.data.unfold(1,patch_size[0],patch_size[0]).unfold(2,patch_size[1],patch_size[1])
        X = X.flatten(start_dim=1,end_dim=2).flatten(start_dim=2).to(torch.float)

        # labels
        Y = mnist.targets

        # adjacency matrix will always be the same:
        # each patch is adjacent to the 8 surrounding patches
        row = []
        col = []
        dim0, dim1 = int(28/patch_size[0]), int(28/patch_size[1])   
        for i in range(dim0):
            for j in range(dim1):
                for i_adj in range(i-1, i+2):
                    for j_adj in range(j-1, j+2):
                        if i_adj >= 0 and i_adj < dim0 and j_adj >=0 and j_adj < dim1 and not (i == i_adj and j == j_adj):
                            row.append(i * dim0 + j)         # patch at (i,j)
                            col.append(i_adj * dim0 + j_adj) # is adjacent to patch at (i_adj, j_adj)
        row = torch.tensor(row, dtype=torch.long)
        col = torch.tensor(col, dtype=torch.long)
        edge_index = torch.stack([row, col])
        return X, Y, edge_index

In [3]:
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool, global_max_pool
from sklearn.metrics import f1_score, accuracy_score

In [4]:
class GCN(torch.nn.Module):
    def __init__(self, data):
        super(GCN, self).__init__()
        # using architecture inspired by MNISTSuperpixels example 
        # (https://medium.com/@rtsrumi07/understanding-graph-neural-network-with-hands-on-example-part-2-139a691ebeac)
        hidden_channel_size = 64 
        self.initial_conv = GCNConv(data.num_features, hidden_channel_size)
        self.conv1 = GCNConv(hidden_channel_size, hidden_channel_size)
        self.conv2 = GCNConv(hidden_channel_size, hidden_channel_size)
        # self.conv3 = GCNConv(hidden_channel_size, data.num_features)
        self.out = nn.Linear(hidden_channel_size*2, data.num_classes)

    def forward(self, x, edge_index, batch_index):
        hidden = self.initial_conv(x, edge_index)
        hidden = F.relu(hidden)
        hidden = self.conv1(hidden, edge_index)
        hidden = F.relu(hidden)
        hidden = self.conv2(hidden, edge_index)
        hidden = F.relu(hidden)
        # hidden = self.conv3(hidden, edge_index)
        # hidden = F.relu(hidden)
        hidden = torch.cat([global_mean_pool(hidden, batch_index),
                            global_max_pool(hidden, batch_index)], dim=1)
        out = self.out(hidden)
        return out 


In [8]:
def train(dataloader, model, loss_fn, optimizer, device):
    size = len(dataloader.dataset)
    for batch, b in enumerate(dataloader):
        b.to(device)
        pred = model(b.x, b.edge_index, b.batch)
        loss = loss_fn(pred, b.y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # if batch % 100 == 0:
        #     loss, current = loss.item(), batch
        #     print(f"loss: {loss:>7f}  [{(current*64):>5d}/{size:>5d}]")

def test(dataloader, model, loss_fn, device):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss = 0
    Y, Y_pred = torch.empty(0), torch.empty(0)
    with torch.no_grad():
        for d in dataloader:
            d.to(device)
            pred = model(d.x, d.edge_index, d.batch)
            test_loss += loss_fn(pred, d.y).item()
            Y = torch.cat([Y, d.y.to('cpu')])
            Y_pred = torch.cat([Y_pred, pred.to('cpu')])

    test_loss /= num_batches
    Y_pred = torch.argmax(Y_pred, dim=1)
    accuracy = accuracy_score(Y, Y_pred)
    f1_micro = f1_score(Y, Y_pred, average='micro', labels=[0,1,2,3,4,5,6,7,8,9])
    f1_macro = f1_score(Y, Y_pred, average='macro', labels=[0,1,2,3,4,5,6,7,8,9])
    print(f"Accuracy: {accuracy:>8f}, F-measure (micro): {f1_micro:>8f}, F-measure (macro): {f1_macro:>8f}, Avg loss: {test_loss:>8f}")


With patch size 4x4

In [9]:
train_ds_4x4 = GraphMNIST((4,4), train=True)
test_ds_4x4  = GraphMNIST((4,4), train=False)
train_loader_4x4 = DataLoader(train_ds_4x4, batch_size=64)
test_loader_4x4  = DataLoader(test_ds_4x4, batch_size=64)

print(f"Train dataset {train_ds_4x4}: {train_ds_4x4[0]}")
print(f"Test dataset  {test_ds_4x4}: {test_ds_4x4[0]}")

Train dataset GraphMNIST(60000): Data(x=[49, 16], edge_index=[2, 312], y=[1])
Test dataset  GraphMNIST(10000): Data(x=[49, 16], edge_index=[2, 312], y=[1])


In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model_4x4 = GCN(train_ds_4x4).to(device)
optimizer = torch.optim.Adam(model_4x4.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()

epochs = 100
for t in range(epochs):
    print(f"Epoch {t+1}: ", end='')
    train(train_loader_4x4, model_4x4, loss_fn, optimizer, device)
    test(test_loader_4x4, model_4x4, loss_fn, device)
print("Done!")


cuda
Epoch 1: Accuracy: 0.730900, F-measure (micro): 0.730900, F-measure (macro): 0.719400, Avg loss: 0.819007
Epoch 2: Accuracy: 0.776500, F-measure (micro): 0.776500, F-measure (macro): 0.765687, Avg loss: 0.678900
Epoch 3: Accuracy: 0.819400, F-measure (micro): 0.819400, F-measure (macro): 0.812981, Avg loss: 0.547509
Epoch 4: Accuracy: 0.851900, F-measure (micro): 0.851900, F-measure (macro): 0.847612, Avg loss: 0.463992
Epoch 5: Accuracy: 0.869400, F-measure (micro): 0.869400, F-measure (macro): 0.867051, Avg loss: 0.407739
Epoch 6: Accuracy: 0.874900, F-measure (micro): 0.874900, F-measure (macro): 0.872602, Avg loss: 0.392629
Epoch 7: Accuracy: 0.881000, F-measure (micro): 0.881000, F-measure (macro): 0.879160, Avg loss: 0.367200
Epoch 8: Accuracy: 0.893000, F-measure (micro): 0.893000, F-measure (macro): 0.891458, Avg loss: 0.341463
Epoch 9: Accuracy: 0.891600, F-measure (micro): 0.891600, F-measure (macro): 0.890114, Avg loss: 0.336933
Epoch 10: Accuracy: 0.892400, F-measure (

With patch size 2x2

In [12]:
train_ds_2x2 = GraphMNIST((2,2), train=True)
test_ds_2x2  = GraphMNIST((2,2), train=False)
train_loader_2x2 = DataLoader(train_ds_2x2, batch_size=64, shuffle=True)
test_loader_2x2  = DataLoader(test_ds_2x2, batch_size=64, shuffle=True)

print(f"Train dataset {train_ds_2x2}: {train_ds_2x2[0]}")
print(f"Test dataset  {test_ds_2x2}: {test_ds_2x2[0]}")

Train dataset GraphMNIST(60000): Data(x=[196, 4], edge_index=[2, 1404], y=[1])
Test dataset  GraphMNIST(10000): Data(x=[196, 4], edge_index=[2, 1404], y=[1])


In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_2x2 = GCN(train_ds_2x2).to(device)
optimizer = torch.optim.Adam(model_2x2.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()

epochs = 100
for t in range(epochs):
    print(f"Epoch {t+1}: ",  end='')
    train(train_loader_2x2, model_2x2, loss_fn, optimizer, device)
    test(test_loader_2x2, model_2x2, loss_fn, device)
print("Done!")

Epoch 1: Accuracy: 0.605300, F-measure (micro): 0.605300, F-measure (macro): 0.576309, Avg loss: 1.166810
Epoch 2: Accuracy: 0.775700, F-measure (micro): 0.775700, F-measure (macro): 0.774489, Avg loss: 0.726260
Epoch 3: Accuracy: 0.792300, F-measure (micro): 0.792300, F-measure (macro): 0.785783, Avg loss: 0.642303
Epoch 4: Accuracy: 0.809500, F-measure (micro): 0.809500, F-measure (macro): 0.806453, Avg loss: 0.594961
Epoch 5: Accuracy: 0.837700, F-measure (micro): 0.837700, F-measure (macro): 0.834815, Avg loss: 0.511148
Epoch 6: Accuracy: 0.843900, F-measure (micro): 0.843900, F-measure (macro): 0.839978, Avg loss: 0.474454
Epoch 7: Accuracy: 0.862800, F-measure (micro): 0.862800, F-measure (macro): 0.860560, Avg loss: 0.434052
Epoch 8: Accuracy: 0.868100, F-measure (micro): 0.868100, F-measure (macro): 0.863894, Avg loss: 0.414250
Epoch 9: Accuracy: 0.880700, F-measure (micro): 0.880700, F-measure (macro): 0.878820, Avg loss: 0.368340
Epoch 10: Accuracy: 0.882100, F-measure (micro

Patch size 1x1

In [15]:
train_ds_1x1 = GraphMNIST((1,1), train=True)
test_ds_1x1  = GraphMNIST((1,1), train=False)
train_loader_1x1 = DataLoader(train_ds_1x1, batch_size=64, shuffle=True)
test_loader_1x1  = DataLoader(test_ds_1x1, batch_size=64, shuffle=True)

print(f"Train dataset {train_ds_1x1}: {train_ds_1x1[0]}")
print(f"Test dataset  {test_ds_1x1}: {test_ds_1x1[0]}")

Train dataset GraphMNIST(60000): Data(x=[784, 1], edge_index=[2, 5940], y=[1])
Test dataset  GraphMNIST(10000): Data(x=[784, 1], edge_index=[2, 5940], y=[1])


In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_1x1 = GCN(train_ds_1x1).to(device)
optimizer = torch.optim.Adam(model_1x1.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()

epochs = 100
for t in range(epochs):
    print(f"Epoch {t+1}: ", end='')
    train(train_loader_1x1, model_1x1, loss_fn, optimizer, device)
    test(test_loader_1x1, model_1x1, loss_fn, device)
print("Done!")

Epoch 1: Accuracy: 0.326600, F-measure (micro): 0.326600, F-measure (macro): 0.273499, Avg loss: 1.936106
Epoch 2: Accuracy: 0.322600, F-measure (micro): 0.322600, F-measure (macro): 0.239144, Avg loss: 1.888491
Epoch 3: Accuracy: 0.349000, F-measure (micro): 0.349000, F-measure (macro): 0.290990, Avg loss: 1.801743
Epoch 4: Accuracy: 0.309800, F-measure (micro): 0.309800, F-measure (macro): 0.257716, Avg loss: 1.810245
Epoch 5: Accuracy: 0.348300, F-measure (micro): 0.348300, F-measure (macro): 0.290393, Avg loss: 1.734549
Epoch 6: Accuracy: 0.385900, F-measure (micro): 0.385900, F-measure (macro): 0.349604, Avg loss: 1.697276
Epoch 7: Accuracy: 0.401400, F-measure (micro): 0.401400, F-measure (macro): 0.360407, Avg loss: 1.642673
Epoch 8: Accuracy: 0.388600, F-measure (micro): 0.388600, F-measure (macro): 0.341335, Avg loss: 1.655311
Epoch 9: Accuracy: 0.400700, F-measure (micro): 0.400700, F-measure (macro): 0.377584, Avg loss: 1.626137
Epoch 10: Accuracy: 0.430300, F-measure (micro

KeyboardInterrupt: 