Implement the spiking neural network on the MNIST dataset with convolution layers

In [1]:
# !pip3 install torch==1.5.0+cu101 torchvision==0.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html

In [2]:
%pylab inline

Populating the interactive namespace from numpy and matplotlib


## hidden layer 

In [3]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR
import matplotlib.pyplot as plt
import math
# import matplotlib
# TODO: code cleaning
T = 20
'''
STEP 1: LOADING DATASET
'''
batch_size = 128
train_dataset = dsets.MNIST(root='./data',train=True,transform=transforms.ToTensor(),download=True)
test_dataset = dsets.MNIST(root='./data',train=False,transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False)
'''
STEP 2: MAKING DATASET ITERABLE
'''

n_iters = 300000
decay = 0.1  # neuron decay rate
thresh = 0.5  # neuronal threshold
lens = 0.3  # hyper-parameters of approximate function
num_epochs = 150  # n_iters / (len(train_dataset) / batch_size)
num_epochs = int(num_epochs)

'''
STEP 3a: CREATE spike MODEL CLASS
'''

b_j0 = 0.01  # neural threshold baseline
tau_m = 20  # ms membrane potential constant
R_m = 1  # membrane resistance
dt = 1  #
gamma = .5  # gradient scale

def gaussian(x,mu=0.,sigma=.5):
    return torch.exp(-((x-mu)**2)/(2*sigma**2))/torch.sqrt(2*torch.tensor(math.pi))/sigma
# define approximate firing function
class ActFun_adp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):  # input = membrane potential- threshold
        ctx.save_for_backward(input)
        return input.gt(0).float()  # is firing ???

    @staticmethod
    def backward(ctx, grad_output):  # approximate the gradients
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        temp = gaussian(input,mu=0.,sigma=lens)*1.15 \
            - gaussian(input,mu=lens,sigma=1.75*lens)*.15 \
                - gaussian(input,mu=-lens,sigma=1.75*lens)*.15
        return gamma * grad_input * temp.float()


act_fun_adp = ActFun_adp.apply
# membrane potential update


def mem_update_NU_adp(inputs, mem, spike, tau_adp, tau_m,b, dt=1, isAdapt=1):
    alpha = torch.exp(-1. * dt / tau_m).cuda()
    ro = torch.exp(-1. * dt / tau_adp).cuda()
    if isAdapt:
        beta = 1.8
    else:
        beta = 0.
    b = ro * b + (1. - ro) * spike
    B = b_j0 + beta * b

    mem = mem * alpha + (1 - alpha) * R_m * inputs - B * spike * dt
    inputs_ = mem - B
    spike = act_fun_adp(inputs_)  # act_fun : approximation firing function
    # spike = F.relu(inputs)
    return mem, spike, B, b

def integrator(inputs,mem,tau_m):
    # alpha = torch.exp(-1. * dt / tau_m).cuda()
    alpha =tau_m
    # mem = mem * alpha + (1 - alpha) * R_m * inputs
    mem = mem  + alpha* inputs
    return mem

def std_cross_entropy(x, y):
    log_prob = -1.0 * F.log_softmax(x, 1)
    loss = log_prob.gather(1, y.unsqueeze(1))
    loss = loss.mean()
    return loss


class RNN_custom(nn.Module):
    def __init__(self, input_size, output_size,criterion=nn.NLLLoss(),hidden_size = [128,10],filters=[15,40],kernels = [5,5],is_rec = 0):
        super(RNN_custom, self).__init__()
        self.criterion = criterion
        
        self.hidden_size = hidden_size
        self.output_size = output_size
        
        self.i2h = nn.Linear(784, hidden_size[0])
        
        self.h2o = nn.Linear(hidden_size[1], output_size)

        self.h2r = nn.Linear(hidden_size[0], hidden_size[1]-10)

        self.r2r = nn.Linear(hidden_size[1]-10, hidden_size[1]-10)
        self.r2p = nn.Linear(hidden_size[1]-10, 10)
        self.p2r = nn.Linear(10, hidden_size[1]-10)
        self.p2p = nn.Linear(10, 10)



        self.tau_adp_h = nn.Parameter(torch.Tensor(hidden_size[0]))
        self.tau_adp_h2 = nn.Parameter(torch.Tensor(hidden_size[1]))
        self.tau_adp_o = nn.Parameter(torch.Tensor(output_size))

        self.tau_m_h = nn.Parameter(torch.Tensor(hidden_size[0]))
        self.tau_m_h2 = nn.Parameter(torch.Tensor(hidden_size[1]))
        self.tau_m_o = nn.Parameter(torch.Tensor(output_size))

        # nn.init.orthogonal_(self.h2h.weight)
        nn.init.xavier_uniform_(self.i2h.weight)
        nn.init.xavier_uniform_(self.h2o.weight)
        nn.init.xavier_uniform_(self.h2r.weight)
        nn.init.xavier_uniform_(self.r2r.weight)
        nn.init.xavier_uniform_(self.r2p.weight)
        nn.init.xavier_uniform_(self.p2p.weight)
        nn.init.xavier_uniform_(self.p2r.weight)
        # nn.init.xavier_uniform_(self.d1.weight)

        nn.init.constant_(self.i2h.bias, 0)
        # nn.init.constant_(self.h2h.bias, 0)
        # nn.init.constant_(self.d1.bias, 0)

        nn.init.constant_(self.tau_adp_h, 200)
        nn.init.constant_(self.tau_adp_h2, 200)
        nn.init.constant_(self.tau_adp_o, 200)

        nn.init.constant_(self.tau_m_h, 20)
        nn.init.constant_(self.tau_m_h2, 20)
        nn.init.constant_(self.tau_m_o, 0.5)

        self.b_c = self.b_h = self.b_o = 0

        self.is_rec = is_rec

    def forward(self, input, labels):
        batch_size, seq_num, input_dim = input.shape
        self.b_c1=self.b_c2 = self.b_h= self.b_h2 = self.b_o = b_j0

        h1_mem =torch.rand(batch_size, self.hidden_size[0]).cuda()
        h1_spike =torch.zeros(batch_size, self.hidden_size[0]).cuda()

        h2_mem =torch.rand(batch_size, self.hidden_size[1]).cuda()
        h2_spike =torch.zeros(batch_size, self.hidden_size[1]).cuda()

        h2o_mem = torch.rand(batch_size, self.output_size).cuda()
        sum_spike_out = h2o_spike = torch.zeros(batch_size, self.output_size).cuda()
            
        hidden_spike_ = []
        h2o_spike_ = []
        theta_h_ = []
        theta_o_ = []
        hidden_mem_ = []
        h2o_mem_ = []
        output_ = []
        loss = 0
        predictions = []
        std_ = []

        I_h = []
        for i in range(T):
            x = input.view(-1,28*input_dim)
            x =  F.dropout(x,0.2,training=True)
            ####################################################################
            h_input = self.i2h(x)
            
            h1_mem, h1_spike, theta_h, self.b_h = mem_update_NU_adp(h_input,h1_mem, h1_spike, self.tau_adp_h,self.tau_m_h, self.b_h)
            r_input = self.h2r(h1_spike)+self.r2r(h2_spike[:,10:])+self.p2r(h2_spike[:,:10])
            # print(self.h2r(h1_spike).shape, self.r2r(h2_spike[:,10:]).shape,self.p2r(h2_spike[:,:10]).shape)
            # r_input = F.avg_pool1d(x,10,8)[:,0,:]+self.r2r(h2_spike[:,10:])+self.p2r(h2_spike[:,:10])
            p_input = self.p2p(h2_spike[:,:10])+self.r2p(h2_spike[:,10:])

            # 

            # h2_input_pad = F.pad(source, pad=(10, 0, 0,0))
            # print(p_input.shape,r_input.shape)
            h2_input = torch.cat((p_input,r_input),dim=1)#h2_input_pad + self.h2h(h2_spike)

            h2_mem, h2_spike, theta_h2, self.b_h2 = mem_update_NU_adp(h2_input,h2_mem, h2_spike, self.tau_adp_h2,self.tau_m_h2, self.b_h2)
            I_h.append(p_input.data.cpu().numpy())

            # h2o_mem, h2o_spike, theta_o, self.b_o = mem_update_NU_adp(self.h2o(h2_spike), h2o_mem, h2o_spike, self.tau_adp_o,self.tau_m_o, self.b_o)
            # h2o_input = self.h2o(h2_spike)
            # print(h2_spike.shape)
            h2o_input = h2_spike[:,:10]
            h2o_mem = integrator(h2o_input,h2o_mem,self.tau_m_o)
            # h2o_spike = h2_spike[:,-10:]

            # sum_spike_out =sum_spike_out + h2o_spike
            output_sumspike = F.log_softmax(h2o_mem,dim=1)
            output_.append(output_sumspike.data.cpu().numpy())
            if i >=0:
                std_mat = torch.tensor(np.array(output_)[-3:,:,:])
                output_std = torch.std(std_mat,dim=0)
                std_.append(output_std.data.cpu().numpy())
                # print(output_std.shape,self.criterion(output_sumspike, labels))

                loss += self.criterion(output_sumspike, labels)/T
                predictions.append(output_sumspike.data.cpu().numpy())
            # loss = self.criterion(output_sumspike, labels)

            hidden_spike_.append(h1_spike.data.cpu().numpy())
            hidden_mem_.append(h1_mem.data.cpu().numpy())

            h2o_spike_.append(h2o_spike.data.cpu().numpy())
            h2o_mem_.append(h2o_mem.data.cpu().numpy())

            theta_h_.append(theta_h2.data.cpu().numpy())
            # theta_o_.append(theta_o.data.cpu().numpy())
            
            
                
        predictions = torch.tensor(np.exp(predictions)) 
        return [predictions,h2o_mem_,h2o_spike_,std_],loss# output_sumspike, h1_spike, hidden_spike_, h2o_spike_, output_, hidden_mem_, h2o_mem_, theta_h_, theta_o_, I_h
    
    def predict(self, input, lablel):
        prediction,_ = self.forward(input, lablel)
        return prediction


'''
STEP 4: INSTANTIATE MODEL CLASS
'''
input_dim = 28
hidden_dim = [155,155+10]
layer_dim = 1  # layer number
output_dim = 10
seq_dim = int(784 / input_dim)  # Number of steps to unroll

criterion = nn.NLLLoss()

model = RNN_custom(input_dim,  output_dim,hidden_size=hidden_dim)
#model = torch.load('./model/model_97.98-v7 (copy).pth')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device:",device)
model.to(device)
criterion = nn.CrossEntropyLoss()
learning_rate = 3e-3

#optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
base_params = [
               model.i2h.weight, model.i2h.bias, 
               model.r2r.weight, model.r2r.bias,
               model.r2p.weight, model.r2p.bias,
               model.h2r.weight, model.h2r.bias,
               model.p2r.weight, model.p2r.bias,
               model.p2p.weight, model.p2p.bias,
            #    model.d1.weight, model.d1.bias,
                model.h2o.weight, model.h2o.bias]


optimizer = torch.optim.Adamax([
    {'params': base_params},

    {'params': model.tau_adp_h, 'lr': learning_rate * 2},
    {'params': model.tau_adp_h2, 'lr': learning_rate * 2},
    {'params': model.tau_adp_o, 'lr': learning_rate * 3},

    {'params': model.tau_m_h, 'lr': learning_rate * 2},
    {'params': model.tau_m_h2, 'lr': learning_rate * 2},
    {'params': model.tau_m_o, 'lr': learning_rate },],
    lr=learning_rate)
scheduler = StepLR(optimizer, step_size=300*10, gamma=.5)

'''
STEP 7: TRAIN THE MODEL
'''

def train(model,num_epochs=150):
    acc = []
    best_accuracy = 0
    
    for epoch in range(num_epochs):
        train_acc = 0
        train_loss_sum = 0
        total = 0
        for i, (images, labels) in enumerate(train_loader):
            # if i <2:
            images = images.view(-1, seq_dim, input_dim).requires_grad_().to(device)
            labels = labels.to(device)
            # Clear gradients w.r.t. parameters
            optimizer.zero_grad()
            # Forward pass to get output/logits
            predictions, train_loss = model(images, labels)
            predictions = predictions[0]
            _, predicted = torch.max(predictions.data, 2)
            train_loss.backward()
            train_loss_sum += train_loss
            optimizer.step()
            # scheduler.step()
            # print(predicted.shape, labels.shape)
            predicted = predicted.t()
            total += labels.size(0)
            train_acc += (predicted.cpu() == labels.cpu().view(-1,1).repeat(1,int(T))).sum()
            # train_acc += (predicted.cpu()[:,-1] == labels.cpu()).sum()
        train_acc_np = train_acc.data.cpu().numpy()/T/total
        
        accuracy = test(model)
        if accuracy > best_accuracy and accuracy > 95.:
            # torch.save(model,'./model/cnn_srnn_'+str(accuracy)+'-v0.pth')
            best_accuracy = accuracy
        acc.append(accuracy)
        # Print Loss
        #print('epoch: {}. Loss: {}. Accuracy: {}'.format(epoch, loss.item(), accuracy))
        print('epoch: ', epoch, '. Loss: ', train_loss_sum.item(), '. tr Acc: ',train_acc_np,' ,ts Acc:',accuracy)
    return acc

def test(model):
    correct = 0
    total = 0
    # Iterate through test dataset
    for images, labels in test_loader:
        images = images.view(-1, seq_dim, input_dim).to(device)
        labels = labels.to(device)
        outputs = model.predict(images, labels)
        outputs = outputs[0]
        _, predicted = torch.max(outputs.data, 2)
        predicted = predicted.t()
        # print(predicted.shape,labels.shape)
        total += labels.size(0)
        if torch.cuda.is_available():
            correct +=  (predicted.cpu() == labels.cpu().view(-1,1).repeat(1,int(T))).sum()
        else:
            correct += (predicted == labels).sum()
   
    accuracy = 100. * correct.numpy() / total/T
    return accuracy
    
def test_real(model):
    correct = 0
    total = 0
    # Iterate through test dataset
    for images, labels in test_loader:
        images = images.view(-1, seq_dim, input_dim).to(device)
        labels = labels.to(device)
        outputs = model.predict(images, labels)
        outputs = outputs[0]
        pred_sum = outputs.sum(axis=1)# sum along time
        _, predicted = torch.max(pred_sum.data, 1)
        # print(predicted.shape,labels.shape)
        total += labels.size(0)
        if torch.cuda.is_available():
            correct += (predicted.cpu() == labels.cpu()).sum()
        else:
            correct += (predicted == labels).sum()
   
    accuracy = 100. * correct.numpy() / total
    return accuracy
###############################
acc = train(model,num_epochs)
accuracy = test(model)
print('. Accuracy: ',accuracy)


###################
##  Accuracy  curve
###################
plt.plot(acc)
plt.title('Learning Curve -- Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy: %')
plt.show()

#



device: cuda:0
epoch:  0 . Loss:  800.4718017578125 . tr Acc:  0.41514  ,ts Acc: 57.479
epoch:  1 . Loss:  560.6456909179688 . tr Acc:  0.6280125  ,ts Acc: 67.029
epoch:  2 . Loss:  478.5580749511719 . tr Acc:  0.6910241666666667  ,ts Acc: 71.35900000000001
epoch:  3 . Loss:  427.5786437988281 . tr Acc:  0.7268116666666666  ,ts Acc: 74.1255
epoch:  4 . Loss:  390.7110900878906 . tr Acc:  0.7520683333333333  ,ts Acc: 76.46849999999999
epoch:  5 . Loss:  360.2369689941406 . tr Acc:  0.7712825  ,ts Acc: 77.4235
epoch:  6 . Loss:  337.2833557128906 . tr Acc:  0.7847708333333333  ,ts Acc: 78.8025
epoch:  7 . Loss:  317.8880920410156 . tr Acc:  0.7965166666666667  ,ts Acc: 80.067
epoch:  8 . Loss:  299.4064025878906 . tr Acc:  0.8084441666666667  ,ts Acc: 81.074
epoch:  9 . Loss:  282.916015625 . tr Acc:  0.81756  ,ts Acc: 81.9835
epoch:  10 . Loss:  268.64422607421875 . tr Acc:  0.8260175000000001  ,ts Acc: 82.5635
epoch:  11 . Loss:  255.95419311523438 . tr Acc:  0.8335425000000001  ,ts Ac

In [None]:
images,labels = next(iter(test_loader)) 
images = images.view(-1, seq_dim, input_dim)[:20].to(device)
labels = labels[:20].to(device)
print(images.shape,labels.shape)
a = model.forward(images,labels)
print(a[0][0].cpu().numpy().shape)



In [None]:
pred = a[0][0].cpu().numpy()
pred1 = pred[:,1,:]
# plt.imshow(pred1)
plt.figure(figsize=(10,5))
for i in range(10):
    plt.plot(pred1[:,i],label=str(i))
plt.legend()

In [None]:
pred = np.array(a[0][1])
pred1 = pred[:,1,:]
# plt.imshow(pred1)
plt.figure(figsize=(10,5))
plt.subplot(121)
plt.imshow(images.cpu().numpy()[1,:,:])
plt.subplot(122)
for i in range(10):
    plt.plot(pred1[:,i],label=str(i))
plt.legend()

In [None]:
images,labels = next(iter(test_loader)) 
images = images.view(-1, seq_dim, input_dim)[:20].to(device)
images = images*torch.randn_like(images).gt(0.2)
labels = labels[:20].to(device)
print(images.shape,labels.shape)
a = model.forward(images,labels)
print(a[0][0].cpu().numpy().shape)

In [None]:
pred = np.array(a[0][1])
pred1 = pred[:,1,:]
# plt.imshow(pred1)
plt.figure(figsize=(10,5))
plt.subplot(121)
plt.imshow(images.cpu().numpy()[1,:,:])
plt.subplot(122)
for i in range(10):
    plt.plot(pred1[:,i],label=str(i))
plt.legend()

In [None]:
images,labels = next(iter(test_loader)) 
images = images.view(-1, seq_dim, input_dim)[:20].to(device)
images = images*torch.randn_like(images).gt(0.1)
labels = labels[:20].to(device)
print(images.shape,labels.shape)
a = model.forward(images,labels)
print(a[0][0].cpu().numpy().shape)


In [None]:
pred = np.array(a[0][1])
pred1 = pred[:,1,:]
# plt.imshow(pred1)
plt.figure(figsize=(40,5))

pred1 = pred[:,1,:]
plt.subplot(251)
plt.imshow(images.cpu().numpy()[1,:,:])
plt.subplot(256)
for i in range(10):
    plt.plot(pred1[:,i],label=str(i))
# plt.legend()

pred1 = pred[:,3,:]
plt.subplot(252)
plt.imshow(images.cpu().numpy()[3,:,:])
plt.subplot(257)
for i in range(10):
    plt.plot(pred1[:,i],label=str(i))
# plt.legend()

pred1 = pred[:,5,:]
plt.subplot(253)
plt.imshow(images.cpu().numpy()[5,:,:])
plt.subplot(258)
for i in range(10):
    plt.plot(pred1[:,i],label=str(i))
# plt.legend()

pred1 = pred[:,7,:]
plt.subplot(254)
plt.imshow(images.cpu().numpy()[7,:,:])
plt.subplot(259)
for i in range(10):
    plt.plot(pred1[:,i],label=str(i))
# plt.legend()

pred1 = pred[:,12,:]
plt.subplot(255)
plt.imshow(images.cpu().numpy()[12,:,:])
plt.subplot(2,5,10)
for i in range(10):
    plt.plot(pred1[:,i],label=str(i))
plt.legend(loc=2)

In [None]:
def my_cross_entropy(x, y):
    log_prob = -1.0 * F.log_softmax(x, 1)
    loss = log_prob.gather(1, y.unsqueeze(1))
    loss = loss.mean()
    return loss


criterion = nn.CrossEntropyLoss()

batch_size = 5
nb_classes = 10
x = torch.randn(batch_size, nb_classes, requires_grad=True)
y = torch.randint(0, nb_classes, (batch_size,))

loss_reference = criterion(x, y)
loss = my_cross_entropy(x, y)

print(loss_reference - loss)

In [None]:
def std_cross_entropy(x, y,std_mat):
    log_prob = -1.0 * F.log_softmax(x, 1)
    std_log_prob = std_mat*log_prob
    print(std_log_prob.shape)
    loss = std_log_prob.gather(1, y.unsqueeze(1))
    loss = loss.mean()
    return loss

batch_size = 5
nb_classes = 10
x = torch.randn(batch_size, nb_classes, requires_grad=True)
y = torch.randint(0, nb_classes, (batch_size,))
std_mat = torch.randn(batch_size, nb_classes)

loss_reference = criterion(x, y)
loss = std_cross_entropy(x, y,std_mat)

print(loss_reference - loss)