In [1]:
%matplotlib inline

In [36]:
import matplotlib.pylab as plt
import seaborn as sn
import pandas as pd
import os 
from os.path import join
from tqdm.notebook import tqdm
import pickle
sn.set_context("poster")
import itertools
from csv import DictWriter
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms

ttype = torch.cuda.DoubleTensor if torch.cuda.is_available() else torch.DoubleTensor
ctype = torch.cuda.LongTensor if torch.cuda.is_available() else torch.LongTensor
from deepsith import DeepSITH


In [12]:
# Data
batch_size = 8
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=batch_size, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified


In [34]:
def train(model, ttype, train_loader, optimizer, loss_func, epoch, perf_file,
          loss_buffer_size=200, batch_size=4, device='cuda', reg_loss=None,
          prog_bar=None):

    assert(loss_buffer_size%batch_size==0)
    losses = []
    perfs = []
    
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.to(device).view(data.shape[0],1,data.shape[1],-1).type(ttype)
        target = target.to(device).type(ctype)
        optimizer.zero_grad()
        out = model(data)
        loss = loss_func(out[:, -1, :],
                         target)
        loss.backward()
        optimizer.step()

        perfs.append((torch.argmax(out[:, -1, :], dim=-1) == 
                      target).sum().item())
        perfs = perfs[int(-loss_buffer_size/batch_size):]
        losses.append(loss.detach().cpu().numpy())
        losses = losses[-loss_buffer_size:]
        if not (prog_bar is None):
            # Update progress_bar
            s = "{}:{} Loss: {:.6f}, perf: {:.6f}"
            format_list = [e,batch_idx, np.mean(losses), 
                           np.sum(perfs)/((len(perfs))*batch_size)]         
            s = s.format(*format_list)
            prog_bar.set_description(s)
        if (batch_idx*batch_size)%loss_buffer_size == 1:
            loss_track = {}
            loss_track['avg_loss'] = np.mean(losses)
            loss_track['epoch'] = epoch
            loss_track['batch_idx'] = batch_idx
            loss_track['train_perf']= np.sum(perfs)/((len(perfs))*batch_size)
            with open(perf_file, 'a+') as fp:
                csv_writer = DictWriter(fp, fieldnames=list(loss_track.keys()))
                csv_writer.writerow(loss_track)
                fp.flush()

In [30]:
           
def test(model, device, test_loader, scale, stop_early=None):
    model.eval()
    correct = 0
    count = 0
    with torch.no_grad():
        for data, target in test_loader:
            f = scipy.interpolate.interp1d(np.arange(0, data.view(-1).shape[-1], 1),
                                           data.view(-1).detach().cpu())
            new_y = f(np.arange(0, data.view(-1).shape[-1]-1, 1/scale))
            new_y = ttype(new_y).view(1, 1, 1, -1)
            target = target.to(device)
            
            out = model(new_y)
            pred = out[:, -1].argmax(dim=-1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            count += 1
            if count %10000 ==0:
                print(scale, correct/count)
            if stop_early:
                if count > stop_early:
                    break
    return correct / count

In [31]:
class DeepSITH_Classifier(nn.Module):
    def __init__(self, out_features, layer_params, dropout=.5):
        super(DeepSITH_Classifier, self).__init__()
        last_hidden = layer_params[-1]['hidden_size']
        self.hs = DeepSITH(layer_params=layer_params, dropout=dropout)
        self.to_out = nn.Linear(last_hidden, out_features)
    def forward(self, inp):
        x = self.hs(inp)
        x = self.to_out(x)
        return x

In [42]:
sith_params1 = {"in_features":3, 
                "tau_min":1, "tau_max":256.0, 
                "k":75, 'dt':1,
                "ntau":20, 'g':0.,  
                "ttype":ttype, 
                "hidden_size":10, "act_func":nn.ReLU()
               }
sith_params2 = {"in_features":sith_params1['hidden_size'], 
                "tau_min":1, "tau_max":512.0, 
                "k":75, 'dt':1,
                "ntau":20, 'g':0., 
                "ttype":ttype, 
                "hidden_size":20, "act_func":nn.ReLU()
                }
sith_params3 = {"in_features":sith_params2['hidden_size'], 
                "tau_min":1, "tau_max":1024.0, 
                "k":75, 'dt':1,
                "ntau":20, 'g':0., 
                "ttype":ttype, 
                "hidden_size":30, "act_func":nn.ReLU()
                }
layer_params = [sith_params1, sith_params2, sith_params3]
model = DeepSITH_Classifier(out_features=len(classes),
                            layer_params=layer_params, 
                            dropout=.0).cuda().double()
print(model)
for i, l in enumerate(model.hs.layers):
    print("Layer {}".format(i), l.sith.tau_star)
tot_weights = 0
for p in model.parameters():
    tot_weights += p.numel()
print("Total Weights:", tot_weights)
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

DeepSITH_Classifier(
  (hs): DeepSITH(
    (layers): ModuleList(
      (0): _DeepSITH_core(
        (sith): iSITH(ntau=20, tau_min=1, tau_max=256.0, buff_max=768.0, dt=1, k=75, g=0.0)
        (linear): Sequential(
          (0): Linear(in_features=60, out_features=10, bias=True)
          (1): ReLU()
        )
      )
      (1): _DeepSITH_core(
        (sith): iSITH(ntau=20, tau_min=1, tau_max=512.0, buff_max=1536.0, dt=1, k=75, g=0.0)
        (linear): Sequential(
          (0): Linear(in_features=200, out_features=20, bias=True)
          (1): ReLU()
        )
      )
      (2): _DeepSITH_core(
        (sith): iSITH(ntau=20, tau_min=1, tau_max=1024.0, buff_max=3072.0, dt=1, k=75, g=0.0)
        (linear): Sequential(
          (0): Linear(in_features=400, out_features=30, bias=True)
          (1): ReLU()
        )
      )
    )
    (dropouts): ModuleList(
      (0): Dropout(p=0.0, inplace=False)
      (1): Dropout(p=0.0, inplace=False)
    )
  )
  (to_out): Linear(in_features=30, out_

In [43]:
# Just for visualizing average loss through time. 
loss_buffer_size = 50*8
epochs = 40
progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')

for e in progress_bar:
    train(model, ttype, trainloader, optimizer, loss_func, batch_size=batch_size,
          epoch=e, perf_file=join('perf','cifar10_deepsith_1.csv'),
          prog_bar=progress_bar)

HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))




KeyboardInterrupt: 