In [1]:
# Make boxes window width. 
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [2]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import time
import datetime
torch.cuda.is_available()

True

In [14]:
class Block(nn.Module):
    def __init__(self):
        super(Block, self).__init__()
    def forward(self, x, t, net_params):
        x = F.relu(F.conv2d(x, net_params[0:9].view(1,1,3,3), padding=1))
        x = F.relu(F.conv2d(x, net_params[9:18].view(1,1,3,3), padding=1))
        return x

In [15]:
def RK4(f, x0, t0, t1, N,  net_params):
    h = (t1 - t0) / float(N) # calculate step size
    solution = [x0] # initialize saved dynamics
    t = t0
    for i in range(N):
        k1 = f(solution[i], t, net_params)
        k2 = f(solution[i] + h * k1 / 2.0, t + h/2.0, net_params)
        k3 = f(solution[i] + h * k2 / 2.0, t + h/2.0, net_params)
        k4 = f(solution[i] + h * k3, t + h, net_params)
        solution.append(solution[i] + h/6.0 * (k1 + k2 + k3 + k4))
        t = t + h  
    
    return solution

In [16]:
# dumb but will be replaced when ragged_tensor is a native pytorch feature
def tuple_add(*tuples):
    return tuple(sum(i) for i in zip(*tuples))
def tw(a, weight):
    return tuple(i * weight for i in a)

def RK4_backward(f, x0, t0, t1, N,  net_params):
    h = (t1 - t0) / float(N) # calculate step size
    solution = [x0] # initialize saved dynamics
    t = t0
    for i in range(N):
        k1 = f(solution[i], t, net_params)        
        k2 = f(tuple_add(solution[i], tuple(h/2.0 * j for j in k1)), t + h/2.0, net_params)
        k3 = f(tuple_add(solution[i], tuple(h/2.0 * j for j in k2)), t + h/2.0, net_params)
        k4 = f(tuple_add(solution[i], tuple(h * j for j in k2)), t + h, net_params)
        solution.append(tuple_add(solution[i], tw(k1,h/6.0), tw(k2,h/6.0), tw(k3,h/6.0), tw(k4,h/6.0)))
        t = t + h  
    return solution

In [17]:
class Integrate(torch.autograd.Function):
    def __deepcopy__(self, memo):
        return Integrate(copy.deepcopy(memo))
    
    @staticmethod
    def forward(ctx, Integrator, Integrator_backwards, f, x0, t0, t1, N, net_params):
        
        # Forward, Runge-Kutta 4th Order. 
        
        # Forward integration
        solution = Integrator(f, x0, t0, t1, N, net_params)
            
        # Save for jacobian calculations in backward()
        ctx.save_for_backward(x0,t0,t1)
        ctx.net_params = net_params
        ctx.solution = solution
        ctx.Integrator_backwards = Integrator_backwards
        ctx.N = N
        ctx.f = f
        
        return solution[-1]
    
    @staticmethod
    def backward(ctx, dL_dz1):
        # Get all saved context
        z0, t0, t1 = ctx.saved_tensors
        net_params = ctx.net_params
        dynamics = ctx.solution
        z1 = dynamics[-1]
        N = ctx.N
        f = ctx.f
        
        # Convenience sizes
        batch_size = z0.size()[0]
        img_size = z0.size()[3]
        img_len = img_size ** 2

        # Compute derivative w.r.t. to end time of integration
        dL_dt1 = dL_dz1.view(batch_size,1,-1).bmm(f(z1, t1, net_params).view(batch_size,-1,1))  # Derivative of loss w.r.t t1
        
        # Initial Condition
        s0 = (z1, dL_dz1, torch.zeros((batch_size, net_params.numel()), dtype=torch.float32).to(torch.device("cuda:0")), -dL_dt1) # initial augmented state
        
        # augmented dynamics function
        # what I really want is a Tensorflow Ragged Tensor, and pytorch's implementation really isn't there yet
        def aug_dynamics(s, t, theta):
           
            with torch.enable_grad():
                gradients = [torch.autograd.functional.vjp(f, 
                                                           (s[0][i,:,:].unsqueeze(0), t, theta), 
                                                           v=-s[1][i,:,:].unsqueeze(0)
                                                          )[1] for i in range(batch_size)]
            
            return (f(s[0],t,theta),
                    torch.cat([gradient[0] for gradient in gradients], dim=0), 
                    torch.cat([gradient[2].reshape(1,18) for gradient in gradients], dim=0), 
                    torch.cat([gradient[1].reshape(1,1) for gradient in gradients], dim=0).to(torch.device("cuda:0")),
                   )
        
        # Integrate backwards
        
        with torch.no_grad(): back_dynamics = ctx.Integrator_backwards(aug_dynamics, s0, t1, t0, N, net_params)
        # Extract derivatives
        _, dL_dz0, dL_dtheta, dL_dt0 = back_dynamics[-1]
        
        # must return something for every input to forward, None for non-tensors
        return None, None, None, dL_dz0, dL_dt0, dL_dt1, None, dL_dtheta

In [18]:
class ODENet(nn.Module):
    def __init__(self, solver, solver_b, f, solver_params):
        super(ODENet, self).__init__()
        
        self.f = f()
        
        self.int_f = solver
        self.int_b = solver_b
        self.Integrate = Integrate()
        
        self.solver_params = solver_params
        self.N = solver_params["N"]
        self.h = (solver_params["t1"] - solver_params["t0"]) / solver_params["N"]
        self.t0 = torch.tensor(float(solver_params["t0"]), requires_grad=True)
        self.t1 = torch.tensor(float(solver_params["t1"]), requires_grad=True)
        self.net_params = torch.nn.parameter.Parameter(torch.Tensor(18).normal_(mean=0, std=0.1,generator=None), requires_grad=True)

        self.avg_pool = torch.nn.MaxPool2d(2, stride=2, padding=0)
        self.fc1 = nn.Linear(196, 50)
        self.fc2 = nn.Linear(50, 10)
        
    def forward(self, x):
        x = self.Integrate.apply(self.int_f, self.int_b, self.f, x, self.t0, self.t1, self.N, self.net_params) # Vanilla RK4
        x = self.avg_pool(x)
        x = x.view(-1, 196) 
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        
        return F.log_softmax(x, dim=1)

In [19]:
def train(net, train_loader, test_loader, hyperparameters):
    
    lr = hyperparameters["lr"]
    n_epochs = hyperparameters["n_epochs"]
    momentum = hyperparameters["momentum"]
    
    optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=momentum)
    
    start_TIME = time.time()
    losses = []
    for i in range(n_epochs):
        
        
        
        # Train
        net.train()
        train_losses = []
        for j, (data, label) in enumerate(train_loader):
            data = data.to(torch.device("cuda:0"))
            label = label.to(torch.device("cuda:0"))
            optimizer.zero_grad()
            output = net(data)
            loss = F.nll_loss(output, label)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())
        
        """ Check ur gradients if the accuracy sucks
        for name, param in net.named_parameters():
            if param.requires_grad:
                print(name, param.data, param.grad)
        """
        
        num_correct, test_losses = test(net, test_loader)
        losses.append([train_losses, test_losses])
        
        # Report
        print(
          "EPOCH", i, 
          "time", print(datetime.datetime.now()), "\n"
          "Avg Train Loss", sum(train_losses)/len(train_losses), "\n"
          "Avg Test Loss", sum(test_losses)/len(test_losses), "\n"
          "Test Accuracy", (num_correct / float(len(test_loader.dataset)) * 100).item(), "%"
         )
        print("----------------------------------------")
        
        torch.save(net.state_dict(), "Test1/test_e" + str(i) + ".pth")
        
    return losses    

In [20]:
def test(net, test_loader):
    # Test
    net.eval()
    test_losses = []
    num_correct = 0
    with torch.no_grad():
        for j, (data, label) in enumerate(test_loader):
            data = data.to(torch.device("cuda:0"))
            label = label.to(torch.device("cuda:0"))
            output = net(data)
            loss = F.nll_loss(output, label)
            test_losses.append(loss.item())
            num_correct += label.eq(torch.max(output, 1, keepdim=False, out=None).indices).sum()

    
    return num_correct, test_losses 

In [21]:
batch_size_train = 64
batch_size_test = 1000
img_size = 28
img_len = 784

train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('.', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('.', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

In [22]:
hyperparameters = {
    "lr":  0.01,
    "n_epochs": 10,
    "momentum": 0.5,
}

solver_params = {
    "t0": 0,
    "t1": 3,
    "N": 3,
}

In [23]:
TestNetwork = ODENet(RK4, RK4_backward, Block, solver_params)

This behaves weirdly. The first time you run it, it will throw an error saying that it expected a cpu tensor but got a gpu tensor. Just run everything except the imports again, and it will work. I don't know why it does this.


In [None]:
TestNetwork = TestNetwork.to(torch.device("cuda:0"))
torch.set_default_tensor_type(torch.cuda.FloatTensor)
losses = train(TestNetwork, train_loader, test_loader, hyperparameters)


for name, param in TestNetwork.named_parameters():
    if param.requires_grad:
        print(name, param.data, param.data.type())

2020-06-01 01:05:44.864906
EPOCH 0 time None 
Avg Train Loss 0.44757418585087316 
Avg Test Loss 0.19449862092733383 
Test Accuracy 93.86000061035156 %
----------------------------------------
2020-06-01 01:14:35.999520
EPOCH 1 time None 
Avg Train Loss 0.18004053369609277 
Avg Test Loss 0.15600185543298722 
Test Accuracy 95.1199951171875 %
----------------------------------------
2020-06-01 01:23:25.932958
EPOCH 2 time None 
Avg Train Loss 0.14352391686425534 
Avg Test Loss 0.12495857700705529 
Test Accuracy 96.1500015258789 %
----------------------------------------
2020-06-01 01:32:10.327703
EPOCH 3 time None 
Avg Train Loss 0.12284196385823841 
Avg Test Loss 0.13107190355658532 
Test Accuracy 95.86000061035156 %
----------------------------------------
2020-06-01 01:40:55.840195
EPOCH 4 time None 
Avg Train Loss 0.11117728205441411 
Avg Test Loss 0.10996657982468605 
Test Accuracy 96.5999984741211 %
----------------------------------------
