In [1]:
# Make boxes window width. 
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:80% !important; }</style>"))

In [2]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd.functional import vjp
import time

if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.set_default_tensor_type(torch.cuda.FloatTensor)
else:
    device = torch.device("cpu")

In [3]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Block, self).__init__()
        self.in_channels = in_channels
        self.bnorm = torch.nn.BatchNorm2d(in_channels)
    def forward(self, x, t, net_params):
        size = x.size()
        x = self.bnorm(x)
        x = F.relu(F.conv2d(x, net_params[0:9*size[1]*size[1]].view(size[1],size[1],3,3), padding=1))
        x = F.conv2d(x, net_params[9*size[1]*size[1]:18*size[1]*size[1]].view(size[1],size[1],3,3), padding=1)
        return x #self.bnorm(x)
    def num_params(self):
        num_conv = self.in_channels * self.in_channels * 9 * 2
        return num_conv
    
class BigBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(BigBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
    def forward(self, x, t, net_params):
        bsize = x.size(1)
        tk = 9 * bsize * bsize
        
        # First resnet block
        x1 = F.relu(F.conv2d(x, net_params[0:1*tk].view(bsize,bsize,3,3), padding=1))
        x2 = x + F.relu(F.conv2d(x1, net_params[1*tk:2*tk].view(bsize,bsize,3,3), padding=1))
        
        x1 = F.relu(F.conv2d(x, net_params[2*tk:3*tk].view(bsize,bsize,3,3), padding=1))
        x2 = x2 + F.relu(F.conv2d(x1, net_params[3*tk:4*tk].view(bsize,bsize,3,3), padding=1))
        
        x1 = F.relu(F.conv2d(x, net_params[4*tk:5*tk].view(bsize,bsize,3,3), padding=1))
        x2 = x2 + F.relu(F.conv2d(x1, net_params[5*tk:6*tk].view(bsize,bsize,3,3), padding=1))
        
        x1 = F.relu(F.conv2d(x, net_params[6*tk:7*tk].view(bsize,bsize,3,3), padding=1))
        x2 = x2 + F.relu(F.conv2d(x1, net_params[7*tk:8*tk].view(bsize,bsize,3,3), padding=1))
        
        return x2
    
    def num_params(self):
        return self.in_channels * self.in_channels * 9 * 8
 

class BigBlockClone(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(BigBlockClone, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.conv3 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.conv4 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        
    def forward(self, x, t, net_params):
        size = x.size()
        x1 = F.relu(self.conv1(x))
        x2 = x + F.relu(self.conv2(x1))
        x3 = F.relu(self.conv3(x2))
        x4 = x2 + F.relu(self.conv4(x3))
        
        return x4
    
    def num_params(self):
        return self.in_channels * self.in_channels * 9 * 4
    
    
class DepthWise(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DepthWise, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

    def forward(self, x, t, net_params):
        size = x.size()
        x1 = F.conv2d(x, net_params[0:1*9*size[1]*1].view(size[1],1,3,3), padding=1, groups=size[1])
        x1 = F.relu(F.conv2d(x, net_params[4*9*size[1]*1:4*9*size[1]*1 + size[1]*size[1]].view(size[1],size[1],1,1), padding=0))
        
        x2 = F.conv2d(x1, net_params[1*9*size[1]*1:2*9*size[1]*1].view(size[1],1,3,3), padding=1, groups=size[1])
        x2 = x + F.relu(F.conv2d(x2, net_params[4*9*size[1]*1 + size[1]*size[1]: 4*9*size[1]*1 + size[1]*size[1]*2].view(size[1],size[1],1,1), padding=0))
        
        x3 = F.conv2d(x2, net_params[2*9*size[1]*1:3*9*size[1]*1].view(size[1],1,3,3), padding=1, groups=size[1])
        x3 = F.relu(F.conv2d(x3, net_params[4*9*size[1]*1 + size[1]*size[1]*2:4*9*size[1]*1 + size[1]*size[1]*3].view(size[1],size[1],1,1), padding=0))
        
        x4 = F.conv2d(x3, net_params[3*9*size[1]*1:4*9*size[1]*1].view(size[1],1,3,3), padding=1, groups=size[1])
        x4 = x2 + F.relu(F.conv2d(x4, net_params[4*9*size[1]*1 + size[1]*size[1]*3:4*9*size[1]*1 + size[1]*size[1]*4].view(size[1],size[1],1,1), padding=0))
        return x4
    def num_params(self):
        return(self.in_channels * 9 * 4 + self.in_channels * self.in_channels * 1 * 4)
        
class DepthwiseClone(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DepthwiseClone, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        self.conv1a = nn.Conv2d(in_channels, in_channels, 3, groups=in_channels, padding=1)
        self.conv1b = nn.Conv2d(in_channels, in_channels, 1)
        
        self.conv2a = nn.Conv2d(in_channels, in_channels, 3, groups=in_channels, padding=1)
        self.conv2b = nn.Conv2d(in_channels, in_channels, 1)
        
        self.conv3a = nn.Conv2d(in_channels, in_channels, 3, groups=in_channels, padding=1)
        self.conv3b = nn.Conv2d(in_channels, in_channels, 1)
        
        self.conv4a = nn.Conv2d(in_channels, in_channels, 3, groups=in_channels, padding=1)
        self.conv4b = nn.Conv2d(in_channels, in_channels, 1)
        
    def forward(self, x, t, net_params):
        
        x1 = F.relu(self.conv1b(self.conv1a(x)))
        x2 = x + F.relu(self.conv2b(self.conv2a(x1)))
        
        x3 = F.relu(self.conv3b(self.conv3a(x2)))
        x4 = x2 + F.relu(self.conv4b(self.conv4a(x3)))
        
        return x4
    
    def num_params(self):
        return(self.in_channels * 9 * 4 + self.in_channels * self.in_channels * 1 * 4)
        

In [4]:
RK4 = ((  0,),
       (1/2, 1/2,),
       (1/2,   0,  1/2,),
       (  1,   0,    0,   1,),
       (1/6, 1/3, 1/3, 1/6,))

EF = ((0,),
      (1,))

"""
General Runge-Kutta Solver.
https://en.wikipedia.org/wiki/Runge–Kutta_methods

b_tableau, nested tuple, contains weights of integration. 
f, function, is the function to iterate. Should only be a function of x, t. 
x0, torch.Tensor, is the intial condition.
t0, torch.Tensor, is the start time of integration.
t1, torch.Tensor, is the end time of integration.
N, int, is the desired number of timesteps.

returns x, torch.Tensor, estimated solution of dy/dx = f(x,t) at time t1. 
"""

def explicit_RK(b_tableau, f, x0, t0, t1, N):        
    h = (t1 - t0) / float(N) # calculate step size
    x = x0 # initialize saved dynamics
    mesh = (t0 + h * i for i in range(N)) # generator of time values
    for time in mesh:
        
        k = [f(x, time + h*b_tableau[0][0])] # Covers the first row of the Butcher tableau. 
        for i, row in enumerate(b_tableau[1:-1]): # Covers the middle rows of the Butcher tableau.
            k.append(f(x + sum(w * k[idx] * h for idx, w in enumerate(row[1:])), time + row[0] * h)) # calculate k's. 
        x = x + sum(w * k_i * h for k_i, w in zip(k, b_tableau[-1])) # calculate timestep 
    return x

In [5]:
# Convenience tuple -> tensor function
def flatten(*args):
    return(torch.cat(tuple(torch.flatten(arg) for arg in args), dim=0).view(1,-1))

# Convenience tensor -> tuple function
def unflatten(x, n_e, sizes):
    return (x[0, 0:n_e[0]].view(sizes[0]),
            x[0, n_e[0]:n_e[0] + n_e[1]].view(sizes[1]),
            x[0, (n_e[0] + n_e[1]):(n_e[0] + n_e[1] + n_e[2])].view(sizes[2]),
            x[0, (n_e[0] + n_e[1] + n_e[2]):].view(sizes[3]),
            )

class Integrate(torch.autograd.Function):
    def __deepcopy__(self, memo):
        return Integrate(copy.deepcopy(memo))
    
    @staticmethod
    def forward(ctx, Integrator, f, x0, t0, t1, N, net_params):
                
        solution = Integrator(EF, lambda x, t: f(x, t, net_params), x0, t0, t1, N)
            
        # Save for jacobian calculations in backward()
        ctx.save_for_backward(x0,t0,t1,net_params)
        ctx.solution = solution
        ctx.Integrator = Integrator
        ctx.N = N
        ctx.f = f
        
        return solution
    
    @staticmethod
    def backward(ctx, dL_dz1):
        # Get all saved context
        z0, t0, t1, net_params = ctx.saved_tensors
        z1 = ctx.solution
        N = ctx.N
        f = ctx.f
        
        # Convenience sizes
        batch_size = z0.size()[0]
        img_len = int(z0.numel() / batch_size)

        # Compute derivative w.r.t. to end time of integration
        dL_dt1 = dL_dz1.view(batch_size,1,-1).bmm(f(z1, t1, net_params).view(batch_size,-1,1))  # Derivative of loss w.r.t t1
        
        #print("dL_dt1", dL_dt1)
        
        # Initial Condition
        num_elements = (z1.numel(), dL_dz1.numel(), batch_size * net_params.numel(), dL_dt1.numel())
        sizes = (z1.size(), dL_dz1.size(), (batch_size, net_params.numel()), dL_dt1.size())
        s0 = flatten(z1, dL_dz1, torch.zeros((batch_size, net_params.numel()), dtype=torch.float32), -dL_dt1) # initial augmented state
        
        # augmented dynamics function
        def aug_dynamics(s, t, theta):
            s = unflatten(s, num_elements, sizes)
            
            with torch.enable_grad(): 
                gradients = [vjp(f, 
                                 (s[0][i].unsqueeze(0), t, theta), 
                                  v=-s[1][i].unsqueeze(0),
                                 )[1] for i in range(batch_size)]
                
            return flatten(f(s[0],t,theta),
                    torch.cat([gradient[0] for gradient in gradients], dim=0), 
                    torch.cat([gradient[2].unsqueeze(0) for gradient in gradients], dim=0), 
                    torch.cat([gradient[1].reshape(1,1) for gradient in gradients], dim=0),
                   )#.unsqueeze(2)

        # Integrate backwards
        with torch.enable_grad():
            s = ctx.Integrator(EF, lambda x, t: aug_dynamics(x, t, net_params), s0, t1, t0, N)
        
        # Extract derivatives
        _, dL_dz0, dL_dtheta, dL_dt0 = unflatten(s, num_elements, sizes)
        
        # must return something for every input to forward, None for non-tensors
        return None, None, dL_dz0, dL_dt0, dL_dt1, None, dL_dtheta 

In [14]:
class ODENet(nn.Module):
    def __init__(self, solver, f, in_channels, body_channels, solver_params):
        super(ODENet, self).__init__()
                
        # Controls amount of parameters
        self.body_channels = body_channels
        
        # Head, used to mix stuff around and improve accuracy
        mid_channels = body_channels // 2 if body_channels // 2 > in_channels else body_channels
        self.conv1 = nn.Conv2d(in_channels, mid_channels, 5, padding=2)
        self.conv2 = nn.Conv2d(mid_channels, self.body_channels, 5, padding=2)
        
        self.f = f(body_channels, body_channels)
        
        # Body
        self.int_f = solver
        self.Integrate = Integrate
        self.solver_params = solver_params
        self.N = solver_params["N"]
        self.h = (solver_params["t1"] - solver_params["t0"]) / solver_params["N"]
        self.t0 = torch.tensor(float(solver_params["t0"]), requires_grad=True)
        self.t1 = torch.tensor(float(solver_params["t1"]), requires_grad=True)
        
        self.net_params = torch.nn.parameter.Parameter(torch.Tensor(self.f.num_params()).normal_(mean=0, std=0.1,generator=None), requires_grad=True)

        # Tails
        
        #self.avg_pool = torch.nn.MaxPool2d(2, stride=2, padding=0)
        self.fc1 = nn.Linear(int(img_size*img_size* self.body_channels), 512)
        self.fc2 = nn.Linear(512, 10)
        
    def forward(self, x):
        x = self.conv2(F.relu(self.conv1(x))) #initial dimensionality expansion
        x = self.Integrate.apply(self.int_f, self.f, x, self.t0, self.t1, self.N, self.net_params) # Vanilla RK4
        #x = self.avg_pool(x)
        x = x.view(-1, int(img_size * img_size * self.body_channels)) ##/ 4))
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [15]:
def train(net, train_loader, test_loader, hyperparameters):
    
    lr = hyperparameters["lr"]
    n_epochs = hyperparameters["n_epochs"]
    momentum = hyperparameters["momentum"]
    weight_decay = hyperparameters["weight_decay"]
    optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
    
    losses = []
    for i in range(n_epochs):
        
        # Train
        net.train()
        train_losses = []
        train_num_correct = 0
        for j, (data, label) in enumerate(train_loader):
            optimizer.zero_grad()
            
            #data = torch.cat((data, torch.zeros(data.size()), torch.zeros(data.size())), dim=1)
            output = net(data) #augmented nodes
            loss = F.nll_loss(output, label)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())
            train_num_correct += label.eq(torch.max(output, 1, keepdim=False, out=None).indices).sum()        
        
        
        num_correct, test_losses = test(net, test_loader)
        losses.append([train_losses, test_losses])
        
        # Report
        print(
          "Avg Train Loss", sum(train_losses)/len(train_losses), "\n",
        "Train Accuracy", (train_num_correct / float(len(train_loader.dataset)) * 100).item(), "%\n",
          "Avg Test Loss", sum(test_losses)/len(test_losses), "\n",
          "Test Accuracy", (num_correct / float(len(test_loader.dataset)) * 100).item(), "%"
         )
        print("----------------------------------------")
        
    return losses    

In [16]:
def test(net, test_loader):
    # Test
    net.eval()
    test_losses = []
    num_correct = 0
    with torch.no_grad():
        for j, (data, label) in enumerate(test_loader):
            #data = torch.cat((data, torch.zeros(data.size()), torch.zeros(data.size())), dim=1)
            output = net(data)
            loss = F.nll_loss(output, label)
            test_losses.append(loss.item())
            num_correct += label.eq(torch.max(output, 1, keepdim=False, out=None).indices).sum()

    
    return num_correct, test_losses 

Here, the dataset can be loaded. I am pretty sure that this script will work for any image dataset. Make sure to change the batch sizes and image sizes to avoid running out of memory. On my computer, this script takes 8 gigs of my actual memory and 20 gigs swap, which is huge, but I have experienced vanishing gradients with 1-batch size training runs. I will do a hyperparameter search, and am working on that script. 

In [17]:
batch_size_train = 64
batch_size_test = 1000
img_size = 28
img_len = 784

train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/Users/louis/Desktop/neuralODE/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=False)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/Users/louis/Desktop/neuralODE/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)
"""
batch_size_train = 64
batch_size_test = 64
img_size = 32
img_len = 32 * 32

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=torchvision.transforms.Compose([
                                            torchvision.transforms.Pad(4),
                                            torchvision.transforms.RandomCrop(32),
                                            torchvision.transforms.RandomHorizontalFlip(p=0.5),
                                            torchvision.transforms.ToTensor(),
                                            torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]))

train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size_train, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=torchvision.transforms.Compose([
                                           torchvision.transforms.ToTensor(),
                                           torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]))

test_loader = torch.utils.data.DataLoader(testset, batch_size= batch_size_test, shuffle=False, num_workers=2)"""

"\nbatch_size_train = 64\nbatch_size_test = 64\nimg_size = 32\nimg_len = 32 * 32\n\ntrainset = torchvision.datasets.CIFAR10(root='./data', train=True,\n                                        download=True, transform=torchvision.transforms.Compose([\n                                            torchvision.transforms.Pad(4),\n                                            torchvision.transforms.RandomCrop(32),\n                                            torchvision.transforms.RandomHorizontalFlip(p=0.5),\n                                            torchvision.transforms.ToTensor(),\n                                            torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]))\n\ntrain_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size_train, shuffle=True, num_workers=2)\n\ntestset = torchvision.datasets.CIFAR10(root='./data', train=False,\n                                       download=True, transform=torchvision.transforms.Compose([\n               

This is where the hyperparameters for training are controlled, and the solver parameters. 

In [24]:
hyperparameters = {
    "lr":  0.01,
    "n_epochs": 30,
    "momentum": 0.5,
    "weight_decay": 0.0,
}

solver_params = {
    "t0": 0,
    "t1": 1,
    "N": 5,
}

In [25]:
TestNetwork = ODENet(explicit_RK, BigBlock, 1, 64, solver_params).to(device)
losses = train(TestNetwork, train_loader, test_loader, hyperparameters)
torch.save(TestNetwork.state_dict(), "testgensolver2.pth")

Avg Train Loss nan 
 Train Accuracy 9.873333930969238 %
 Avg Test Loss nan 
 Test Accuracy 9.800000190734863 %
----------------------------------------


KeyboardInterrupt: 

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

TestNetwork2 = ODENet(explicit_RK, Block, 1, 3, solver_params).to(device)
#TestNetwork.load_state_dict(torch.load("testgensolver.pth"))
print("ODENet", count_parameters(TestNetwork2), "parameters")

Here I convert the Block module from above into one that uses convolutions from the torch.nn package. Since most external packages are made to work with torch.nn, this is convenient for working with MemTorch or torch.nn.quantized. After you make this switch, the backwards pass code will no longer work, so this only works for inference. 

In [None]:
"""
Conversion to something MemTorch & Pytorch's quantization can handle. 
"""

class Block_nn(nn.Module):
    def __init__(self, in_channels, body_channels):
        super(Block_nn, self).__init__()
        self.in_channels = in_channels
        self.conv1 = nn.Conv2d(TestNetwork.body_channels, TestNetwork.body_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(TestNetwork.body_channels, TestNetwork.body_channels, 3, padding=1)
    def forward(self, x, t, net_params):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        return x
    def num_params(self):
        return self.in_channels * self.in_channels * 9 * 2

qnet = ODENet(explicit_RK, Block_nn, 1, 3, solver_params)
sz = qnet.body_channels
qnet.f.conv1.weight = nn.Parameter(TestNetwork.net_params[0:9*sz*sz].view(sz,sz,3,3))
qnet.f.conv2.weight = nn.Parameter(TestNetwork.net_params[9*sz*sz:18*sz*sz].view(sz,sz,3,3))

In [None]:
from memtorch.mn.Module import patch_model
from memtorch.map.Parameter import naive_map
from memtorch.bh.crossbar.Program import naive_program
from memtorch.bh.nonideality.NonIdeality import apply_nonidealities
import memtorch
import copy
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

In [None]:
reference_memristor = memtorch.bh.memristor.VTEAM
reference_memristor_params = {'time_series_resolution': 1e-10, 
                              'r_off': memtorch.bh.StochasticParameter(200, std=20, min=2),
                              'r_on': memtorch.bh.StochasticParameter(100, std=10, min=1)}
memristor = reference_memristor(**reference_memristor_params)
memristor.plot_hysteresis_loop()

patched_model = patch_model(copy.deepcopy(TestNetwork),
                          memristor_model=reference_memristor,
                          memristor_model_params=reference_memristor_params,
                          module_parameters_to_patch=[torch.nn.Linear, torch.nn.Conv2d],
                          mapping_routine=naive_map,
                          transistor=True,
                          programming_routine=None)

patched_model.tune_()

In [None]:
accuracy = []
for i in range(250, 10, -20):
    model = apply_nonidealities(copy.deepcopy(patched_model),
                                      non_idealities=[memtorch.bh.nonideality.NonIdeality.FiniteConductanceStates],
                                      conductance_states = i)
    accuracy.append((test(model, test_loader)[0] / float(len(test_loader.dataset))).item())

In [None]:
plt.plot(list(range(250, 10, -20)), accuracy)

In [None]:
model = apply_nonidealities(copy.deepcopy(patched_model),
                                      non_idealities=[memtorch.bh.nonideality.NonIdeality.FiniteConductanceStates],
                                      conductance_states = 1024)
print((test(model, test_loader)[0] / float(len(test_loader.dataset))).item())