In [1]:
!pip install torchdiffeq --quiet

In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import time
import os
import matplotlib.pyplot as plt
from torchdiffeq import odeint_adjoint as odeint

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [4]:
def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


def norm(dim):
    return nn.GroupNorm(min(32, dim), dim)

In [None]:
class NormResBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(NormResBlock, self).__init__()
        
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        
        self.layers = nn.Sequential(
            norm(inplanes),
            self.relu,
            conv3x3(inplanes, planes, stride)
        )
        
        self.middle_layers = nn.Sequential(
            norm(planes),
            self.relu,
            conv1x1(inplanes,planes,stride),
            conv3x3(planes, planes)
        )

    def forward(self, x):
        shortcut = x
        out = self.layers(x)
        
        if self.downsample is not None:
            shortcut = self.downsample(out)
        
        out = self.middle_layers(out)
        return out + shortcut

In [5]:
class Conv2d(nn.Module):
    def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):
        super(Conv2d, self).__init__()
        module = nn.ConvTranspose2d if transpose else nn.Conv2d
        conv_layer = module(
            dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,
            bias=bias
        )
        self.add_module("conv_layer", conv_layer)

    def forward(self, t, x):
        tt = torch.ones_like(x[:, :1, :, :]) * t
        ttx = torch.cat([tt, x], 1)
        return self.conv_layer(ttx)

In [None]:
class ODEfunc(nn.Module):
    def __init__(self, dim):
        super(ODEfunc, self).__init__()

        self.layers = nn.Sequential(
            norm(dim),
            nn.ReLU(inplace=True),
            Conv2d(dim, dim, 3, 1, 1),
            norm(dim),
            nn.ReLU(inplace=True),
            Conv2d(dim, dim, 3, 1, 1),
            norm(dim)
        )

    def forward(self, t, x):
        out = self.layers[0](x)
        out = self.layers[1](out)
        out = self.layers[2](t, out)
        out = self.layers[3](out)
        out = self.layers[4](out)
        out = self.layers[5](t, out)
        out = self.layers[6](out)
        return out

class ODEBlock(nn.Module):
    def __init__(self, odefunc):
        super(ODEBlock, self).__init__()
        self.integration_time = torch.tensor([0, 1]).float()
        self.odefunc = odefunc

    def forward(self, x):
        self.integration_time = self.integration_time.to(x.device)
        out = odeint(self.odefunc, x, self.integration_time, rtol=1e-3, atol=1e-3)
        return out[1]

In [None]:
def inf_loop(iterable):
    iterator = iterable.__iter__()
    while 1:
        try:
            yield iterator.__next__()
        except StopIteration:
            iterator = iterable.__iter__()


def lr_decay(batch_size, batch_denom, batches_per_epoch, boundary_epochs, decay_rates):
    intital = 0.1 * batch_size / batch_denom

    ends = [int(batches_per_epoch * epoch) for epoch in boundary_epochs]
    vals = [intital * decay for decay in decay_rates]

    def lr_fn(itr):
        lt = [itr < b for b in ends] + [True]
        i = np.argmax(lt)
        return vals[i]

    return lr_fn

In [6]:
def mnist_loaders(batch_size=128, test_batch_size=1000, perc=1.0):
    transform_train = transforms.Compose([
        transforms.RandomCrop(28, padding=4),
        transforms.ToTensor(),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])

    test_loader = DataLoader(
        datasets.MNIST(root='.data/mnist', train=False, download=True, transform=transform_test),
        batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
    )
    
    train_loader = DataLoader(
        datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_train), batch_size=batch_size,
        shuffle=True, num_workers=2, drop_last=True
    )

    train_eval_loader = DataLoader(
        datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_test),
        batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
    )


    return train_loader, test_loader, train_eval_loader

In [None]:
# generic accuracy checker
def accuracy(model, dataset_loader):
    total_correct = 0
    for x, y in dataset_loader:
        x = x.to(device)
        y = np.array(np.array(y.numpy())[:, None] == np.arange(10)[None, :], dtype=int)

        target_class = np.argmax(y, axis=1)
        predicted_class = np.argmax(model(x).cpu().detach().numpy(), axis=1)
        total_correct += np.sum(predicted_class == target_class)
    return total_correct / len(dataset_loader.dataset)

In [8]:
def train_and_evalute(model_name,epochs):
  train_acc_l = []
  val_acc_l = []

  is_odenet = model_name == 'ODENet'

  if is_odenet == True:
      downsampling_layers = [nn.Conv2d(1, 64, 3, 1),norm(64),nn.ReLU(inplace=True),nn.Conv2d(64, 64, 4, 2, 1),norm(64),nn.ReLU(inplace=True),nn.Conv2d(64, 64, 4, 2, 1),]

  elif is_odenet == False:
      downsampling_layers = [nn.Conv2d(1, 64, 3, 1),NormResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),NormResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),]

  train_loader, test_loader, train_eval_loader = mnist_loaders(128, 1000)

  data_gen = inf_loop(train_loader)
  batches_per_epoch = len(train_loader)
  
  layers = [ODEBlock(ODEfunc(64))] if is_odenet else [NormResBlock(64, 64) for _ in range(6)]
  layers_l = [norm(64), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), nn.Linear(64, 10)]

  model = nn.Sequential(*downsampling_layers, *layers, *layers_l).to(device)
  criterion = nn.CrossEntropyLoss().to(device)
  optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)


  lr_fn = lr_decay(128, batch_denom=128, batches_per_epoch=batches_per_epoch, boundary_epochs=[60, 100, 140],decay_rates=[1, 0.1, 0.01, 0.001])

  for itr in range(epochs * batches_per_epoch):
      for param_group in optimizer.param_groups:
          param_group['lr'] = lr_fn(itr)

      optimizer.zero_grad()
      x, y = data_gen.__next__()
      x = x.to(device)
      y = y.to(device)
      mod = model(x)
      loss = criterion(mod, y)

      loss.backward()
      optimizer.step()

      if itr % batches_per_epoch == 0:
          with torch.no_grad():
              train_acc = accuracy(model, train_eval_loader)
              train_acc_l.append((train_acc,itr // batches_per_epoch))
              val_acc = accuracy(model, test_loader)
              val_acc_l.append((val_acc,itr // batches_per_epoch))
            print(f"Epoch {itr // batches_per_epoch:04d} | Train Acc {train_acc:.4f} | Test Acc {val_acc:.4f}")

  return model, train_acc_l, val_acc_l

In [None]:
epochs = 500
print("ODENET-------------------------------")
model_odenet, train_acc_ode, val_acc_ode = train_and_evalute('ODENet',epochs)
print("RESNET-------------------------------")
model_resnet, train_acc_res, val_acc_res  = train_and_evalute('ResNet',epochs)

In [None]:
plt.figure()
x_val_ode = [x[1] for x in train_acc_ode]
y_val_ode = [x[0] for x in train_acc_ode]

x_val_res = [x[1] for x in train_acc_res]
y_val_res = [x[0] for x in train_acc_res]

plt.plot(x_val_ode[2:], y_val_ode[2:], label='ODE Training Accuracy', linewidth=1.5)
plt.plot(x_val_res[2:], y_val_res[2:], label='RES Training Accuracy', linewidth=1.5)
plt.scatter(x_val_ode[2:], y_val_ode[2:], s=5)
plt.scatter(x_val_res[2:], y_val_res[2:], s=5)
plt.legend()
plt.title("Training Accuracy")
plt.ylabel("Accuracy")
plt.xlabel("Epochs")

In [None]:
plt.figure()
x_val_ode = [x[1] for x in val_acc_ode]
y_val_ode = [x[0] for x in val_acc_ode]

x_val_res = [x[1] for x in val_acc_res]
y_val_res = [x[0] for x in val_acc_res]

plt.plot(x_val_ode[2:], y_val_ode[2:], label='ODE Testing Accuracy', linewidth=1.5)
plt.plot(x_val_res[2:], y_val_res[2:], label='RES Testing Accuracy', linewidth=1.5)
plt.scatter(x_val_ode[2:], y_val_ode[2:], s=5)
plt.scatter(x_val_res[2:], y_val_res[2:], s=5)
plt.legend()
plt.title("Testing Accuracy")
plt.ylabel("Accuracy")
plt.xlabel("Epochs")