In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, sampler
import torch.optim as optim
import os
import time
import pickle
import numpy as np

In [2]:
class PseudoSpikeRect(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx, input, vth, grad_win, grad_amp):
    
        ctx.save_for_backward(input)
        ctx.vth = vth
        ctx.grad_win = grad_win
        ctx.grad_amp = grad_amp
        output = input.gt(vth).float()
        return output

    @staticmethod
    def backward(ctx, grad_output):
        
        input, = ctx.saved_tensors
        vth = ctx.vth
        grad_win = ctx.grad_win
        grad_amp = ctx.grad_amp
        grad_input = grad_output.clone()
        spike_pseudo_grad = torch.abs((input-vth))<grad_win
        grad = grad_amp * grad_input * spike_pseudo_grad.float()
        return grad, None, None, None

In [3]:
class LinearIFCell(nn.Module):
    """ Leaky Integrate-and-fire neuron layer"""

    def __init__(self, psp_func, pseudo_grad_ops, param):

        super(LinearIFCell, self).__init__()
        self.psp_func = psp_func
        self.pseudo_grad_ops = pseudo_grad_ops
        self.vdecay, self.vth, self.grad_win, self.grad_amp = param

    def forward(self, input_data, state):
        
        pre_spike, pre_volt = state
        volt = self.vdecay  * pre_volt * (1. - pre_spike) + self.psp_func(input_data)
        output = self.pseudo_grad_ops(volt,self.vth, self.grad_win, self.grad_amp)
        return output, (output, volt)

In [4]:
class SingleHiddenLayerSNN(nn.Module):
    """ SNN with single hidden layer """

    def __init__(self, input_dim, output_dim, hidden_dim, param_dict):
        
        super(SingleHiddenLayerSNN, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        pseudo_grad_ops = PseudoSpikeRect.apply
        self.hidden_cell = LinearIFCell(nn.Linear(input_dim,hidden_dim,bias=False),pseudo_grad_ops,param_dict['hid_layer'])
        self.output_cell = LinearIFCell(nn.Linear(hidden_dim,output_dim,bias=False),pseudo_grad_ops,param_dict['out_layer'])
        

    def forward(self, spike_data, init_states_dict, batch_size, spike_ts):
        
        hidden_state, out_state = init_states_dict['hid_layer'], init_states_dict['out_layer']
        spike_data_flatten = spike_data.view(batch_size, self.input_dim, spike_ts)
        output_list = [] #List to store the output at each timestep
        for tt in range(spike_ts):
            
            input_data = spike_data_flatten[:,:,tt]
            hidden_layer, hidden_state= self.hidden_cell(input_data,hidden_state)#forward
            output_layer,out_state = self.output_cell(hidden_layer,out_state)
            output_list.append(output_layer)
        
        output = torch.stack(output_list)
        
        output= torch.sum(output, dim = 0)
       
        return output

In [5]:
class WrapSNN(nn.Module):
    """ Wrapper of SNN """

    def __init__(self, input_dim, output_dim, hidden_dim, param_dict, device):
        
        super(WrapSNN, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.device = device
        self.snn = SingleHiddenLayerSNN(input_dim, output_dim, hidden_dim, param_dict)

    def forward(self, spike_data):
        
        batch_size = spike_data.shape[0]
        spike_ts = spike_data.shape[-1]
        init_states_dict = {}
        hidden_volt = torch.zeros(batch_size, self.hidden_dim, device=self.device)
        hidden_spike = torch.zeros(batch_size, self.hidden_dim, device=self.device)
        init_states_dict['hid_layer'] = (hidden_spike, hidden_volt)
        out_volt = torch.zeros(batch_size, self.output_dim, device=self.device)
        out_spike = torch.zeros(batch_size, self.output_dim, device=self.device)
        init_states_dict['out_layer'] = (out_spike, out_volt)
        output = self.snn(spike_data, init_states_dict, batch_size, spike_ts)
        return output

In [6]:
def img_2_event_img(image, device, spike_ts):
    
    batch_size = image.shape[0]
    channel_size = image.shape[1]
    image_size = image.shape[2]
    image = image.view(batch_size, channel_size, image_size, image_size, 1)
    random_image = torch.rand(batch_size,channel_size,image_size,image_size,spike_ts)
    random_image.to(device)
    event_image = torch.gt(image,random_image).float()

    return event_image

In [7]:
def stbp_snn_training(network, spike_ts, device, batch_size=128, test_batch_size=256, epoch=100):
    
    try:
        os.mkdir("./data")
        print("Directory data Created")
    except FileExistsError:
        print("Directory data already exists")
    
    data_path = './data/'
    train_dataset = torchvision.datasets.MNIST(root=data_path, train=True, download=True,
                                               transform=transforms.ToTensor())
    test_dataset = torchvision.datasets.MNIST(root=data_path, train=False, download=True,
                                              transform=transforms.ToTensor())

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size,
                                  shuffle=False, num_workers=4)
    test_dataloader = DataLoader(test_dataset, batch_size=test_batch_size,
                                 shuffle=False, num_workers=4)

                                              
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(network.parameters(), lr=0.001, momentum=0.9)#was .01
    
    train_loss_list, test_accuracy_list = [], []
    test_num = len(test_dataset)
    network.to(device)
    
    
    for ee in range(epoch):
        
        running_loss = 0.0
        running_batch_num = 0
        train_start = time.time()
        
        
        for data in train_dataloader:
            
            image = data[0]
            label = data[1]
            image= image.to(device)
            label= label.to(device)
            event_image = img_2_event_img(image, device, spike_ts)
            optimizer.zero_grad()
            input = network.forward(event_image)
            loss = criterion(input, label)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            running_batch_num += 1
        train_end = time.time()
        train_loss_list.append(running_loss / running_batch_num)
        print("Epoch %d Training Loss %.4f" % (ee, train_loss_list[-1]), end=" ")
        test_correct_num = 0
        test_start = time.time()
        with torch.no_grad():
            for data in test_dataloader:
                image = data[0]
                label = data[1]
                image.to(device)
                label.to(device)
                event_image = img_2_event_img(image, device, spike_ts)
                outputs = network.forward(event_image)
                _, predicted = torch.max(outputs, 1)
                test_correct_num+= ((predicted==label).sum().to("cpu")).item()
        
        
        test_end = time.time()
        test_accuracy_list.append(test_correct_num / test_num)
        print("Test Accuracy %.4f Training Time: %.1f Test Time: %.1f" % (
            test_accuracy_list[-1], train_end - train_start, test_end - test_start))
    
   
    print("End Training")
    network.to('cpu')
    return train_loss_list, test_accuracy_list

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_dim = 784
output_dim = 10
hidden_dim = 256
param_dict = {'hid_layer': [.5, .4, .2, 1],'out_layer': [.5, .4, .2, 1]}
spike_ts = 3
snn = WrapSNN(input_dim, output_dim, hidden_dim, param_dict, device)
batch_size=128
test_batch_size=256
epoch=5
train_loss_list, test_accuracy_list = stbp_snn_training(snn, spike_ts, device, batch_size=128, test_batch_size=256, epoch=5)

Directory data already exists
Epoch 0 Training Loss 1.5788 Test Accuracy 0.7019 Training Time: 10.1 Test Time: 4.1
Epoch 1 Training Loss 1.0729 Test Accuracy 0.7323 Training Time: 10.0 Test Time: 4.2
Epoch 2 Training Loss 0.9869 Test Accuracy 0.7910 Training Time: 10.2 Test Time: 4.5
Epoch 3 Training Loss 0.8592 Test Accuracy 0.8769 Training Time: 10.4 Test Time: 4.7
Epoch 4 Training Loss 0.7004 Test Accuracy 0.8913 Training Time: 10.6 Test Time: 4.2
End Training
