In [1]:
# Make boxes window width. 
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [206]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd.functional import vjp
import time

if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.set_default_tensor_type(torch.cuda.FloatTensor)
else:
    device = torch.device("cpu")

In [207]:
class Block(nn.Module):
    def __init__(self):
        super(Block, self).__init__()
    def forward(self, x, t, net_params):
        size = x.size()
        x = F.relu(F.conv2d(x, net_params[0:9*size[1]*size[1]].view(size[1],size[1],3,3), padding=1))
        x = F.relu(F.conv2d(x, net_params[9*size[1]*size[1]:18*size[1]*size[1]].view(size[1],size[1],3,3), padding=1))
        return x

In [208]:
RK4 = ((  0,),
       (1/2, 1/2,),
       (1/2,   0,  1/2,),
       (  1,   0,    0,   1,),
       (1/6, 1/3, 1/3, 1/6,))

EF = ((0,),
      (1,))

"""
General Runge-Kutta Solver.
https://en.wikipedia.org/wiki/Runge–Kutta_methods

b_tableau, nested tuple, contains weights of integration. 
f, function, is the function to iterate. Should only be a function of x, t. 
x0, FloatTensor, is the intial condition.
t0, FloatTensor, is the start time of integration.
t1, FloatTensor, is the end time of integration.
N, int, is the desired number of timesteps. 
"""

def explicit_RK(b_tableau, f, x0, t0, t1, N):        
    h = (t1 - t0) / float(N) # calculate step size
    x = x0 # initialize saved dynamics
    mesh = (t0 + h * i for i in range(N)) # generator of time values
    for time in mesh:
        
        k = [f(x, time + h*b_tableau[0][0])] # Covers the first row of the Butcher tableau. 
        for i, row in enumerate(b_tableau[1:-1]): # Covers the middle rows of the Butcher tableau.
            k.append(f(x + sum(w * k[idx] * h for idx, w in enumerate(row[1:])), time + row[0] * h)) # calculate k's. 
        x = x + sum(w * k_i * h for k_i, w in zip(k, b_tableau[-1])) # calculate timestep 
    return x

In [209]:
# Convenience tuple -> tensor function
def flatten(*args):
    return(torch.cat(tuple(torch.flatten(arg) for arg in args), dim=0).view(1,-1))

# Convenience tensor -> tuple function
def unflatten(x, n_e, sizes):
    return (x[0, 0:n_e[0]].view(sizes[0]),
            x[0, n_e[0]:n_e[0] + n_e[1]].view(sizes[1]),
            x[0, (n_e[0] + n_e[1]):(n_e[0] + n_e[1] + n_e[2])].view(sizes[2]),
            x[0, (n_e[0] + n_e[1] + n_e[2]):].view(sizes[3]),
            )

class Integrate(torch.autograd.Function):
    def __deepcopy__(self, memo):
        return Integrate(copy.deepcopy(memo))
    
    @staticmethod
    def forward(ctx, Integrator, f, x0, t0, t1, N, net_params):
                
        solution = Integrator(RK4, lambda x, t: f(x, t, net_params), x0, t0, t1, N)
            
        # Save for jacobian calculations in backward()
        ctx.save_for_backward(x0,t0,t1,net_params)
        ctx.solution = solution
        ctx.Integrator = Integrator
        ctx.N = N
        ctx.f = f
        
        return solution
    
    @staticmethod
    def backward(ctx, dL_dz1):
        # Get all saved context
        z0, t0, t1, net_params = ctx.saved_tensors
        z1 = ctx.solution
        N = ctx.N
        f = ctx.f
        
        # Convenience sizes
        batch_size = z0.size()[0]
        img_len = int(z0.numel() / batch_size)

        # Compute derivative w.r.t. to end time of integration
        dL_dt1 = dL_dz1.view(batch_size,1,-1).bmm(f(z1, t1, net_params).view(batch_size,-1,1))  # Derivative of loss w.r.t t1
        
        #print("dL_dt1", dL_dt1)
        
        # Initial Condition
        num_elements = (z1.numel(), dL_dz1.numel(), batch_size * net_params.numel(), dL_dt1.numel())
        sizes = (z1.size(), dL_dz1.size(), (batch_size, net_params.numel()), dL_dt1.size())
        s0 = flatten(z1, dL_dz1, torch.zeros((batch_size, net_params.numel()), dtype=torch.float32), -dL_dt1) # initial augmented state
        
        # augmented dynamics function
        def aug_dynamics(s, t, theta):
            s = unflatten(s, num_elements, sizes)

            with torch.enable_grad(): 
                gradients = [vjp(f, 
                                 (s[0][i].unsqueeze(0), t, theta), 
                                  v=-s[1][i].unsqueeze(0),
                                 )[1] for i in range(batch_size)]
                
            return flatten(f(s[0],t,theta),
                    torch.cat([gradient[0] for gradient in gradients], dim=0), 
                    torch.cat([gradient[2].unsqueeze(0) for gradient in gradients], dim=0), 
                    torch.cat([gradient[1].reshape(1,1) for gradient in gradients], dim=0),
                   )#.unsqueeze(2)
           

        # Integrate backwards
        with torch.enable_grad():
            s = ctx.Integrator(RK4, lambda x, t: aug_dynamics(x, t, net_params), s0, t1, t0, N)
        
        # Extract derivatives
        _, dL_dz0, dL_dtheta, dL_dt0 = unflatten(s, num_elements, sizes)
        
        # must return something for every input to forward, None for non-tensors
        return None, None, dL_dz0, dL_dt0, dL_dt1, None, dL_dtheta 

In [210]:
class ODENet(nn.Module):
    def __init__(self, solver, f, solver_params):
        super(ODENet, self).__init__()
        
        self.f = f()
        
        # Controls amount of parameters
        self.body_channels = 32
        
        # Head, used to mix stuff around and improve accuracy
        self.conv1 = nn.Conv2d(3, 16, 5, padding=2)
        self.conv2 = nn.Conv2d(16, self.body_channels,5,padding=2)
        
        
        # Body
        self.int_f = solver
        self.Integrate = Integrate()
        self.solver_params = solver_params
        self.N = solver_params["N"]
        self.h = (solver_params["t1"] - solver_params["t0"]) / solver_params["N"]
        self.t0 = torch.tensor(float(solver_params["t0"]), requires_grad=True)
        self.t1 = torch.tensor(float(solver_params["t1"]), requires_grad=True)
        self.net_params = torch.nn.parameter.Parameter(torch.Tensor(self.body_channels*self.body_channels*9*2).normal_(mean=0, std=0.1,generator=None), requires_grad=True)

        # Tails
        self.avg_pool = torch.nn.MaxPool2d(2, stride=2, padding=0)
        self.fc1 = nn.Linear(int(img_size*img_size* self.body_channels / 4), 512)
        self.fc2 = nn.Linear(512, 10)
        
    def forward(self, x):
        x = self.conv2(F.relu(self.conv1(x))) #initial dimensionality expansion
        x = self.Integrate.apply(self.int_f, self.f, x, self.t0, self.t1, self.N, self.net_params) # Vanilla RK4
        x = self.avg_pool(x)
        x = x.view(-1, int(img_size * img_size * self.body_channels / 4))
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [211]:
def train(net, train_loader, test_loader, hyperparameters):
    
    lr = hyperparameters["lr"]
    n_epochs = hyperparameters["n_epochs"]
    momentum = hyperparameters["momentum"]
    weight_decay = hyperparameters["weight_decay"]
    optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
    
    losses = []
    for i in range(n_epochs):
        
        # Train
        net.train()
        train_losses = []
        for j, (data, label) in enumerate(train_loader):
            optimizer.zero_grad()
            output = net(torch.cat((data, torch.zeros(data.size()), torch.zeros(data.size())), dim=1)) #augmented nodes
            loss = F.nll_loss(output, label)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())
                    
        """ Check ur gradients if the accuracy sucks
        for name, param in net.named_parameters():
            if param.requires_grad:
                print(name, param.data, param.grad)
        """
        
        num_correct, test_losses = test(net, test_loader)
        losses.append([train_losses, test_losses])
        
        # Report
        print(
          "Avg Train Loss", sum(train_losses)/len(train_losses), "\n"
          "Avg Test Loss", sum(test_losses)/len(test_losses), "\n"
          "Test Accuracy", (num_correct / float(len(test_loader.dataset)) * 100).item(), "%"
         )
        print("----------------------------------------")
        
    return losses    

In [212]:
def test(net, test_loader):
    # Test
    net.eval()
    test_losses = []
    num_correct = 0
    with torch.no_grad():
        for j, (data, label) in enumerate(test_loader):
            output = net(torch.cat((data, torch.zeros(data.size()), torch.zeros(data.size())), dim=1))
            
            print(output == output)
            loss = F.nll_loss(output, label)
            test_losses.append(loss.item())
            num_correct += label.eq(torch.max(output, 1, keepdim=False, out=None).indices).sum()

    
    return num_correct, test_losses 

In [216]:
batch_size_train = 64
batch_size_test = 1000
img_size = 28
img_len = 784

train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/Users/louis/Desktop/neuralODE/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=False)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/Users/louis/Desktop/neuralODE/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

In [217]:
hyperparameters = {
    "lr":  0.01,
    "n_epochs": 1,
    "momentum": 0.5,
    "weight_decay": 0.0,
}

solver_params = {
    "t0": 0,
    "t1": 1,
    "N": 2,
}

In [218]:
TestNetwork = ODENet(explicit_RK, Block, solver_params).to(device)
losses = train(TestNetwork, train_loader, test_loader, hyperparameters)
torch.save(TestNetwork.state_dict(), "test.pth")

KeyboardInterrupt: 