In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd.functional import vjp
from torch_geometric.datasets import Planetoid
import time

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

The GDE network makes use of the generic GCN layer to compute the 'head' and 'tail' layers, since the dimensions each hidden state cannot change within the ODE network. 

In [2]:
# Generic GCN for head and tail layers
class GCN(nn.Module):
    def __init__(self, A, in_features, out_features, bias=False):
        super(GCN, self).__init__()
        
        self.A = A
        self.in_features = in_features
        self.out_features = out_features
        
        self.weight = nn.Parameter(torch.Tensor(in_features, out_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        
        self.reset_parameters()

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ',' \
               + str(self.out_features) + ')'
    
    def reset_parameters(self):
        stdv = 1. / self.weight.size(1) ** 1/2
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)
    
    # H, feature matrix
    # A, precomputed adj matrix
    def forward(self, H):
        n = torch.mm(self.A, torch.mm(H, self.weight))
        if self.bias is not None:
            return n + self.bias
        else:
            return n

This is a generic GCN block with a dyanmic number of layers. Note that when using, we wish to keep this number small (probably 1), since the ODE solver will effectively increase the number of layers here.

In [3]:
# GCN Block for body layers
class Block(nn.Module):
    def __init__(self, A, features, activation, num_layers):
        super(Block, self).__init__()
        self.features = features
        self.activation = activation
        self.num_layers = num_layers
        self.A = A

    def forward(self, x, t, net_params):
        weights = net_params.view(self.num_layers, self.features, self.features)
        
        x = x.view(self.A.size(1), self.features)
        for i in range(self.num_layers):
            x = self.A.mm(x).mm(weights[i,:,:])
            x = self.activation(x)

        return x

    def num_params(self):
        return self.features * self.features * self.num_layers

In [4]:
RK4 = ((  0,),
       (1/2, 1/2,),
       (1/2,   0,  1/2,),
       (  1,   0,    0,   1,),
       (1/6, 1/3, 1/3, 1/6,))

EF = ((0,),
      (1,))

"""
General Runge-Kutta Solver.
https://en.wikipedia.org/wiki/Runge–Kutta_methods

b_tableau, nested tuple, contains weights of integration. 
f, function, is the function to iterate. Should only be a function of x, t. 
x0, torch.Tensor, is the intial condition.
t0, torch.Tensor, is the start time of integration.
t1, torch.Tensor, is the end time of integration.
N, int, is the desired number of timesteps.

returns x, torch.Tensor, estimated solution of dy/dx = f(x,t) at time t1. 
"""

def explicit_RK(b_tableau, f, x0, t0, t1, N):
    h = (t1 - t0) / float(N) # calculate step size
    x = x0 # initialize saved dynamics
    mesh = (t0 + h * i for i in range(N)) # generator of time values
    for time in mesh:
        
        k = [f(x, time + h*b_tableau[0][0])] # Covers the first row of the Butcher tableau. 
        for i, row in enumerate(b_tableau[1:-1]): # Covers the middle rows of the Butcher tableau.
            k.append(f(x + sum(w * k[idx] * h for idx, w in enumerate(row[1:])), time + row[0] * h)) # calculate k's. 
        x = x + sum(w * k_i * h for k_i, w in zip(k, b_tableau[-1])) # calculate timestep 
    return x

In [5]:
# Convenience tuple -> tensor function
def flatten(*args):
    return(torch.cat(tuple(torch.flatten(arg) for arg in args), dim=0).view(1,-1))

# Convenience tensor -> tuple function
def unflatten(x, n_e, sizes):
    return (x[0, 0:n_e[0]].view(sizes[0]),
            x[0, n_e[0]:n_e[0] + n_e[1]].view(sizes[1]),
            x[0, (n_e[0] + n_e[1]):(n_e[0] + n_e[1] + n_e[2])].view(sizes[2]),
            x[0, (n_e[0] + n_e[1] + n_e[2]):].view(sizes[3]),
            )

class Integrate(torch.autograd.Function):
    def __deepcopy__(self, memo):
        return Integrate(copy.deepcopy(memo))
    
    @staticmethod
    def forward(ctx, Integrator, f, x0, t0, t1, N, net_params):
        solution = Integrator(EF, lambda x, t: f(x, t, net_params), x0, t0, t1, N)

        # Save for jacobian calculations in backward()
        ctx.save_for_backward(x0,t0,t1,net_params)
        ctx.solution = solution
        ctx.Integrator = Integrator
        ctx.N = N
        ctx.f = f
        
        return solution
    
    @staticmethod
    def backward(ctx, dL_dz1):
        # Get all saved context
        z0, t0, t1, net_params = ctx.saved_tensors
        z1 = ctx.solution
        N = ctx.N
        f = ctx.f
        
        # Convenience sizes
        batch_size = z0.size()[0]
        img_len = int(z0.numel() / batch_size)

        # Compute derivative w.r.t. to end time of integration
        dL_dt1 = dL_dz1.view(batch_size,1,-1).bmm(f(z1, t1, net_params).view(batch_size,-1,1))  # Derivative of loss w.r.t t1
        
        #print("dL_dt1", dL_dt1)
        
        # Initial Condition
        num_elements = (z1.numel(), dL_dz1.numel(), batch_size * net_params.numel(), dL_dt1.numel())
        sizes = (z1.size(), dL_dz1.size(), (batch_size, net_params.numel()), dL_dt1.size())
        s0 = flatten(z1, dL_dz1, torch.zeros((batch_size, net_params.numel()), dtype=torch.float32, device=z1.device), -dL_dt1) # initial augmented state
    
        # augmented dynamics function
        def aug_dynamics(s, t, theta):
            s = unflatten(s, num_elements, sizes)
            
            with torch.enable_grad(): 
#                 gradients = [vjp(f, 
#                                  (s[0][i].unsqueeze(0), t, theta), 
#                                   v=-s[1][i].unsqueeze(0),
#                                  )[1] for i in range(batch_size)]
                gradients = vjp(f, 
                                 (s[0], t, theta), 
                                  v=-s[1],
                                 )[1]
                
            return flatten(f(s[0],t,theta),
                    torch.cat([gradients[0]], dim=0), 
                    torch.cat([gradients[2].unsqueeze(0) for i in range(batch_size)], dim=0), 
                    torch.cat([gradients[1].reshape(1,1) for i in range(batch_size)], dim=0),
                   )#.unsqueeze(2)
#             return flatten(f(s[0],t,theta),
#                     torch.cat([gradient[0] for gradient in gradients], dim=0), 
#                     torch.cat([gradient[2].unsqueeze(0) for gradient in gradients], dim=0), 
#                     torch.cat([gradient[1].reshape(1,1) for gradient in gradients], dim=0),
#                    )#.unsqueeze(2)

        # Integrate backwards
        with torch.enable_grad():
            s = ctx.Integrator(EF, lambda x, t: aug_dynamics(x, t, net_params), s0, t1, t0, N)
        
        # Extract derivatives
        _, dL_dz0, dL_dtheta, dL_dt0 = unflatten(s, num_elements, sizes)
                
        # must return something for every input to forward, None for non-tensors
        return None, None, dL_dz0, dL_dt0, dL_dt1, None, dL_dtheta 

The ODENet module instantiates the ODE network. You can think of it as replacing the GCNNet in the traditional non-ODE version of the network.

In [6]:
class ODENet(nn.Module):
    def __init__(self, solver, f, in_channels, body_channels, out_channels, hidden_layers, A, solver_params):
        super(ODENet, self).__init__()
        
        # Graph Laplacian
        self.A = A
        
        # Controls amount of parameters
        self.body_channels = body_channels
        self.f = f(A, body_channels, F.relu, hidden_layers)
        
        # Head
        self.head = GCN(A, in_channels, body_channels)
        
        # Body
        self.int_f = solver
        self.Integrate = Integrate
        self.solver_params = solver_params
        self.N = solver_params["N"]
        self.h = (solver_params["t1"] - solver_params["t0"]) / solver_params["N"]
        self.t0 = torch.tensor(float(solver_params["t0"]), requires_grad=True)
        self.t1 = torch.tensor(float(solver_params["t1"]), requires_grad=True)
        
        self.net_params = torch.nn.parameter.Parameter(torch.Tensor(self.f.num_params()).normal_(mean=0, std=0.1,generator=None), requires_grad=True)

        # Tail
        self.tail = GCN(A, body_channels, out_channels)
    
    def _apply(self, fn):
        super(ODENet, self)._apply(fn)
        self.t0 = fn(self.t0)
        self.t1 = fn(self.t1)
        return self
    
    def forward(self, x):
        x = F.relu(self.head(x))
        x = self.Integrate.apply(self.int_f, self.f, x, self.t0, self.t1, self.N, self.net_params) # Vanilla RK4
        x = self.tail(x)
        return x

In [7]:
def create_A(data):
    adj = torch.eye(data.num_nodes, data.num_nodes)
    adj[data.edge_index[0,:], data.edge_index[1,:]] += 1
    deg = adj.sum(dim=1) ** (-1/2)
    D = torch.diag(deg)
    return D.mm(adj).mm(D)

In [8]:
def masked_accuracy(pred, labels, mask):
    return (pred.argmax(dim=1) == labels)[mask].sum().item() / mask.sum().item()

In [9]:
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]

A = create_A(data)

train_labels = torch.where(data.train_mask, data.y, torch.tensor(-100))
val_labels = torch.where(data.val_mask, data.y, torch.tensor(-100))
test_labels = torch.where(data.test_mask, data.y, torch.tensor(-100))

print('training samples: ', data.train_mask.sum().item())
print('validation samples: ', data.val_mask.sum().item())
print('test samples: ', data.test_mask.sum().item())

A = A.to(device)
input_features = data.x.to(device)
train_labels = train_labels.to(device)
val_labels = val_labels.to(device)
test_labels = test_labels.to(device)
train_mask = data.train_mask.to(device)
val_mask = data.val_mask.to(device)
test_mask = data.test_mask.to(device)

training samples:  140
validation samples:  500
test samples:  1000


In [10]:
parameters = {
    'features': dataset.num_features,
    'body': 64,
    'classes': dataset.num_classes,
    'num_epochs': 100,
    'learning_rate': 1e-2,
    'weight_decay': 5e-3
}

solver_params = {
    "t0": 0,
    "t1": 1,
    "N": 5,
}

In [15]:
torch.manual_seed(0)

num_epochs = parameters['num_epochs']
model = ODENet(explicit_RK, Block, parameters['features'], parameters['body'], parameters['classes'], 2, A, solver_params).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=parameters['learning_rate'], weight_decay=parameters['weight_decay'])
criterion = nn.CrossEntropyLoss()


for epoch in range(1, num_epochs + 1):
    model.train()
    optimizer.zero_grad()
    train_pred = model(input_features)
    train_loss = criterion(train_pred, train_labels)
    train_loss.backward()
    optimizer.step()
    train_acc = masked_accuracy(train_pred, train_labels, train_mask)
    
    model.eval()
    val_pred = model(input_features)
    val_loss = criterion(val_pred, val_labels)
    val_acc = masked_accuracy(val_pred, val_labels, val_mask)
    print("{}: \ttrain loss {}\tacc {:2f}\tval loss {}\tacc {:2f}".format(epoch, train_loss, train_acc, val_loss, val_acc))

TypeError: 'int' object is not iterable

In [12]:
model.eval()
test_pred = model(input_features)
test_acc = masked_accuracy(test_pred, test_labels, test_mask)
print(test_acc)

0.784


This is where the hyperparameters for training are controlled, and the solver parameters. 