<a href="https://colab.research.google.com/github/hsc-2752/CLEVER_RNN/blob/master/Untitled0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install torchdiffeq

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchdiffeq
  Downloading torchdiffeq-0.2.3-py3-none-any.whl (31 kB)
Installing collected packages: torchdiffeq
Successfully installed torchdiffeq-0.2.3


In [2]:
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import os
import torchvision
import numpy as np
import torch.nn as nn
from torch.autograd import Variable
from keras.models import Sequential
from keras.layers import LSTM, Dense
import numpy as np
%matplotlib inline

In [4]:
transform = transforms.Compose([transforms.ToTensor(),
                               transforms.Normalize(mean=[0.5],std=[0.5])])

data_train = datasets.MNIST(root = "./data/",
                            transform=transform,
                            train = True,
                            download = True)

data_test = datasets.MNIST(root="./data/",
                           transform = transform,
                           train = False)
data_loader_train = torch.utils.data.DataLoader(dataset=data_train,
                                                batch_size = 64,
                                                shuffle = True,
                                                 num_workers=2)

data_loader_test = torch.utils.data.DataLoader(dataset=data_test,
                                               batch_size = 64,
                                               shuffle = True,
                                                num_workers=2)
images, labels = next(iter(data_loader_train))
img = torchvision.utils.make_grid(images)

img = img.numpy().transpose(1,2,0)
std = [0.5]
mean = [0.5]
img = img*std+mean
#print([labels[i] for i in range(64)])
#plt.imshow(img)

In [7]:
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
     """3x3 convolution with padding"""    
     return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                      padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):    
  """1x1 convolution"""    
  return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

In [10]:
## Import the Adjoint Method (ODE Solver)
from torchdiffeq import odeint_adjoint as odeint

## Normal Residual Block Example

class ResBlock(nn.Module):

    #init a block - Convolve, pool, activate, repeat
    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(ResBlock, self).__init__()
        self.norm1 = nn.BatchNorm2d(inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.norm2 = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes)

    #Forward pass - pass output of one layer to the input of the next 
    def forward(self, x):
        shortcut = x
        out = self.relu(self.norm1(x))
        out = self.conv1(out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(out)

        return out + shortcut

## Ordinary Differential Equation Definition     

class ODEfunc(nn.Module):

    # init ODE variables
    def __init__(self, dim):
        super(ODEfunc, self).__init__()
        self.norm1 = nn.BatchNorm2d(dim)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = conv3x3(dim, dim)
        self.norm2 = nn.BatchNorm2d(dim)
        self.conv2 = conv3x3(dim, dim)
        self.norm3 = nn.BatchNorm2d(dim)
        self.nfe = 0

    # init ODE operations 
    def forward(self, t, x):
      #nfe = number of function evaluations per timestep
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.norm3(out)
        return out


 ## ODE block
class ODEBlock(nn.Module):

    #initialized as an ODE Function
    #count the time
    def __init__(self, odefunc):
        super(ODEBlock, self).__init__()
        self.odefunc = odefunc
        self.integration_time = torch.tensor([0, 1]).float()

    #foorward pass 
    #input the ODE function and input data into the ODE Solver (adjoint method)
    # to compute a forward pass
    def forward(self, x):
        self.integration_time = self.integration_time.type_as(x)
        out = odeint(self.odefunc, x, self.integration_time, rtol=1e-7, atol=1e-9)
        return out[1]

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value


## Main Method

if __name__ == '__main__':

    #Add Pooling
    downsampling_layers = [
         nn.Conv2d(1, 64, 3, 1),
         ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),
         ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),
     ]

    # Initialize the network as 1 ODE Block
    feature_layers = [ODEBlock(ODEfunc(64))] 
    # Fully connected Layer at the end
    fc_layers = [nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), nn.Linear(64, 10)]
  
    #The Model consists of an ODE Block, pooling, and a fully connected block at the end
    #model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers).to(torch.device)
    model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers)

    #Declare Gradient Descent Optimizer
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)

    n_epochs = 5
    for epoch in range(n_epochs):
      running_loss = 0.0
      running_correct = 0
      print("Epoch {}/{}".format(epoch, n_epochs))
      print("-"*10)
      for data in data_loader_train:
          X_train, y_train = data
          X_train, y_train = Variable(X_train), Variable(y_train)
          outputs = model(X_train)
          _,pred = torch.max(outputs.data, 1)
          optimizer.zero_grad()
          loss = nn.CrossEntropyLoss(outputs, y_train)
          
          loss.backward()
          optimizer.step()
          running_loss += loss.data[0]
          running_correct += torch.sum(pred == y_train.data)
      testing_correct = 0
      for data in data_loader_test:
          X_test, y_test = data
          X_test, y_test = Variable(X_test), Variable(y_test)
          outputs = model(X_test)
          _, pred = torch.max(outputs.data, 1)
          testing_correct += torch.sum(pred == y_test.data)
      print("Loss is:{:.4f}, Train Accuracy is:{:.4f}%, Test Accuracy is:{:.4f}".format(running_loss/len(data_train),
                                                                                        100*running_correct/len(data_train),
                                                                                        100*testing_correct/len(data_test)))


Epoch 0/5
----------


RuntimeError: ignored