In [None]:
import torch_geometric.datasets
import torch
import numml.sparse as sp
import torch.nn as tNN
import numml.nn as nNN
import matplotlib.pyplot as plt

In [None]:
# Use CUDA if we have access to it
device = (torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu'))

In [None]:
# few helper functions to massage the data
def tg_data_to_spmatrix(data):
    mat_data = None
    if data.edge_attr is not None:
        mat_data = data.edge_attr
    else:
        mat_data = torch.ones(data.edge_index.shape[1])
    return sp.SparseCSRTensor((mat_data, (data.edge_index[0], data.edge_index[1])), (data.x.shape[0], data.x.shape[0]))

def class_to_onehot(y):
    classes = (torch.max(y) - torch.min(y) + 1).item()
    z = torch.zeros(len(y), classes, device=y.device)
    z[torch.arange(len(y), device=y.device), y] = 1.
    return z

def restriction_mat(b):
    N = len(b)
    N_r = torch.sum(b)
    
    rows = torch.arange(N_r, device=device)
    cols = torch.where(b)[0].to(device)
    
    return sp.SparseCSRTensor((torch.ones(N_r, device=device), (rows, cols)), shape=(N_r, N)).to(device)

In [None]:
# Get the data into a nice format
dataset = torch_geometric.datasets.Planetoid(root='/tmp/citeseer', name='CiteSeer').data
G = tg_data_to_spmatrix(dataset).to(device)
y = dataset.y.to(device)
y_oh = class_to_onehot(y)
x = dataset.x.to(device)

# Create train/test split
N = dataset.x.shape[0]
train_p = 0.8
train_mask = torch.bernoulli(torch.ones(N)*train_p).bool()
test_mask = torch.logical_not(train_mask)

y_tr = y[train_mask]
y_te = y[test_mask]
y_oh_tr = y_oh[train_mask]
y_oh_te = y_oh[test_mask]
x_tr = x[train_mask]
x_te = x[test_mask]

R_tr = restriction_mat(train_mask)
R_te = restriction_mat(test_mask)

In [None]:
class Network(tNN.Module):
    def __init__(self, in_layers, out_layers, H):
        super().__init__()
        
        self.conv1 = nNN.GCNConv(in_layers, H, normalize=True)
        self.conv2 = nNN.GCNConv(H, out_layers, normalize=True)
    
    def forward(self, A, X):
        X = torch.relu(self.conv1(A, X))
        X = torch.sigmoid(self.conv2(A, X))
        return X

In [None]:
# Minimize cross-entropy loss between network output and one-hot categories
net = Network(x.shape[1], y_oh.shape[1], 100).to(device)
loss = tNN.CrossEntropyLoss()
opt = torch.optim.Adam(net.parameters(), lr=0.1)

In [None]:
# Run for 100 training epochs

N_e = 100

G_tr = R_tr@G@R_tr.T
G_te = R_te@G@R_te.T

lh_tr = []
lh_te = []

for i in range(N_e):
    opt.zero_grad()
    yhat = net(G_tr, x_tr)
    l = loss(yhat, y_oh_tr)
    l.backward()
    opt.step()
    
    lh_tr.append(l.item())
    with torch.no_grad():
        lh_te.append(loss(net(G_te, x_te), y_oh_te).item())

    if i % 10 == 0 or i == N_e-1:
        print(i, l.item())

In [None]:
# loss curves

plt.semilogy(lh_tr, label='Training loss')
plt.semilogy(lh_te, label='Testing loss')
plt.legend()
plt.grid()

In [None]:
print('accuracy (overall)', torch.mean((torch.argmax(net(G, x), dim=1) == y).float()).item())
print('accuracy (train)', torch.mean((torch.argmax(net(R_tr@G@R_tr.T, x_tr), dim=1) == y_tr).float()).item())
print('accuracy (test)', torch.mean((torch.argmax(net(R_te@G@R_te.T, x_te), dim=1) == y_te).float()).item())