In [None]:
%matplotlib inline


In [None]:
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import StepLR
import torch
from sklearn.metrics import confusion_matrix
import pandas as pd
import numpy as np

import torch.nn as nn

from tqdm import tqdm_notebook

from torchvision import transforms
from torchvision import datasets
from os.path import join

from deepsith import DeepSITH

from tqdm.notebook import tqdm

import random

from csv import DictWriter
# if gpu is to be used
use_cuda = torch.cuda.is_available()

FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor

ttype =FloatTensor

import seaborn as sn
print(use_cuda)
import pickle

sn.set_context("poster")

In [None]:
class DeepSITH_Classifier(nn.Module):
    def __init__(self, out_features, layer_params, dropout=.1):
        super(DeepSITH_Classifier, self).__init__()
        last_hidden = layer_params[-1]['hidden_size']
        self.hs = DeepSITH(layer_params=layer_params, dropout=dropout)
        self.to_out = nn.Linear(last_hidden, out_features)
    def forward(self, inp):
        x = self.hs(inp)
        x = self.to_out(x)
        return x

# Load Stimuli

In [None]:
# Same seed and supposed Permutation as the coRNN paper
torch.manual_seed(12008)
permute = torch.randperm(784)

In [None]:
print(permute)

In [None]:
norm = transforms.Normalize((.1307,), (.3081,), )

In [None]:
batch_size = 64
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((.1307,), (.3081,))
                               ])
ds1 = datasets.MNIST('../data', train=True, download=True, transform=transform)
ds2 = datasets.MNIST('../data', train=False, download=True, transform=transform)
train_loader=torch.utils.data.DataLoader(ds1,batch_size=batch_size, 
                                         num_workers=1, pin_memory=True, shuffle=True)
test_loader=torch.utils.data.DataLoader(ds2, batch_size=batch_size, 
                                        num_workers=1, pin_memory=True, shuffle=True)

In [None]:
testi = next(iter(test_loader))[0]

plt.imshow(testi[0].reshape(-1)[permute].reshape(28,28))

plt.colorbar()

In [None]:
testi.shape

# Define test and train

In [None]:

def train(model, ttype, train_loader, test_loader, optimizer, loss_func, epoch, perf_file,
          permute=None, loss_buffer_size=800, batch_size=4, device='cuda',
          prog_bar=None):
    
    assert(loss_buffer_size%batch_size==0)
    if permute is None:
        permute = torch.LongTensor(list(range(784)))
        
    losses = []
    perfs = []
    last_test_perf = 0
    best_test_perf = -1
    
    for batch_idx, (data, target) in enumerate(train_loader):
        model.train()
        data = data.to(device).view(data.shape[0],1,1,-1)
        target = target.to(device)
        optimizer.zero_grad()
        out = model(data[:, :, :, permute])
        loss = loss_func(out[:, -1, :],
                         target)

        loss.backward()
        optimizer.step()

        perfs.append((torch.argmax(out[:, -1, :], dim=-1) == 
                      target).sum().item())
        perfs = perfs[int(-loss_buffer_size/batch_size):]
        losses.append(loss.detach().cpu().numpy())
        losses = losses[int(-loss_buffer_size/batch_size):]
        if not (prog_bar is None):
            # Update progress_bar
            s = "{}:{} Loss: {:.4f}, perf: {:.4f}, valid: {:.4f}"
            format_list = [e,batch_idx*batch_size, np.mean(losses), 
                           np.sum(perfs)/((len(perfs))*batch_size), last_test_perf]         
            s = s.format(*format_list)
            prog_bar.set_description(s)
        
        if ((batch_idx*batch_size)%loss_buffer_size == 0) & (batch_idx != 0):
            loss_track = {}
            # last_test_perf = test(model, 'cuda', test_loader, 
            #                       batch_size=batch_size, 
            #                       permute=permute)
            loss_track['avg_loss'] = np.mean(losses)
            #loss_track['last_test'] = last_test_perf
            loss_track['epoch'] = epoch
            loss_track['batch_idx'] = batch_idx
            loss_track['train_perf']= np.sum(perfs)/((len(perfs))*batch_size)
            with open(perf_file, 'a+') as fp:
                csv_writer = DictWriter(fp, fieldnames=list(loss_track.keys()))
                if fp.tell() == 0:
                    csv_writer.writeheader()
                csv_writer.writerow(loss_track)
                fp.flush()
            #if best_test_perf < last_test_perf:
            #    torch.save(model.state_dict(), perf_file[:-4]+".pt")
            #    best_test_perf = last_test_perf

                
def test(model, device, test_loader, batch_size=4, permute=None):
    if permute is None:
        permute = torch.LongTensor(list(range(784)))
        
    model.eval()
    correct = 0
    count = 0
    with torch.no_grad():
        for data, target in test_loader:
            data = data.to(device).view(data.shape[0],1,1,-1)
            target = target.to(device)
            
            out = model(data[:,:,:, permute])
            pred = out[:, -1].argmax(dim=-1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            count += 1
    return correct / len(test_loader.dataset)

# Setup the model

In [None]:
g = 0.0
sith_params1 = {"in_features":1, 
                "tau_min":1, "tau_max":30.0, "buff_max":50,
                "k":125, 'dt':1,
                "ntau":20, 'g':g,  
                "ttype":ttype, "batch_norm":True,
                "hidden_size":60, "act_func":nn.ReLU()
               }
sith_params2 = {"in_features":sith_params1['hidden_size'], 
                "tau_min":1, "tau_max":150.0, "buff_max":250,
                "k":61, 'dt':1,
                "ntau":20, 'g':g, 
                "ttype":ttype, "batch_norm":True,
                "hidden_size":60, "act_func":nn.ReLU()
                }
sith_params3 = {"in_features":sith_params2['hidden_size'], 
                "tau_min":1, "tau_max":750.0, "buff_max":1500,
                "k":35, 'dt':1,
                "ntau":20, 'g':g, 
                "ttype":ttype, "batch_norm":True,
                "hidden_size":60, "act_func":nn.ReLU()
                }



layer_params = [sith_params1, sith_params2, sith_params3]


model = DeepSITH_Classifier(10,
                           layer_params=layer_params, 
                           dropout=0.2).cuda()

tot_weights = 0
for p in model.parameters():
    tot_weights += p.numel()
print("Total Weights:", tot_weights)
print(model)

In [None]:
epochs = 90
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)
sched = StepLR(optimizer, step_size=30, gamma=0.1)
progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')
test_perf = []
for e in progress_bar:
    train(model, ttype, train_loader, test_loader, optimizer, loss_func, batch_size=batch_size,
          epoch=e, perf_file=join('perf','pmnist_deepsith_78.csv'),loss_buffer_size=64*32, 
          prog_bar=progress_bar, permute=permute)
    last_test_perf = test(model, 'cuda', test_loader, 
                                  batch_size=batch_size, 
                                  permute=permute)
    test_perf.append({"epoch":e,
                      'test':last_test_perf})
    sched.step()

# Find Errors

In [None]:
test(model, 'cuda', test_loader, 
     batch_size=batch_size, 
     permute=permute)

In [None]:
def conf_mat_gen(model, device, test_loader, batch_size=4, permute=None):
    if permute is None:
        permute = torch.LongTensor(list(range(784)))
    evals = {'pred':[],
             'actual':[]}
    model.eval()
    correct = 0
    count = 0
    with torch.no_grad():
        for data, target in test_loader:
            data = data.to(device).view(data.shape[0],1,1,-1)
            target = target.to(device)
            for x in target:
                evals['actual'].append(x.detach().cpu().numpy())
            out = model(data[:,:,:, permute])
            for x in out[:, -1].argmax(dim=-1, keepdim=True):
                evals['pred'].append(x.detach().cpu().numpy())
    return evals

In [None]:
evals = conf_mat_gen(model, 'cuda', test_loader, batch_size=4, permute=permute)

In [None]:
fig = plt.figure(figsize=(10,11))
plt.imshow(confusion_matrix(np.array(evals['pred'])[:, 0], 
                            np.array(evals['actual'])), cmap='coolwarm')
plt.colorbar()
plt.xticks(list(range(10)));
plt.yticks(list(range(10)));
plt.savefig(join('figs', 'sMNIST_LoLa_conf'), dpi=200, bboxinches='tight')