In [1]:
import numpy as np
import matplotlib.pyplot as plt
import keras
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchdiffeq
import tqdm

Using TensorFlow backend.


In [2]:
(X_train,Y_train),(X_test,Y_test) = keras.datasets.mnist.load_data()
X_train,X_test = X_train[:,np.newaxis]/255,X_test[:,np.newaxis]/255

In [3]:
class ODEFunc(nn.Module):
    
    def __init__(self, dim):
        super().__init__()
        self.conv1 = nn.Conv2d(dim+1,dim,3,padding=1)
        self.conv2 = nn.Conv2d(dim,dim,3,padding=1)
        self.conv3 = nn.Conv2d(dim,dim,3,padding=1)
    
    def forward(self, t, x):
        t = t.expand(x.size()[:1] + (1,) + x.size()[2:])
        x = torch.cat([x,t], dim=1)
        x = self.conv1(x)
        x = torch.relu(x)
        x = self.conv2(x)
        x = torch.relu(x)
        x = self.conv3(x)
        return x

In [4]:
class ODEBlock(nn.Module):
    
    def __init__(self, dim, rtol=1e-4, atol=1e-4):
        super().__init__()
        self.func = ODEFunc(dim)
        self.itime = torch.Tensor([0,1])
        self.rtol = rtol
        self.atol = atol
    
    def forward(self, x):
        itime = self.itime.type_as(x)
        return torchdiffeq.odeint_adjoint(
            self.func, x, itime, rtol=self.rtol, atol=self.atol)[1]

In [5]:
class Net(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.norm0 = nn.BatchNorm2d(1)
        self.conv1 = nn.Conv2d(1,8,7,padding=0)
        self.norm1 = nn.BatchNorm2d(8)
        self.conv2 = nn.Conv2d(8,16,7,padding=0)
        self.norm2 = nn.BatchNorm2d(16)
        self.ode3 = ODEBlock(16)
        self.conv3 = nn.Conv2d(16,32,3,padding=1)
        self.norm3 = nn.BatchNorm2d(32)
        self.ode4 = ODEBlock(32)
        self.conv4 = nn.Conv2d(32,64,3,padding=1)
        self.norm4 = nn.BatchNorm2d(64)
        self.ode5 = ODEBlock(64)
        self.conv5 = nn.Conv2d(64,128,3,padding=1)
        self.norm5 = nn.BatchNorm2d(128)
        self.fc = nn.Linear(128,10)
    
    def forward(self, x):
        x = self.norm0(x)
        x = self.conv1(x)
        x = self.norm1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = self.norm2(x)
        x = F.relu(x)
        x = F.max_pool2d(x,2)
        x = self.ode3(x)
        x = self.conv3(x)
        x = self.norm3(x)
        x = F.relu(x)
        x = F.max_pool2d(x,2)
        x = self.ode4(x)
        x = self.conv4(x)
        x = self.norm4(x)
        x = F.relu(x)
        x = F.max_pool2d(x,2)
        x = self.ode5(x)
        x = self.conv5(x)
        x = self.norm5(x)
        x = F.relu(x)
        x = F.max_pool2d(x,2)
        x = x.view((-1,128))
        x = self.fc(x)
        x = F.softmax(x, dim=-1)
        return x

In [6]:
device = torch.device('cuda')
data,test,label = X_train,X_test,Y_train
data = torch.Tensor(data).to(device)
test = torch.Tensor(test).to(device)
label = torch.Tensor(label).to(device).type(torch.long)
net = Net().to(device)
los = nn.CrossEntropyLoss().to(device)
opt = torch.optim.Adam(net.parameters())

In [7]:
net = net.train()
for _ in range(3):
    loss_avg = 0
    loss_cnt = 0
    for batch_idx in tqdm.tqdm_notebook(range(0,data.shape[0],128)):
        batch = data[batch_idx:batch_idx+128]
        opt.zero_grad()
        output = net(batch)
        loss = los(output, label[batch_idx:batch_idx+128])
        loss.backward()
        opt.step()
        loss_avg += loss.item()
        loss_cnt += 1
    print(loss_avg/loss_cnt)

HBox(children=(IntProgress(value=0, max=469), HTML(value='')))


1.546300677094124


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))


1.480781925766707


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))


1.4761413678939916


In [8]:
net = net.eval()
test_hit = 0
for batch_idx in tqdm.tqdm_notebook(range(0,test.shape[0],128)):
    batch = test[batch_idx:batch_idx+128]
    pred = net(batch)
    pred = torch.argmax(pred, dim=-1)
    pred = pred.cpu().detach().numpy()
    test_hit += np.sum(pred == Y_test[batch_idx:batch_idx+128])
test_hit/Y_test.shape[0]

HBox(children=(IntProgress(value=0, max=79), HTML(value='')))




0.9867