In [None]:
!pip3 install -q idx2numpy
!pip3 install -q --upgrade --force-reinstall matplotlib

[31mERROR: tensorflow 2.3.0 has requirement numpy<1.19.0,>=1.16.0, but you'll have numpy 1.19.4 which is incompatible.[0m
[31mERROR: nbclient 0.5.1 has requirement jupyter-client>=6.1.5, but you'll have jupyter-client 5.3.5 which is incompatible.[0m
[31mERROR: datascience 0.10.6 has requirement folium==0.2.1, but you'll have folium 0.8.3 which is incompatible.[0m
[31mERROR: albumentations 0.1.12 has requirement imgaug<0.2.7,>=0.2.5, but you'll have imgaug 0.2.9 which is incompatible.[0m


### Import libraries

In [None]:
import argparse, torch, random
import numpy as np
from torchvision import datasets, transforms
import torch.nn.functional as F
from models import SNN, SpikeCELoss

import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torch.autograd import Function

import idx2numpy
import numpy as np
import matplotlib.pyplot as plt

### WrapperFunction Class

In [None]:
class WrapperFunction(Function):
    @staticmethod
    def forward(ctx, input, params, forward, backward):
        ctx.backward = backward
        pack, output = forward(input)
        ctx.save_for_backward(*pack)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        backward = ctx.backward
        pack = ctx.saved_tensors
        grad_input, grad_weight = backward(grad_output, *pack)
        return grad_input, grad_weight, None, None

### FirstSpikeTime Class

In [None]:
class FirstSpikeTime(Function):
    @staticmethod
    def forward(ctx, input):   
        idx = torch.arange(input.shape[2], 0, -1).unsqueeze(0).unsqueeze(0).float().cuda()
        first_spike_times = torch.argmax(idx*input, dim=2).float()
        ctx.save_for_backward(input, first_spike_times.clone())
        first_spike_times[first_spike_times==0] = input.shape[2]-1
        return first_spike_times
    
    @staticmethod
    def backward(ctx, grad_output):
        input, first_spike_times = ctx.saved_tensors
        k = F.one_hot(first_spike_times.long(), input.shape[2]).float()
        grad_input = k * grad_output.unsqueeze(-1)
        return grad_input

### SpikingLinear Class

In [None]:
class SpikingLinear(nn.Module):
    def __init__(self, input_dim, output_dim, T, dt, tau_m, tau_s, mu):
        super(SpikingLinear, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.T = T
        self.dt = dt
        self.tau_m = tau_m
        self.tau_s = tau_s
        
        self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim))
        nn.init.normal_(self.weight, mu, mu)
        
        self.forward = lambda input : WrapperFunction.apply(input, self.weight, self.manual_forward, self.manual_backward)
        
    def manual_forward(self, input):
        steps = int(self.T / self.dt)
    
        V = torch.zeros(input.shape[0], self.output_dim, steps).cuda()
        I = torch.zeros(input.shape[0], self.output_dim, steps).cuda()
        output = torch.zeros(input.shape[0], self.output_dim, steps).cuda()

        while True:
            for i in range(1, steps):
                t = i * self.dt
                V[:,:,i] = (1 - self.dt / self.tau_m) * V[:,:,i-1] + (self.dt / self.tau_m) * I[:,:,i-1]
                I[:,:,i] = (1 - self.dt / self.tau_s) * I[:,:,i-1] + F.linear(input[:,:,i-1].float(), self.weight)
                spikes = (V[:,:,i] > 1.0).float()
                output[:,:,i] = spikes
                V[:,:,i] = (1-spikes) * V[:,:,i]

            if self.training:
                is_silent = output.sum(2).min(0)[0] == 0
                self.weight.data[is_silent] = self.weight.data[is_silent] + 1e-1
                if is_silent.sum() == 0:
                    break
            else:
                break

        return (input, I, output), output
    
    def manual_backward(self, grad_output, input, I, post_spikes):
        steps = int(self.T / self.dt)
                
        lV = torch.zeros(input.shape[0], self.output_dim, steps).cuda()
        lI = torch.zeros(input.shape[0], self.output_dim, steps).cuda()
        
        grad_input = torch.zeros(input.shape[0], input.shape[1], steps).cuda()
        grad_weight = torch.zeros(input.shape[0], *self.weight.shape).cuda()
        
        for i in range(steps-2, -1, -1):
            t = i * self.dt
            delta = lV[:,:,i+1] - lI[:,:,i+1]
            grad_input[:,:,i] = F.linear(delta, self.weight.t())
            lV[:,:,i] = (1 - self.dt / self.tau_m) * lV[:,:,i+1] + post_spikes[:,:,i+1] * (lV[:,:,i+1] + grad_output[:,:,i+1]) / (I[:,:,i] - 1 + 1e-10)
            lI[:,:,i] = lI[:,:,i+1] + (self.dt / self.tau_s) * (lV[:,:,i+1] - lI[:,:,i+1])
            spike_bool = input[:,:,i].float()
            grad_weight -= (spike_bool.unsqueeze(1) * lI[:,:,i].unsqueeze(2))

        return grad_input, grad_weight

### SNN Class

In [None]:
class SNN(nn.Module):
    def __init__(self, input_dim, output_dim, T, dt, tau_m, tau_s):
        super(SNN, self).__init__()
        self.slinear1 = SpikingLinear(input_dim, 10, T, dt, tau_m, tau_s, 0.1)
        self.outact = FirstSpikeTime.apply
        
    def forward(self, input):
        u = self.slinear1(input)
        u = self.outact(u)
        return u

### SpikeCELoss Class

In [None]:
class SpikeCELoss(nn.Module):
    def __init__(self, T, xi, tau_s):
        super(SpikeCELoss, self).__init__()
        self.xi = xi
        self.tau_s = tau_s
        self.celoss = nn.CrossEntropyLoss()
        
    def forward(self, input, target):
        loss = self.celoss(-input / (self.xi * self.tau_s), target)
        return loss

### Download data

In [None]:
!wget -q -r -A '*ubyte.gz' --no-parent  'http://yann.lecun.com/exdb/mnist/'
!rm -rf data
!mkdir data
!mv yann.lecun.com/exdb/mnist/* data/
!rm -rf yann.lecun.com
!gunzip data/*

In [None]:
data_folder = '/content/data/'
device = 'cuda'
# seed thing
print_freq = 100
deterministic = True

# Training settings
num_epochs = 100
lr = 1.0
batch_size = 128

# Loss settings (specific for SNNs)
xi = 0.4
alpha = 1e-2
beta = 2.0

# Spiking Model settings
T = 20
dt = 1
tau_m = 20
tau_s = 5
t_max = 12
t_min = 2

### New Section

In [None]:
if deterministic:
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

### Encode data method

In [None]:
def encode_data(data):
    spike_data = t_min + (t_max - t_min) * (data < 0.5).view(data.shape[0], -1)
    spike_data = F.one_hot(spike_data.long(), int(T))
    return spike_data

### Train method

In [None]:
def train(model, criterion, optimizer, loader):
    total_correct = 0.
    total_loss = 0.
    total_samples = 0.
    model.train()
    
    for batch_idx, (input, target) in enumerate(loader):
        input, target = input.to(device), target.to(device)
        input = encode_data(input)
        
        total_correct = 0.
        total_loss = 0.
        total_samples = 0.
        
        output = model(input)

        loss = criterion(output, target)

        if alpha != 0:
            target_first_spike_times = output.gather(1, target.view(-1, 1))
            loss += alpha * (torch.exp(target_first_spike_times / (beta * tau_s)) - 1).mean()

        predictions = output.data.min(1, keepdim=True)[1]
        total_correct += predictions.eq(target.data.view_as(predictions)).sum().item()
        total_loss += loss.item() * len(target)
        total_samples += len(target)
        
        optimizer.zero_grad()
        loss.backward()
        
        optimizer.step()

        if batch_idx % print_freq == 0:
            print('\tBatch {:03d}/{:03d}: \tAcc {:.2f}  Loss {:.3f}'.format(batch_idx, len(loader), 100*total_correct/total_samples, total_loss/total_samples))
   
    print('\t\tTrain: \tAcc {:.2f}  Loss {:.3f}'.format(100*total_correct/total_samples, total_loss/total_samples))

### Test method

In [None]:
def test(model, loader):
    total_correct = 0.
    total_samples = 0.
    model.eval()
    
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(loader):
            data, target = data.to(device), target.to(device)
            spike_data = encode_data(data)
            
            first_post_spikes = model(spike_data)
            predictions = first_post_spikes.data.min(1, keepdim=True)[1]
            total_correct += predictions.eq(target.data.view_as(predictions)).sum().item()
            total_samples += len(target)
            
        print('\t\tTest: \tAcc {:.2f}'.format(100*total_correct/total_samples))

### Create dataset and loader objects

In [None]:
train_dataset = datasets.MNIST(data_folder, train=True, download=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = datasets.MNIST(data_folder, train=False, download=True, transform=transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
        
model = SNN(784, 10, T, dt, tau_m, tau_s).to(device)
criterion = SpikeCELoss(T, xi, tau_s)
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /content/data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /content/data/MNIST/raw/train-images-idx3-ubyte.gz to /content/data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /content/data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /content/data/MNIST/raw/train-labels-idx1-ubyte.gz to /content/data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /content/data/MNIST/raw/t10k-images-idx3-ubyte.gz




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /content/data/MNIST/raw/t10k-images-idx3-ubyte.gz to /content/data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /content/data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /content/data/MNIST/raw/t10k-labels-idx1-ubyte.gz to /content/data/MNIST/raw
Processing...
Done!


### Epoch training

In [None]:
for epoch in range(num_epochs):
    print('Epoch {:03d}/{:03d}'.format(epoch, num_epochs))
    train(model, criterion, optimizer, train_loader)
    test(model, test_loader)
    scheduler.step()

Epoch 000/100
	Batch 000/469: 	Acc 9.38  Loss 2.440
	Batch 100/469: 	Acc 50.78  Loss 1.344
	Batch 200/469: 	Acc 75.78  Loss 1.060
	Batch 300/469: 	Acc 69.53  Loss 0.848
	Batch 400/469: 	Acc 82.03  Loss 0.716
		Train: 	Acc 69.79  Loss 1.049
		Test: 	Acc 69.97
Epoch 001/100
	Batch 000/469: 	Acc 71.09  Loss 0.924
	Batch 100/469: 	Acc 80.47  Loss 0.684
	Batch 200/469: 	Acc 84.38  Loss 0.593
	Batch 300/469: 	Acc 82.81  Loss 0.618
	Batch 400/469: 	Acc 87.50  Loss 0.527
		Train: 	Acc 81.25  Loss 0.537
		Test: 	Acc 83.18
Epoch 002/100
	Batch 000/469: 	Acc 85.16  Loss 0.591
	Batch 100/469: 	Acc 78.91  Loss 0.838
	Batch 200/469: 	Acc 86.72  Loss 0.427
	Batch 300/469: 	Acc 80.47  Loss 0.693
	Batch 400/469: 	Acc 82.81  Loss 0.579
		Train: 	Acc 82.29  Loss 0.679
		Test: 	Acc 84.56
Epoch 003/100
	Batch 000/469: 	Acc 84.38  Loss 0.597
	Batch 100/469: 	Acc 85.16  Loss 0.511
	Batch 200/469: 	Acc 87.50  Loss 0.408
	Batch 300/469: 	Acc 86.72  Loss 0.479
	Batch 400/469: 	Acc 84.38  Loss 0.626
		Train: 	Ac