In [221]:
import numpy as np
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [222]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.data import Data
import torch.nn.functional as F
import torch_sparse
from torchdiffeq import odeint
# from torchdiffeq import odeint_adjoint as odeint # Might be more stable according to docs of torchdiffeq
import numpy as np
import networkx as nx
import matplotlib
import matplotlib.pyplot as plt

import time

from torch_geometric.datasets import Planetoid
from pathlib import Path


In [223]:
# Set torch device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'

## Get Dataset

In [224]:
dataset_dir = Path('data')
dataset_dir = dataset_dir.absolute()
if not dataset_dir.exists():
    dataset_dir.mkdir(parents=True)

dataset = Planetoid(dataset_dir, 'Cora')

# Some info
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')
data = dataset.data
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Number of training nodes: {data.train_mask.sum()}')
print(f'Number of validation nodes: {data.val_mask.sum()}')
print(f'Number of test nodes: {data.test_mask.sum()}')
print(f'Contains isolated nodes: {data.contains_isolated_nodes()}')
print(f'Contains self-loops: {data.contains_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')

Number of graphs: 1
Number of features: 1433
Number of classes: 7
Number of nodes: 2708
Number of edges: 10556
Average node degree: 3.90
Number of training nodes: 140
Number of validation nodes: 500
Number of test nodes: 1000
Contains isolated nodes: False
Contains self-loops: False
Is undirected: True


In [225]:
# # Plot the graph
# G = nx.Graph()
#
# G.add_nodes_from(list(range(data.num_nodes)))
# G.add_edges_from([tuple(x) for x in data.edge_index.T.tolist()])
# nx.draw(G)
# plt.gca().set_facecolor('white')

In [226]:
# Train the given model on the given graph for num_epochs
def train(model, optimizer, data):
    model.train()
    optimizer.zero_grad()
    x = data.x
    y = data.y.squeeze()

    # Set up the loss and the optimizer
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    out = model(x)

    loss = loss_fn(out[data.train_mask], y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

In [227]:
class ODELayer(MessagePassing):
    def __init__(self, in_features, out_features, device, opt):
        super(ODELayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.opt = opt
        self.device = device
        self.w = nn.Parameter(torch.eye(opt['hidden_dim']))
        self.d = nn.Parameter(torch.zeros(opt['hidden_dim']) + 1)
        self.alpha_train = nn.Parameter(torch.tensor(0.0))
        self.beta_train = nn.Parameter(torch.tensor(0.0))
        self.x0 = None
        self.nfe = 0
        self.alpha_sc = nn.Parameter(torch.ones(1))
        self.beta_sc = nn.Parameter(torch.ones(1))

    def forward(self, t, x):  # the t param is needed by the ODE solver.
        ax = torch_sparse.spmm(self.edge_index, self.edge_weight, x.shape[0], x.shape[0], x)
        alpha = torch.sigmoid(self.alpha_train)
        f = alpha * (ax - x) # What is happening here?
        f = f + self.beta_train * self.x0
        return f

In [228]:
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_geometric.utils import add_remaining_self_loops
from torch_scatter import scatter_add
def get_rw_adj(edge_index, edge_weight=None, norm_dim=1, fill_value=0., num_nodes=None, dtype=None):
  num_nodes = maybe_num_nodes(edge_index, num_nodes)

  if edge_weight is None:
    edge_weight = torch.ones((edge_index.size(1),), dtype=dtype,
                             device=edge_index.device)

  if not fill_value == 0:
    edge_index, tmp_edge_weight = add_remaining_self_loops(
      edge_index, edge_weight, fill_value, num_nodes)
    assert tmp_edge_weight is not None
    edge_weight = tmp_edge_weight

  row, col = edge_index[0], edge_index[1]
  indices = row if norm_dim == 0 else col
  deg = scatter_add(edge_weight, indices, dim=0, dim_size=num_nodes)
  deg_inv_sqrt = deg.pow_(-1)
  edge_weight = deg_inv_sqrt[indices] * edge_weight if norm_dim == 0 else edge_weight * deg_inv_sqrt[indices]
  return edge_index, edge_weight

In [229]:
class ConstantODEblock(nn.Module):
  def __init__(self, odefunc, opt, data, device, t):
    super(ConstantODEblock, self).__init__()

    self.opt = opt
    self.t = t
    self.odefunc = odefunc(opt['hidden_dim'], opt['hidden_dim'], device, opt)
    # edge_index = data.edge_index # Note these are normalized in the original code (why?)
    # edge_weight = torch.ones((edge_index.size(1),), dtype=data.x.dtype, device=device) # Note these are normalized in the original code (why?)

    edge_index, edge_weight = get_rw_adj(data.edge_index, edge_weight=data.edge_attr, norm_dim=1,
                                                           fill_value=1,
                                                           num_nodes=data.num_nodes,
                                                           dtype=data.x.dtype)

    self.odefunc.edge_index = edge_index.to(device)
    self.odefunc.edge_weight = edge_weight.to(device)

    self.train_integrator = odeint
    self.test_integrator = odeint
    self.atol = 1e-7 # Not necessary for euler
    self.rtol = 1e-9 # Not necessary for euler

  def forward(self, x):
    t = self.t.type_as(x)

    integrator = self.train_integrator if self.training else self.test_integrator

    func = self.odefunc
    state = x
    state_dt = integrator(
        func, state, t,
        method='dopri5',
        options=dict(step_size=1, max_iters=100),
        atol=self.atol,
        rtol=self.rtol)

    z = state_dt[1]
    return z

  def set_x0(self, x0):
    self.odefunc.x0 = x0.clone().detach()

  def set_time(self, time):
    self.t = torch.tensor([0, time]).to(self.device)

  def __repr__(self):
    return self.__class__.__name__ + '( Time Interval ' + str(self.t[0].item()) + ' -> ' + str(self.t[1].item()) \
           + ")"

In [230]:
class GNN(MessagePassing):
  def __init__(self, opt, dataset, device=torch.device('cpu')):
    super(GNN, self).__init__()
    self.opt = opt
    self.T = opt['time']
    time_tensor = torch.tensor([0, self.T]).to(device)
    self.num_classes = dataset.num_classes
    self.num_features = dataset.data.num_features
    self.num_nodes = dataset.data.num_nodes
    self.device = device

    self.f = ODELayer
    self.odeblock = ConstantODEblock(self.f, opt, dataset.data, device, t=time_tensor).to(device)

    self.m1 = nn.Linear(self.num_features, opt['hidden_dim'])
    self.hidden_dim = opt['hidden_dim']
    self.m2 = nn.Linear(opt['hidden_dim'], dataset.num_classes)

  def getNFE(self):
    return self.odeblock.odefunc.nfe

  def resetNFE(self):
    self.odeblock.odefunc.nfe = 0

  def reset(self):
    self.m1.reset_parameters()
    self.m2.reset_parameters()

  def forward(self, x):
    # Encode each node based on its feature.
    # x = F.dropout(x, self.opt['input_dropout'], training=self.training)
    x = self.m1(x)

    self.odeblock.set_x0(x)
    z = self.odeblock(x)

    # Activation.
    z = F.relu(z)

    # We can make it deep
    # for i in range(5):
    #     self.odeblock.set_x0(z)
    #     z = self.odeblock(z)
    #     z = F.relu(z)

    # Dropout.
    # z = F.dropout(z, self.opt['dropout'], training=self.training)

    # Decode each node embedding to get node label.
    z = self.m2(z)
    return z

  def __repr__(self):
    return self.__class__.__name__


In [231]:
def print_model_params(model):
  print(model)
  for name, param in model.named_parameters():
    if param.requires_grad:
      print(name)
      print(param.data.shape)

In [232]:
# def set_train_val_test_split(
#         seed: int,
#         data: Data,
#         num_development: int = 1500,
#         num_per_class: int = 20) -> Data:
#   rnd_state = np.random.RandomState(seed)
#   num_nodes = data.y.shape[0]
#   development_idx = rnd_state.choice(num_nodes, num_development, replace=False)
#   test_idx = [i for i in np.arange(num_nodes) if i not in development_idx]
#
#   train_idx = []
#   rnd_state = np.random.RandomState(seed)
#   for c in range(data.y.max() + 1):
#     class_idx = development_idx[np.where(data.y[development_idx].cpu() == c)[0]]
#     train_idx.extend(rnd_state.choice(class_idx, num_per_class, replace=False))
#
#   val_idx = [i for i in development_idx if i not in train_idx]
#
#   def get_mask(idx):
#     mask = torch.zeros(num_nodes, dtype=torch.bool)
#     mask[idx] = 1
#     return mask
#
#   data.train_mask = get_mask(train_idx)
#   data.val_mask = get_mask(val_idx)
#   data.test_mask = get_mask(test_idx)
#
#   return data

In [233]:
@torch.no_grad()
def test(model, data, opt=None):  # opt required for runtime polymorphism
  model.eval()
  logits, accs = model(data.x), []
  for _, mask in data('train_mask', 'val_mask', 'test_mask'):
    pred = logits[mask].max(1)[1]
    acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
    accs.append(acc)
  return accs

In [234]:
opt = {
    'hidden_dim': 80,
    'time': 18.294,
    'epoch': 100,
    'input_dropout': 0,
    'droput': 0,
}


model = GNN(opt, dataset, device).to(device)
# dataset.data = set_train_val_test_split(np.random.randint(0, 1000), dataset.data, 1500)
data = dataset.data.to(device)
parameters = [p for p in model.parameters() if p.requires_grad]
# print_model_params(model)
optimizer = optim.Adam(model.parameters(), lr=0.01)
best_time = best_epoch = train_acc = val_acc = test_acc = 0

for epoch in range(1, opt['epoch']):
    start_time = time.time()

    loss = train(model, optimizer, data)
    tmp_train_acc, tmp_val_acc, tmp_test_acc = test(model, data, opt)

    best_time = opt['time']
    if tmp_val_acc > val_acc:
      best_epoch = epoch
      train_acc = tmp_train_acc
      val_acc = tmp_val_acc
      test_acc = tmp_test_acc
      best_time = opt['time']

    log = 'Epoch: {:03d}, Runtime {:03f}, Loss {:03f}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}, Best time: {:.4f}'
    print(log.format(epoch, time.time() - start_time, loss, train_acc, val_acc, test_acc, best_time))
    print('best val accuracy {:03f} with test accuracy {:03f} at epoch {:d} and best time {:03f}'.format(val_acc, test_acc,
                                                                                                     best_epoch,
                                                                                                     best_time))
# train_acc, val_acc, test_acc



Epoch: 001, Runtime 1.281130, Loss 1.949133, Train: 0.5000, Val: 0.3480, Test: 0.3140, Best time: 18.2940
best val accuracy 0.348000 with test accuracy 0.314000 at epoch 1 and best time 18.294000
Epoch: 002, Runtime 1.368799, Loss 1.876029, Train: 0.8357, Val: 0.6780, Test: 0.6790, Best time: 18.2940
best val accuracy 0.678000 with test accuracy 0.679000 at epoch 2 and best time 18.294000
Epoch: 003, Runtime 1.048199, Loss 1.706633, Train: 0.8857, Val: 0.7140, Test: 0.7190, Best time: 18.2940
best val accuracy 0.714000 with test accuracy 0.719000 at epoch 3 and best time 18.294000
Epoch: 004, Runtime 0.853470, Loss 1.434530, Train: 0.9000, Val: 0.7840, Test: 0.7760, Best time: 18.2940
best val accuracy 0.784000 with test accuracy 0.776000 at epoch 4 and best time 18.294000
Epoch: 005, Runtime 0.882504, Loss 1.118515, Train: 0.9000, Val: 0.7840, Test: 0.7760, Best time: 18.2940
best val accuracy 0.784000 with test accuracy 0.776000 at epoch 4 and best time 18.294000
Epoch: 006, Runtime 