This is normal SNN network with Stashing Randomly

In [1]:
import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
import numpy as np
import itertools
import copy

In [2]:
# Hyperparameters
batch_size = 128
epochs = 3
beta = 0.95
# The beta defined here = (1 - delta(t)/Tau)
# Here delta(t) is the clk period
# Tau = Time constant = RC
# This is the approximation of beta = exp(-delta(t)/Tau)
num_steps = 25

In [3]:
# Perform learning over cuda
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda')

In [4]:
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST("./../data", train=True, download=True, transform=transform)
mnist_test = datasets.MNIST("./../data", train=False, download=True, transform=transform)

In [5]:
# Loading the training and test data
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

In [6]:
# Here we are creating a single hidden layer NN, with 150 nodes
class Net(nn.Module):
    def __init__(self, num_hidden_layer = 150):
        super().__init__()

        # Initialize layers
        self.input_layer = nn.ModuleList([nn.Linear(in_features=784, out_features=1) for i in range(num_hidden_layer)])
        self.lif1 = snn.Leaky(beta=beta)
        self.output_layer = nn.Linear(in_features=num_hidden_layer, out_features=10)
        self.lif2 = snn.Leaky(beta=beta)

    def forward(self, x):

        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        # Record the final layer
        spk2_rec = []
        mem2_rec = []

        for step in range(num_steps):
            cur1 = [self.input_layer[i](x) for i in range(num_hidden_layer)]
            cur1 = torch.hstack(cur1)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.output_layer(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)

# Load the network onto CUDA if available
num_hidden_layer = 150
model = Net(num_hidden_layer=num_hidden_layer).to(device)
next(model.parameters()).is_cuda


True

In [7]:
next(model.parameters()).is_cuda

True

In [8]:
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [9]:
def accuracy():
    with torch.no_grad():
        n_correct = 0
        n_samples = 0
        for images, label in test_loader:
            images = images.reshape(-1, 28*28).to(device)
            label = label.to(device)

            # Test set forward pass
            test_spk, test_mem = model(images.view(batch_size, -1))
            n_samples += label.size(0)
            _, idx = test_spk.sum(dim=0).max(1)
            n_correct+= (label == idx).sum().item()
            return(n_correct/n_samples * 100)

In [10]:
# Generating non duplicate random nodes to be stashed out of hidden layer per epoch 
import random
stashed = random.sample(range(num_hidden_layer), 6)
stashed

[38, 68, 1, 148, 101, 4]

In [11]:
%%time
j=0
n_total_steps = len(train_loader)
for epoch_iter in range(epochs):
    for i,(data, target) in enumerate(train_loader):
        x = data.to(device)
        y = target.to(device)
        spk_rec, mem_rec = model(x.view(batch_size, -1))
        loss_val = torch.zeros((1),device=device)
        for step in range(num_steps):
            loss_val += loss(mem_rec[step], y)
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()
        if (i+1) % 100 == 0:
             print (f'Epoch [{epoch_iter+1}/{epochs}], Step[{i+1}/{n_total_steps}], Loss: {loss_val.item():.4f}')
             print("-------------------------")
        
        if (i+1) % 200 == 0:
            print("Accuracy is", accuracy())
            for params in model.input_layer[stashed[j]].parameters():
                params.requires_grad = False
            print("Node stashed", stashed[j])
            print("-------------------------")
            j+=1

        

Epoch [1/3], Step[100/468], Loss: 22.0693
-------------------------
Epoch [1/3], Step[200/468], Loss: 19.1155
-------------------------
Accuracy is 82.8125
Node stashed 38
-------------------------
Epoch [1/3], Step[300/468], Loss: 18.0457
-------------------------
Epoch [1/3], Step[400/468], Loss: 16.9595
-------------------------
Accuracy is 83.59375
Node stashed 68
-------------------------
Epoch [2/3], Step[100/468], Loss: 12.3625
-------------------------
Epoch [2/3], Step[200/468], Loss: 9.4228
-------------------------
Accuracy is 88.28125
Node stashed 1
-------------------------
Epoch [2/3], Step[300/468], Loss: 9.8779
-------------------------
Epoch [2/3], Step[400/468], Loss: 15.9770
-------------------------
Accuracy is 92.96875
Node stashed 148
-------------------------
Epoch [3/3], Step[100/468], Loss: 11.9606
-------------------------
Epoch [3/3], Step[200/468], Loss: 6.9147
-------------------------
Accuracy is 90.625
Node stashed 101
-------------------------
Epoch [3/3

In [12]:
# Assertion to see if our nodes have really been stashed even after the training is over
univ = [i for i in range(num_hidden_layer)]
not_stashed = [i for i in univ if i not in stashed]

for i in stashed:
    for j in model.input_layer[i].parameters():
        assert j.requires_grad == False

for i in not_stashed:
    for j in model.input_layer[i].parameters():
        assert j.requires_grad == True

In [13]:
accuracy()

92.96875

In [14]:
torch.save(model,'SNN_with_stashing_randomly.pt')

In [15]:
# #Creating computational graph to see, whether all the neurons are contributing or not
# batch = next(iter(train_loader))
# yhat = net(batch[0].reshape(-1,28*28).to(device)) # Give dummy batch to forward().

# from torchviz import make_dot

# make_dot(yhat, params=dict(list(net.named_parameters()))).render("trial_1", format="png")