In [1]:
# Make boxes window width. 
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [2]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import time
from torch.autograd.functional import vjp

In [201]:
class Block(nn.Module):
    def __init__(self):
        super(Block, self).__init__()
    def forward(self, x, t, net_params):
        size = x.size()
        x = F.relu(F.conv2d(x.view(-1,1,28,28), net_params[0:9].view(1,1,3,3), padding=1))
        x = F.relu(F.conv2d(x, net_params[9:18].view(1,1,3,3), padding=1))
        return x.view(size)

In [251]:
def flatten(*args):
    print("flattening")
    return(torch.cat(tuple(torch.flatten(arg) for arg in args), dim=0).view(1,-1))

def unflatten(x, n_e, sizes):
    print("unflattening")
    return (x[0, 0:n_e[0]].view(sizes[0]),
            x[0, n_e[0]:n_e[0] + n_e[1]].view(sizes[1]),
            x[0, n_e[0] + n_e[1]:n_e[0] + n_e[1] + n_e[2]].view(sizes[2]),
            x[0, n_e[0] + n_e[1] + n_e[2]:].view(sizes[3]),
            )

class Integrate(torch.autograd.Function):
    def __deepcopy__(self, memo):
        return Integrate(copy.deepcopy(memo))
    
    @staticmethod
    def forward(ctx, Integrator, f, x0, t0, t1, N, net_params):
        
        # Forward, Runge-Kutta 4th Order. 
        
        # Forward integration
        with torch.enable_grad(): #necessary for gradient calculations
            solution = Integrator(f, x0, t0, t1, N, net_params)
            
        # Save for jacobian calculations in backward()
        ctx.save_for_backward(x0,t0,t1)
        ctx.net_params = net_params
        ctx.solution = solution
        ctx.Integrator = Integrator
        ctx.N = N
        ctx.f = f
        
        return solution
    
    @staticmethod
    def backward(ctx, dL_dz1):
        print("backward")
        # Get all saved context
        z0, t0, t1 = ctx.saved_tensors
        net_params = ctx.net_params
        z1 = ctx.solution
        
        N = ctx.N
        f = ctx.f
        
        # Convenience sizes
        batch_size = z0.size()[0]
        print(z0.size())
        img_len = int(z0.numel() / batch_size)

        # Compute derivative w.r.t. to end time of integration
        dL_dt1 = dL_dz1.view(batch_size,1,-1).bmm(f(z1, t1, net_params).view(batch_size,-1,1))  # Derivative of loss w.r.t t1
        
        # Initial Condition
        num_elements = (z1.numel(), dL_dz1.numel(), batch_size * net_params.numel(), dL_dt1.numel())
        sizes = (z1.size(), dL_dz1.size(), (batch_size, net_params.numel()), dL_dt1.size())
        
        s0 = flatten(z1, dL_dz1, torch.zeros((batch_size, net_params.numel()), dtype=torch.float32), -dL_dt1) # initial augmented state
        
        print("s0", s0.size())
        
        # augmented dynamics function
        # what I really want is a Tensorflow Ragged Tensor, and pytorch's implementation really isn't there yet
        def aug_dynamics(s, t, theta):
            print("aug_dynamics")
            s = unflatten(s, num_elements, sizes)
            with torch.enable_grad(): 
                gradients = [vjp(f, 
                                 (s[0][i].unsqueeze(0), t, theta), 
                                  v=-s[1][i].unsqueeze(0)
                                 )[1] for i in range(batch_size)]
            y = flatten(f(s[0],t,theta),
                    torch.cat([gradient[0] for gradient in gradients], dim=0), 
                    torch.cat([gradient[2].reshape(1,18) for gradient in gradients], dim=0), 
                    torch.cat([gradient[1].reshape(1,1) for gradient in gradients], dim=0),
                   ).unsqueeze(2)
            print("finished aug_dynamics")   
            return y
        
        print("integrating backwards dynamics")
        # Integrate backwards
        print(s0.size())
        with torch.no_grad(): back_dynamics = ctx.Integrator(aug_dynamics, s0, t1, t0, N, net_params)
        
        # Extract derivatives
        _, dL_dz0, dL_dtheta, dL_dt0 = back_dynamics[-1]

        # must return something for every input to forward, None for non-tensors
        return None, None, None, dL_dz0, dL_dt0, dL_dt1, None, dL_dtheta 

In [252]:
class ODENet(nn.Module):
    def __init__(self, solver, f, solver_params):
        super(ODENet, self).__init__()
        
        self.f = f()
        
        self.int_f = solver
        self.Integrate = Integrate()
        
        self.solver_params = solver_params
        self.N = solver_params["N"]
        self.h = (solver_params["t1"] - solver_params["t0"]) / solver_params["N"]
        self.t0 = torch.tensor(float(solver_params["t0"]), requires_grad=True)
        self.t1 = torch.tensor(float(solver_params["t1"]), requires_grad=True)
        self.net_params = torch.nn.parameter.Parameter(torch.Tensor(18).normal_(mean=0, std=0.1,generator=None), requires_grad=True)

        self.avg_pool = torch.nn.MaxPool2d(2, stride=2, padding=0)
        self.fc1 = nn.Linear(196, 50)
        self.fc2 = nn.Linear(50, 10)
        
    def forward(self, x):
        x = self.Integrate.apply(self.int_f, self.f, to_v(x), self.t0, self.t1, self.N, self.net_params)
        x = self.avg_pool(x)
        x = x.view(-1, 196) 
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [253]:
def train(net, train_loader, test_loader, hyperparameters):
    
    lr = hyperparameters["lr"]
    n_epochs = hyperparameters["n_epochs"]
    momentum = hyperparameters["momentum"]
    
    optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=momentum)
    
    losses = []
    for i in range(n_epochs):
        
        # Train
        net.train()
        train_losses = []
        for j, (data, label) in enumerate(train_loader):
            
            optimizer.zero_grad()
            data.requires_grad = True
            output = net(data)
            print("output", output.size())
            loss = F.nll_loss(output, label)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())
        
        """ Check ur gradients if the accuracy sucks
        for name, param in net.named_parameters():
            if param.requires_grad:
                print(name, param.data, param.grad)
        """
        
        num_correct, test_losses = test(net, test_loader)
        losses.append([train_losses, test_losses])
        
        # Report
        print(
          "Avg Train Loss", sum(train_losses)/len(train_losses), "\n"
          "Avg Test Loss", sum(test_losses)/len(test_losses), "\n"
          "Test Accuracy", (num_correct / float(len(test_loader.dataset)) * 100).item(), "%"
         )
        print("----------------------------------------")
        
    return losses    

In [254]:
def test(net, test_loader):
    # Test
    net.eval()
    test_losses = []
    num_correct = 0
    with torch.no_grad():
        for j, (data, label) in enumerate(test_loader):

            output = net(data)
            loss = F.nll_loss(output, label)
            test_losses.append(loss.item())
            num_correct += label.eq(torch.max(output, 1, keepdim=False, out=None).indices).sum()

    
    return num_correct, test_losses 

In [255]:
batch_size_train = 1
batch_size_test = 1000
img_size = 28
img_len = 784


torch.manual_seed(0)


train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/Users/louis/Desktop/neuralODE/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/Users/louis/Desktop/neuralODE/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

In [256]:
#from torch.autograd.functional import jacobian
from torch import inverse as inv
from torch import bmm
from torch.autograd import grad

def jacobian(f, x):
    """Computes the Jacobian of f w.r.t x.

    This is according to the reverse mode autodiff rule,

    sum_i v^b_i dy^b_i / dx^b_j = sum_i x^b_j R_ji v^b_i,

    where:
    - b is the batch index from 0 to B - 1
    - i, j are the vector indices from 0 to N-1
    - v^b_i is a "test vector", which is set to 1 column-wise to obtain the correct
        column vectors out ot the above expression.

    :param f: function R^N -> R^N
    :param x: torch.tensor of shape [B, N]
    :return: Jacobian matrix (torch.tensor) of shape [B, N, N]
    """

    B, N = x.shape
    y = f(x)
    jacobian = list()
    for i in range(N):
        v = torch.zeros_like(y)
        v[:, i] = 1.
        dy_i_dx = grad(y,
                       x,
                       grad_outputs=v,
                       retain_graph=True,
                       create_graph=True,
                       allow_unused=True)[0]  # shape [B, N]
        jacobian.append(dy_i_dx)

    jacobian = torch.stack(jacobian, dim=2).requires_grad_()

    return jacobian


"""
Batched gradient wrapper for newton().
    Input f is a function
    x should be a 2-D tensor, (batch, -1), use lambda in f to reshape if necessary
    output of f() should be same dimension as x. 
"""  



def b_grad(f, x, x0):
    print("b_grad")
    j_list = []
    for j in range(x.size()[0]):
        with torch.enable_grad():
            print(j)
            j_list += [torch.autograd.functional.jacobian(lambda x: f(x, x0[j:j+1]), x[j:j+1]).view(1, x.size()[1], x.size()[1])]
    return(torch.cat(j_list, dim=0)) 
    

"""
def b_grad(f, x):
    return jacobian(lambda x: f(x), x.view(x.size()[0], -1))
    
"""

"""
Newton's method for functions f: batch x R^n -> batch x R^n
    Input f is a function
    x0 is the initial guess, 2-D tensor, (batch, - 1). 
    output of f() should be same dimension as x. 
"""
def newton(f, x0):
    print("newton")
    x = x0
    tolerance = torch.Tensor((1e-5,))
    while(True):
        gradients = b_grad(f, x, x0)
        print(gradients.size())
        sx, x = x, x - bmm(inv(gradients), f(x, x0))
        if(torch.norm(x - sx, p=2) < tolerance): break
    
    return x

In [257]:
def to_v(x):
    return x.view(x.size()[0], -1, 1)
def to_m(x):
    return x.view(x.size()[0], 1, int(x.size()[1] ** 0.5), int(x.size()[1] ** 0.5))

def euler_backwards(f, x0, t0, t1, N, net_params):
    print("euler_backwards")
    h = (t1 - t0) / float(N) # calculate step size
    x, t = x0, t0
    for i in range(N):
        x = newton(lambda z, x0: z - f(z, t, net_params) - x0.view(z.size()), x)
        t = t + h
    return to_m(x)

In [258]:
hyperparameters = {
    "lr":  0.01,
    "n_epochs": 1,
    "momentum": 0.5,
}

solver_params = {
    "t0": 0,
    "t1": 3,
    "N": 2,
}

In [259]:
TestNetwork = ODENet(euler_backwards,Block, solver_params)
losses = train(TestNetwork, train_loader, test_loader, hyperparameters)
torch.save(TestNetwork.state_dict(), "test.pth")

euler_backwards
newton
b_grad
0
torch.Size([1, 784, 784])
b_grad
0
torch.Size([1, 784, 784])
b_grad
0
torch.Size([1, 784, 784])
b_grad
0
torch.Size([1, 784, 784])
newton
b_grad
0
torch.Size([1, 784, 784])
b_grad
0
torch.Size([1, 784, 784])
b_grad
0
torch.Size([1, 784, 784])
output torch.Size([1, 10])
backward
torch.Size([1, 784, 1])
flattening
s0 torch.Size([1, 1587])
integrating backwards dynamics
torch.Size([1, 1587])
euler_backwards
newton
b_grad
0
aug_dynamics
unflattening
flattening
finished aug_dynamics


RuntimeError: 

In [250]:
def square(x):
    return x**2

def derivative(f, x):
    return jacobian(f, x)

def dxdt(f, x):
    return torch.autograd.functional.jacobian(f, x)

def dzdt(f, z):
    return jacobian(f, z)

y = torch.rand((1,200), requires_grad=True)

start = time.time()
for i in range(0, 1):
    dzdt(lambda x: derivative(square, x), y).view(1,-1)
print(time.time() - start)

start = time.time()
for i in range(0, 1):
    dxdt(lambda x: derivative(square, x), y).view(1,-1)
print(time.time() - start)

1.1684999465942383


KeyboardInterrupt: 