In [24]:
# This is edited based on https://github.com/bknyaz/examples/blob/master/fc_vs_graph_train.py

In [25]:
# import section
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
from scipy.spatial.distance import cdist

### Models; @author: https://github.com/bknyaz/examples/blob/master/fc_vs_graph_train.py

In [20]:

class BorisNet(nn.Module):
    def __init__(self):
        super(BorisNet, self).__init__()
        self.fc = nn.Linear(784, 10, bias=False)

    def forward(self, x):
        return self.fc(x.view(x.size(0), -1))


class BorisConvNet(nn.Module):
    def __init__(self):
        super(BorisConvNet, self).__init__()
        self.conv = nn.Conv2d(1, 10, 28, stride=1, padding=14)
        self.fc = nn.Linear(4 * 4 * 10, 10, bias=False)

    def forward(self, x):
        x = F.relu(self.conv(x))
        x = F.max_pool2d(x, 7)
        return self.fc(x.view(x.size(0), -1))

class BorisGraphNet(nn.Module):
    def __init__(self, img_size=28, pred_edge=False):
        super(BorisGraphNet, self).__init__()
        self.pred_edge = pred_edge
        N = img_size ** 2
        self.fc = nn.Linear(N, 10, bias=False)
        if pred_edge:
            col, row = np.meshgrid(np.arange(img_size), np.arange(img_size))
            coord = np.stack((col, row), axis=2).reshape(-1, 2)
            coord = (coord - np.mean(coord, axis=0)) / (np.std(coord, axis=0) + 1e-5)
            coord = torch.from_numpy(coord).float()  # 784,2
            coord = torch.cat((coord.unsqueeze(0).repeat(N, 1,  1),
                                    coord.unsqueeze(1).repeat(1, N, 1)), dim=2)
            #coord = torch.abs(coord[:, :, [0, 1]] - coord[:, :, [2, 3]])
            self.pred_edge_fc = nn.Sequential(nn.Linear(4, 64),
                                              nn.ReLU(),
                                              nn.Linear(64, 1),
                                              nn.Tanh())
            self.register_buffer('coord', coord)
        else:
            # precompute adjacency matrix before training
            A = self.precompute_adjacency_images(img_size)
            self.register_buffer('A', A)


    @staticmethod
    def precompute_adjacency_images(img_size):
        col, row = np.meshgrid(np.arange(img_size), np.arange(img_size))
        coord = np.stack((col, row), axis=2).reshape(-1, 2) / img_size
        dist = cdist(coord, coord)  
        sigma = 0.05 * np.pi
        
        # Below, I forgot to square dist to make it a Gaussian (not sure how important it can be for final results)
        A = np.exp(- dist / sigma ** 2)
        print('WARNING: try squaring the dist to make it a Gaussian')
            
        A[A < 0.01] = 0
        A = torch.from_numpy(A).float()

        # Normalization as per (Kipf & Welling, ICLR 2017)
        D = A.sum(1)  # nodes degree (N,)
        D_hat = (D + 1e-5) ** (-0.5)
        A_hat = D_hat.view(-1, 1) * A * D_hat.view(1, -1)  # N,N

        # Some additional trick I found to be useful
        A_hat[A_hat > 0.0001] = A_hat[A_hat > 0.0001] - 0.2

        print(A_hat[:10, :10])
        return A_hat

    def forward(self, x):
        B = x.size(0)
        if self.pred_edge:
            self.A = self.pred_edge_fc(self.coord).squeeze()

        avg_neighbor_features = (torch.bmm(self.A.unsqueeze(0).expand(B, -1, -1),
                                 x.view(B, -1, 1)).view(B, -1))
        return self.fc(avg_neighbor_features)


def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 200 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print(
        '\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(test_loader.dataset),
            100. * correct / len(test_loader.dataset)))

### Training FCN, CNN, GCN on MNIST

In [21]:
# Setting global parameters
use_cuda = True
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
torch.manual_seed(1)

# set the train and test loader
train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=64, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])),
        batch_size=1000, shuffle=False, **kwargs)

In [23]:
# Run FCN
# args: learning rate: 1e-3, weight_decay: 1e-4, epoch:5
model_FCN = BorisNet()
model_FCN.to(device)
print(model_FCN)
optimizer = optim.SGD(model_FCN.parameters(), lr=1e-3, weight_decay=1e-4)
print('number of trainable parameters: %d' %
    np.sum([np.prod(p.size()) if p.requires_grad else 0 for p in model_FCN.parameters()]))

for epoch in range(1, 6):
    train(model_FCN, device, train_loader, optimizer, epoch)
    test(model_FCN, device, test_loader)

BorisNet(
  (fc): Linear(in_features=784, out_features=10, bias=False)
)
number of trainable parameters: 7840

Test set: Average loss: 0.5874, Accuracy: 8653/10000 (87%)


Test set: Average loss: 0.4666, Accuracy: 8847/10000 (88%)


Test set: Average loss: 0.4191, Accuracy: 8905/10000 (89%)


Test set: Average loss: 0.3914, Accuracy: 8948/10000 (89%)


Test set: Average loss: 0.3732, Accuracy: 8980/10000 (90%)



In [26]:
#run CNN
#args: learning rate: 1e-3, weight_decay: 1e-1, epoch:5
model_CNN = BorisConvNet()
model_CNN.to(device)
print(model_CNN)
optimizer = optim.SGD(model_CNN.parameters(), lr=1e-3, weight_decay=1e-1)
print('number of trainable parameters: %d' %
    np.sum([np.prod(p.size()) if p.requires_grad else 0 for p in model_CNN.parameters()]))

for epoch in range(1, 6):
    train(model_CNN, device, train_loader, optimizer, epoch)
    test(model_CNN, device, test_loader)

BorisConvNet(
  (conv): Conv2d(1, 10, kernel_size=(28, 28), stride=(1, 1), padding=(14, 14))
  (fc): Linear(in_features=160, out_features=10, bias=False)
)
number of trainable parameters: 9450

Test set: Average loss: 1.0589, Accuracy: 8547/10000 (85%)


Test set: Average loss: 0.5599, Accuracy: 9036/10000 (90%)


Test set: Average loss: 0.4085, Accuracy: 9238/10000 (92%)


Test set: Average loss: 0.3432, Accuracy: 9344/10000 (93%)


Test set: Average loss: 0.3083, Accuracy: 9407/10000 (94%)



In [27]:
# Run GCN
# args: learning rate: 1e-3, weight_decay: 1e-4, epoch:5
model_GCN = BorisGraphNet(pred_edge=False)
model_GCN.to(device)
print(model_GCN)
optimizer = optim.SGD(model_GCN.parameters(), lr=1e-3, weight_decay=1e-4)
print('number of trainable parameters: %d' %
    np.sum([np.prod(p.size()) if p.requires_grad else 0 for p in model_GCN.parameters()]))

for epoch in range(1, 6):
    train(model_GCN, device, train_loader, optimizer, epoch)
    test(model_GCN, device, test_loader)

tensor([[ 0.3400, -0.0852, -0.1736, -0.1938,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [-0.0852,  0.2413, -0.0987, -0.1763, -0.1944,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [-0.1736, -0.0987,  0.2207, -0.1015, -0.1768, -0.1946,  0.0000,  0.0000,
          0.0000,  0.0000],
        [-0.1938, -0.1763, -0.1015,  0.2166, -0.1020, -0.1770, -0.1946,  0.0000,
          0.0000,  0.0000],
        [ 0.0000, -0.1944, -0.1768, -0.1020,  0.2166, -0.1020, -0.1770, -0.1946,
          0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.1946, -0.1770, -0.1020,  0.2166, -0.1020, -0.1770,
         -0.1946,  0.0000],
        [ 0.0000,  0.0000,  0.0000, -0.1946, -0.1770, -0.1020,  0.2166, -0.1020,
         -0.1770, -0.1946],
        [ 0.0000,  0.0000,  0.0000,  0.0000, -0.1946, -0.1770, -0.1020,  0.2166,
         -0.1020, -0.1770],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.1946, -0.1770, -0.1020,
          0.2166, -0.1020],
        [ 0.0000,  