In [1]:
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from time import time
from progressbar import progressbar

!pip install torchdiffeq



### Load MNIST dataset

In [0]:
transform = transforms.Compose([transforms.ToTensor()])
mnist_train = torchvision.datasets.MNIST(root='data',train=True,download=True, transform=transform)
mnist_test = torchvision.datasets.MNIST(root='data',train=False,download=True, transform=transform)

In [0]:
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=32, shuffle=False)

### Training Function

In [0]:
loss_fn = nn.CrossEntropyLoss()

def train(model, criterion = loss_fn, epochs=10, device=torch.device('cuda'),
              train_loader=train_loader, test_loader=test_loader, flatten_input=False):
    optimizer = optim.Adam(model.parameters())
    model = model.to(device)

    accs = []
    times = []
    for epoch in range(epochs):
        t1 = time()
        model.train() # Training
        for i, (X,y) in progressbar(enumerate(train_loader)):
            if flatten_input: # Use for feed-forward networks
                X = X.view(-1, 1, 784) # flatten spatial dim
            X = X.to(device) # put on GPU

            # y = F.one_hot(y, num_classes=10) # convert labels to one-hot encoding
            y = y.to(device)

            model.zero_grad() # zero gradients

            out = model.forward(X) # forward pass

            loss = criterion(out, y) # compute loss
            loss.backward() # backpropagate
            optimizer.step() # update model

        model.eval() # Evaluation
        n_correct=0
        n = 0
        for i, (X,y) in enumerate(test_loader):
            if flatten_input: # Use for feed-forward networks
                X = X.view(-1, 1, 784) # flatten spatial dim
            X = X.to(device) # put on GPU
            y = y.to(device)

            with torch.no_grad():
                out = model.forward(X) # forward pass

            n_correct += torch.sum(torch.argmax(out,dim=-1).view(-1) == y)
            n += len(y)
        acc = (1.*n_correct)/(1.*n) # record accuracy
        t2 = time() - t1 # record time

        accs.append(acc)
        times.append(t2)

        print('Epoch {} error: {:.4f}'.format(epoch+1, 1-acc))
        print('Epoch {} time : {:.4f}'.format(epoch+1, t2))
        print()
        
    return(accs, times)    

#### 1-layer MLP (feed-forward neural net)

In [0]:
class MLP(nn.Module):

    def __init__(self, input_dim=28**2, hidden_dim=64, output_dim=10):
        super(MLP, self).__init__()

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.bn = torch.nn.BatchNorm1d(1) # batch normalization
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        out = F.relu(self.bn(self.fc1(x)))
        out = self.fc2(out)
        return(out.squeeze(1))

#### ResNet

In [0]:
# Adapted from https://github.com/rtqichen/torchdiffeq/blob/master/examples/odenet_mnist.py
def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride)

def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride)

def norm(dim):
    '''Group normalization'''
    return nn.GroupNorm(min(32, dim), dim)

class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)

class ResBlock(nn.Module):  
    '''Implements the residual block described in 'Identity Mappings in Deep
        Residual Networks, by Kaiming He et al.'''
    def __init__(self, in_planes, out_planes, stride=1, downsample=None):
        super(ResBlock, self).__init__()

        self.norm1 = norm(in_planes)
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride)
        self.norm2 = norm(out_planes)
        self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=1)

    def forward(self, x):
        shortcut = x
        out = F.relu(self.norm1(x))
        out = self.conv1(out)
        out = F.relu(self.norm2(out))
        out = self.conv2(out)

        return out + shortcut

# Layers to downsample input before going to residual blocks or ODE block
downsampling_layers = [
            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1),
            norm(64), nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=2),
            norm(64), nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=2),
        ]

# Fully connected layers (after doing residual blocks or ODE block)
fc_layers = [norm(64), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(), nn.Linear(64, 10)]

# Now create model
res_layers = [ResBlock(64,64) for _ in range(6)]

def make_resnet():
    return nn.Sequential(*downsampling_layers, *res_layers, *fc_layers)

#### Neural ODE (custom)

In [0]:
def rk(fun, t_span, y0, h):
    '''Runge-Kutta 4th order method
        Args: - fun is derivative
              - t_span is tuple of start and end time
              - y0 is initial condition
              - h is step size (fixed in this case)'''
    weights = torch.tensor([1/6, 1/3, 1/3, 1/6])[:,None,None,None,None].to(y0.device)
    n_steps = int((t_span[1]-t_span[0])/h)
    y = y0 # set state to initial state
    t = t_span[0]
    for _ in range(n_steps):
        k1 = h * fun(t     , y)
        k2 = h * fun(t+.5*h, y+.5*k1)
        k3 = h * fun(t+.5*h, y+.5*k2)
        k4 = h * fun(t+   h, y+   k3)
        t = t + h
        y = y + torch.sum(weights * torch.stack([k1,k2,k3,k4], dim=0), dim=0)
    return y

class ODE_block(ResBlock):

    def __init__(self, planes=64):
        super(ODE_block, self).__init__(planes,planes)

        self.t_span = nn.Parameter(torch.tensor([0.,2.], requires_grad=True))

    def ode(self, t, x): 
        '''Function representing the ODE to be solved. It is almost the same as 
            the forward pass in the residual block, but we do not add the 
            shortcut connection in this case.'''
        out = F.relu(self.norm1(x))
        out = self.conv1(out)
        out = F.relu(self.norm2(out))
        out = self.conv2(out)
        return out

    def forward(self, y):
        '''Forward pass using ODE solver'''
        return rk(fun=self.ode, t_span=self.t_span, h=.5, y0=y) 

def make_odenet():
    return nn.Sequential(*downsampling_layers, ODE_block(), *fc_layers)

#### Neural ODE (from torchdiffeq package)
Used to compare custom ODE with authors' adjoint implementation

In [0]:
from torchdiffeq import odeint, odeint_adjoint

class ODE_Func(nn.Module):  
    '''Implements the residual block described in 'Identity Mappings in Deep
        Residual Networks, by Kaiming He et al.'''
    def __init__(self, dim, stride=1, downsample=None):
        super(ODE_Func, self).__init__()

        self.norm1 = norm(dim)
        self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=stride)
        self.norm2 = norm(dim)
        self.conv2 = nn.Conv2d(dim, dim, kernel_size=1, stride=stride, padding=1)

    def forward(self, t, x):
        out = F.relu(self.norm1(x))
        out = self.conv1(out)
        out = F.relu(self.norm2(out))
        out = self.conv2(out)

        return out

class ODE_block_adjoint(ODE_block):
    def __init__(self, planes=64):
        super(ODE_block_adjoint, self).__init__(planes)
        self.t_span = nn.Parameter(torch.tensor([0.,2.], requires_grad=True))
        self.t = torch.tensor([0.,.5,1.,1.5,2.])
        self.fun = ODE_Func(planes)

    def forward(self, x):
        out = odeint_adjoint(self.fun, x, self.t, method='rk4')
        return out[1]

def make_odenet_adjoint():
    return nn.Sequential(*downsampling_layers, ODE_block_adjoint(), *fc_layers)

### Run trials

In [0]:
epochs_per_trial = 20
n_trials = 3

res = []

for trial in range(n_trials):

    mlp = MLP()
    accs, times = train(model=mlp, device='cuda', epochs=epochs_per_trial, flatten_input=True)
    res.append({'model':'mlp', 'accuracy':torch.tensor(accs).numpy(), 'time':times})

    resnet = make_resnet()
    accs, times = train(resnet, epochs=epochs_per_trial)
    res.append({'model':'resnet', 'accuracy':torch.tensor(accs).numpy(), 'time':times})

    ode_net = make_odenet()
    accs, times = train(ode_net, epochs=epochs_per_trial)
    res.append({'model':'rk-net', 'accuracy':torch.tensor(accs).numpy(), 'time':times})

    odenet_adjoint = make_odenet_adjoint()
    accs, times = train(odenet_adjoint, epochs=epochs_per_trial)
    res.append({'model':'ode-net', 'accuracy':torch.tensor(accs).numpy(), 'time':times})

| |                        #                       | 1874 Elapsed Time: 0:01:43
| | #                                                 | 3 Elapsed Time: 0:00:00

Epoch 1 error: 0.0082
Epoch 1 time : 107.0992



| |                                #               | 1874 Elapsed Time: 0:01:42
\ | #                                                 | 3 Elapsed Time: 0:00:00

Epoch 2 error: 0.0073
Epoch 2 time : 106.2570



| |                              #                 | 1874 Elapsed Time: 0:01:39
\ | #                                                 | 3 Elapsed Time: 0:00:00

Epoch 3 error: 0.0095
Epoch 3 time : 102.8515



| |                                       #        | 1874 Elapsed Time: 0:01:39
| | #                                                 | 3 Elapsed Time: 0:00:00

Epoch 4 error: 0.0078
Epoch 4 time : 103.8910



| |                             #                  | 1874 Elapsed Time: 0:01:42
| | #                                                 | 3 Elapsed Time: 0:00:00

Epoch 5 error: 0.0062
Epoch 5 time : 106.6296



\ |                              #                  | 773 Elapsed Time: 0:00:42

#### (optional) write results to file

In [0]:
# # write results to file (optional)
# from google.colab import files
# import pickle

# fname = 'results.pkl'
# pickle.dump(res, open(fname, 'wb'))
# files.download(fname)

#### Print out results

In [0]:
res = pd.DataFrame(res)

acc_df = pd.DataFrame(res['accuracy'].values.tolist()).T
time_df = pd.DataFrame(res['time'].values.tolist()).T

acc_df.columns = res.model.values
time_df.columns = res.model.values

print('ERROR')
print('Minimum values:')
print(1-acc_df.max().groupby(level=0).mean())
print()

print('std:')
print(acc_df.max().groupby(level=0).std())

print('\n\nTIME')
print('Mean values:')
print(time_df.sum().groupby(level=0).mean())
print()

print('std:')
print(time_df.sum().groupby(level=0).std())