# CRI MNIST Demonstration with snnTorch 

## Training SNN with snnTorch

In [None]:
!pip install snntorch

In [None]:
# imports
import snntorch as snn
from snntorch import surrogate
from snntorch import backprop
from snntorch import functional as SF
from snntorch import utils
from snntorch import spikeplot as splt
from snntorch import spikegen
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.nn.parameter import Parameter
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import os
import shutil
import time
from quant_layer import *

### Import MNIST datasets

In [None]:
# dataloader arguments
batch_size = 128
data_path='~/justinData/mnist'

dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
# Define a transform
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

In [None]:
# Create DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

### Define the network

In [1]:
# Network Architecture
num_inputs = 28*28
num_hidden_0 = 1000
num_outputs = 10

# Temporal Dynamics
num_steps = 25
beta = 1.0
spike_grad = surrogate.fast_sigmoid(slope=25)

NameError: name 'surrogate' is not defined

In [None]:
net = nn.Sequential(QuantLinear(num_inputs, num_hidden_0, bias = True), 
                    snn.Leaky(beta=beta, spike_grad=spike_grad,init_hidden=True),
                    QuantLinear(num_hidden_0, num_outputs, bias = True),
                    snn.Leaky(beta=beta, spike_grad=spike_grad,init_hidden=True, output=True)).to(device)

In [None]:
data, targets = next(iter(train_loader))
data = data.to(device)
targets = targets.to(device)

In [None]:
def forward_pass(net, num_steps, data, batch_size):
    mem_rec = []
    spk_rec = []
    utils.reset(net)  # resets hidden states for all LIF neurons in net

    for step in range(num_steps):
        spk_out, mem_out = net(data.view(batch_size, -1))
        spk_rec.append(spk_out)
        mem_rec.append(mem_out)
  
    return torch.stack(spk_rec), torch.stack(mem_rec)

In [None]:
spk_rec, mem_rec = forward_pass(net, num_steps, data, batch_size)

### Loss Functions

In [None]:
loss_fn = SF.ce_rate_loss()

In [None]:
loss_val = loss_fn(spk_rec, targets)

print(f"The loss from an untrained network is {loss_val.item():.3f}")

### Accuracy 

In [None]:
acc = SF.accuracy_rate(spk_rec, targets)

print(f"The accuracy of a single batch using an untrained network is {acc*100:.3f}%")

In [None]:
def batch_accuracy(train_loader, net, num_steps, batch_size):
    with torch.no_grad():
        total = 0
        acc = 0
        net.eval()

        train_loader = iter(train_loader)
        for data, targets in train_loader:
            data = data.to(device)
            targets = targets.to(device)
            spk_rec, _ = forward_pass(net, num_steps, data, batch_size)

            acc += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1)
            total += spk_rec.size(1)

    return acc/total

### Training

In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=5e-3, betas=(0.9, 0.999))
num_epochs = 20
loss_hist = []
test_loss_hist = []
counter = 0

# Outer training loop
for epoch in range(num_epochs):
    iter_counter = 0
    train_batch = iter(train_loader)

    # Minibatch training loop
    for data, targets in train_batch:
        data = data.to(device)
        targets = targets.to(device)

        # forward pass
        net.train()
        spk_rec, mem_rec = forward_pass(net, num_steps, data, batch_size)

        # initialize the loss & sum over time
        loss_val = torch.zeros((1), dtype=dtype, device=device)
        loss_val += loss_fn(spk_rec, targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # Test set
        with torch.no_grad():
            net.eval()
            test_data, test_targets = next(iter(test_loader))
            test_data = test_data.to(device)
            test_targets = test_targets.to(device)

            # Test set forward pass
            test_spk, test_mem = forward_pass(net, num_steps, test_data,batch_size)

            # Test set loss
            test_loss = torch.zeros((1), dtype=dtype, device=device)
            test_loss += loss_fn(test_spk, test_targets)

            test_loss_hist.append(test_loss.item())

            # Print train/test loss/accuracy
            if counter % 50 == 0:
                print(f"Epoch {epoch}, Iteration {iter_counter}")
                print(f"Train Set Loss: {loss_hist[counter]:.2f}")
                print(f"Test Set Loss: {test_loss_hist[counter]:.2f}")
                train_acc = SF.accuracy_rate(spk_rec, targets)
                test_acc = SF.accuracy_rate(test_spk, test_targets)
                print(f"Train set accuracy for a single minibatch: {train_acc*100:.2f}%")
                print(f"Test set accuracy for a single minibatch: {test_acc*100:.2f}%")
                print("\n")
            counter += 1
            iter_counter +=1

In [None]:
test_acc = batch_accuracy(test_loader, net, num_steps, batch_size)

print(f"The total accuracy on the test set is: {test_acc * 100:.2f}%")

In [None]:
# Plot Loss
fig = plt.figure(facecolor="w", figsize=(10, 5))
plt.plot(loss_hist)
plt.plot(test_loss_hist)
plt.title("Loss Curves")
plt.legend(["Train Loss", "Test Loss"])
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.show()

### Save Models

In [None]:
def save_checkpoint(state, is_quan, fdir):
    filepath = os.path.join(fdir, 'checkpoint.pth')
    torch.save(state, filepath)
    if is_quan:
        shutil.copyfile(filepath, os.path.join(fdir, 'model_quantized_6L.pth.tar'))
    else:
        shutil.copyfile(filepath, os.path.join(fdir, 'model_snnTorch_6L.pth.tar'))

In [None]:
if not os.path.exists('result'):
    os.makedirs('result')
fdir = 'result/'
if not os.path.exists(fdir):
    os.makedirs(fdir)

In [None]:
save_checkpoint({'state_dict': net.state_dict(),}, 1, fdir)

### Quantization

In [None]:
def weight_quantization(b):

    def uniform_quant(x, b):
        xdiv = x.mul((2 ** b - 1))
        xhard = xdiv.round().div(2 ** b - 1)
        #print('uniform quant bit: ', b)
        return xhard

    class _pq(torch.autograd.Function):
        @staticmethod
        def forward(ctx, input, alpha):
            input.div_(alpha)                          # weights are first divided by alpha
            input_c = input.clamp(min=-1, max=1)       # then clipped to [-1,1]
            sign = input_c.sign()
            input_abs = input_c.abs()
            input_q = uniform_quant(input_abs, b).mul(sign)
            ctx.save_for_backward(input, input_q)
            input_q = input_q.mul(alpha)               # rescale to the original range
            return input_q

        @staticmethod
        def backward(ctx, grad_output):
            grad_input = grad_output.clone()             # grad for weights will not be clipped
            input, input_q = ctx.saved_tensors
            i = (input.abs()>1.).float()     # >1 means clipped. # output matrix is a form of [True, False, True, ...]
            sign = input.sign()              # output matrix is a form of [+1, -1, -1, +1, ...]
            #grad_alpha = (grad_output*(sign*i + (input_q-input)*(1-i))).sum()
            grad_alpha = (grad_output*(sign*i + (0.0)*(1-i))).sum()
            # above line, if i = True,  and sign = +1, "grad_alpha = grad_output * 1"
            #             if i = False, "grad_alpha = grad_output * (input_q-input)"
            grad_input = grad_input*(1-i)
            return grad_input, grad_alpha

    return _pq().apply

class weight_quantize_fn(nn.Module):
    def __init__(self, w_bit):
        super(weight_quantize_fn, self).__init__()
        self.w_bit = w_bit-1
        #self.wgt_alpha = wgt_alpha
        self.weight_q = weight_quantization(b=self.w_bit)
        #self.register_parameter('wgt_alpha', Parameter(torch.tensor(3.0)))
    def forward(self, weight):
        #mean = weight.data.mean()
        #std = weight.data.std()
        #weight = weight.add(-mean).div(std)      # weights normalization
        weight_q = self.weight_q(weight, self.wgt_alpha)

        return weight_q

In [None]:
w_alpha=1
w_bits=16
weight_quant = weight_quantize_fn(w_bit= w_bits)  ## define quant function
weight_quant.wgt_alpha = w_alpha
fc1_quant      = weight_quant(net[0].weight)
w_delta        = w_alpha/(2**(w_bits-1)-1)
fc1_int        = fc1_quant/w_delta
print("FC1 Weights: \n",fc1_int)

In [None]:

for layer in net:
        if isinstance(layer, torch.nn.Linear):
                layer.weight = Parameter(weight_quant(layer.weight))
                w_delta = w_alpha/(2**(w_bits-1)-1)
                layer.weight = Parameter(layer.weight/w_delta)
                layer.bias = Parameter(layer.bias/w_delta)
#                 print(layer.weight)
#                 print(layer.bias)
        if isinstance(layer, torch.nn.Conv2d):
                layer.weight = Parameter(weight_quant(layer.weight))
                w_delta = w_alpha/(2**(w_bits-1)-1)
                layer.weight = Parameter(layer.weight/w_delta)
                layer.bias = Parameter(layer.bias/w_delta)
#                 print(layer.weight)
#                 print(layer.bias)
        if isinstance(layer, snn.Leaky):
                layer.threshold = layer.threshold/w_delta

In [None]:
test_acc = batch_accuracy(test_loader, net, num_steps, batch_size)

print(f"The total accuracy on the test set is: {test_acc * 100:.2f}%")

In [None]:
save_checkpoint({'state_dict': net.state_dict(),}, 1, fdir)

### Load Saved Model

In [None]:
best_model_path = '/Volumes/export/isn/keli/Desktop/CRI/result/model_quantized.pth.tar'
checkpoint = torch.load(best_model_path)
net.load_state_dict(checkpoint['state_dict'])
net.eval()

In [None]:
test_acc = batch_accuracy(test_loader, net, num_steps, batch_size)
print(f"The total accuracy on the test set is: {test_acc * 100:.2f}%")

### Mapping into CRI

In [None]:
for i, layer in enumerate(net):
    if i % 2 == 0:
        print(layer.weight.shape)

In [None]:
# extract weights and bias for torchsnn
layers, biases = [], []
for i, layer in enumerate(net):
    if i % 2 == 0:
        layers.append(layer.weight.detach().cpu().numpy())
        biases.append(layer.bias.detach().cpu().numpy())

print(np.min(layers[1]))
print(np.max(layers[1]))

In [None]:
for layerNum, layer in enumerate(layers):
    print(layer.shape)
    print(biases[layerNum].shape)

In [None]:
axonsDict = {}
neuronsDict = {}
outputs = []
bias_axon = {}

axonOffset = 0
currLayerNeuronIdxOffset = 0
nextLayerNeuronIdxOffset = 0
for layerNum, layer in enumerate(layers):
    inFeatures = layer.shape[1]
    outFeatures = layer.shape[0]
    shape = layer.shape
    weight = layer
    bias = biases[layerNum]
    print("Weights shape: ", np.shape(weight))
    if (layerNum == 0):
        print('constructing Axons')
        print("Input layer shape(outfeature, infeature): ", weight.shape)
        for axonIdx, axon in enumerate(weight.T):
            #print(axonIdx)
            axonID = 'a'+str(axonIdx)
            axonEntry = [(str(postSynapticID), int(synapseWeight)) for postSynapticID, synapseWeight in enumerate(axon) ]
            axonsDict[axonID] = axonEntry
        axonOffset += inFeatures
        print("axon offset: ",axonOffset)
        #implmenting bias: for each bias add a axon with corresponding weights with synapse (neuron, bias_val)
        print('Construct bias axons for first hidden layers:',bias.shape)
        for neuronIdx, bias_value in enumerate(bias):
            biasAxonID = 'a'+str(neuronIdx + axonOffset)
            biasAxonEntry = [(str(neuronIdx),int(bias_value))]
            axonsDict[biasAxonID] = biasAxonEntry
        print("number of axons: ", len(axonsDict))
        print("number of neurons: ", len(neuronsDict),"\n")
        
    elif (layerNum == len(layers)-1):
        print('constructing output layer')
        nextLayerNeuronIdxOffset += inFeatures
        print("output layer shape(outfeature, infeature): ", weight.shape)
        for baseNeuronIdx, neuron in enumerate(weight.T):
            neuronID = str(baseNeuronIdx+currLayerNeuronIdxOffset)
            neuronEntry = [(str(basePostSynapticID+nextLayerNeuronIdxOffset), int(synapseWeight)) for basePostSynapticID, synapseWeight in enumerate(neuron) if synapseWeight != 0]
            neuronsDict[neuronID] = neuronEntry
            #print(neuronID)
        currLayerNeuronIdxOffset += inFeatures
        #instantiate the output neurons
        print('instantiate output neurons')
        for baseNeuronIdx in range(outFeatures):
            neuronID = str(baseNeuronIdx+nextLayerNeuronIdxOffset)
            neuronsDict[neuronID] = []
            outputs.append(neuronID)
            #print(neuronID)
        #implmenting bias: for each bias add a axon with corresponding weights with synapse (neuron, bias_val)
        print('Construct bias axons for output neurons',bias.shape)
        axonOffset += inFeatures
        for neuronIdx, bias_value in enumerate(bias):
            biasAxonID = 'a'+str(neuronIdx + axonOffset)
            biasAxonEntry = [(str(neuronIdx+nextLayerNeuronIdxOffset),int(bias_value))]
            axonsDict[biasAxonID] = biasAxonEntry
        print("number of axons: ", len(axonsDict))
        print("number of neurons: ", len(neuronsDict),"\n")
            
    else:
        print('constructing hidden layer')
        nextLayerNeuronIdxOffset += inFeatures
        for baseNeuronIdx, neuron in enumerate(weight.T): #SHOULD THIS BE A TRANSPOSE
            neuronID = str(baseNeuronIdx+currLayerNeuronIdxOffset)
            neuronEntry = [(str(basePostSynapticID+nextLayerNeuronIdxOffset), int(synapseWeight)) for basePostSynapticID, synapseWeight in enumerate(neuron) if synapseWeight != 0 ]
            neuronsDict[neuronID] = neuronEntry
            #print(neuronID)
        currLayerNeuronIdxOffset += inFeatures
        axonOffset += inFeatures
        print("axon offset: ",axonOffset)
        #implmenting bias: for each bias add a axon with corresponding weights with synapse (neuron, bias_val)
        print('Construct bias axons for hidden layers:',bias.shape)
        for neuronIdx, bias_value in enumerate(bias):
            biasAxonID = 'a'+str(neuronIdx + axonOffset)
            biasAxonEntry = [(str(neuronIdx+nextLayerNeuronIdxOffset),int(bias_value))]
            axonsDict[biasAxonID] = biasAxonEntry
        print("number of axons: ", len(axonsDict))
        print("number of neurons: ", len(neuronsDict),"\n")
        
print("output neurons: ", outputs)
print("number of axons: ", len(axonsDict))
print("number of neurons: ", len(neuronsDict),"\n")

In [None]:
print("Number of axons: ",len(axonsDict))
totalAxonSyn = 0
maxFan = 0
for key in axonsDict.keys():
    totalAxonSyn += len(axonsDict[key])
    if len(axonsDict[key]) > maxFan:
        maxFan = len(axonsDict[key])
print("Total number of connections between axon and neuron: ", totalAxonSyn)
print("Max fan out of axon: ", maxFan)
print('---')
print("Number of neurons: ", len(neuronsDict))
totalSyn = 0
maxFan = 0
for key in neuronsDict.keys():
    totalSyn += len(neuronsDict[key])
    if len(neuronsDict[key]) > maxFan:
        maxFan = len(neuronsDict[key])
print("Total number of connections between hidden and output layers: ", totalSyn)
print("Max fan out of neuron: ", maxFan)

In [None]:
from l2s.api import CRI_network
import cri_simulations

In [None]:
config = {}
config['neuron_type'] = "I&F"
config['global_neuron_params'] = {}
config['global_neuron_params']['v_thr'] = 9*10**4
#softwareNetwork = CRI_network(axons=axonsDict,connections=neuronsDict,config=config,target='simpleSim', outputs = outputs)
hardwareNetwork = CRI_network(axons=axonsDict,connections=neuronsDict,config=config,target='CRI', outputs = outputs,simDump = False)

In [None]:
def input_to_CRI(currentInput):
    num_steps = 10
    currentInput = data.view(data.size(0), -1)
    batch = []
    n = 0
    for element in currentInput:
        timesteps = []
        rateEnc = spikegen.rate(element,num_steps)
        rateEnc = rateEnc.detach().cpu().numpy()
        for element in rateEnc:
            currInput = ['a'+str(idx) for idx,axon in enumerate(element) if axon != 0]
            biasInput = ['a'+str(idx) for idx in range(784,len(axonsDict))]
#             timesteps.append(currInput)
#             timesteps.append(biasInput)
            timesteps.append(currInput+biasInput)
        batch.append(timesteps)
    return batch

In [None]:
def run_CRI(inputList,output_offset):
    predictions = []
    total_time_cri = 0
    #each image
    for currInput in inputList:
        #reset the membrane potential to zero
        softwareNetwork.simpleSim.initialize_sim_vars(len(neuronsDict))
        spikeRate = [0]*10
        #each time step
        for slice in currInput:
            start_time = time.time()
            swSpike = softwareNetwork.step(slice, membranePotential=False)
            end_time = time.time()
            total_time_cri = total_time_cri + end_time-start_time
            for spike in swSpike:
                spikeIdx = int(spike) - output_offset 
                try: 
                    if spikeIdx >= 0: 
                        spikeRate[spikeIdx] += 1 
                except:
                    print("SpikeIdx: ", spikeIdx,"\n SpikeRate:",spikeRate )
        predictions.append(spikeRate.index(max(spikeRate)))
    print(f"Total simulation execution time: {total_time_cri:.5f} s")
    return(predictions)

In [None]:
def run_CRI_hw(inputList,output_offset):
    predictions = []
    #each image
    total_time_cri = 0
    for currInput in inputList:
        #initiate the softwareNetwork for each image
        cri_simulations.FPGA_Execution.fpga_controller.clear(len(neuronsDict), False, 0)  ##Num_neurons, simDump, coreOverride
        spikeRate = [0]*10
        #each time step
        for slice in currInput:
            start_time = time.time()
            hwSpike = hardwareNetwork.step(slice, membranePotential=False)
#             print("Mem:",mem)
            end_time = time.time()
            total_time_cri = total_time_cri + end_time-start_time
            print(hwSpike)
            for spike in hwSpike:
                print(int(spike))
                spikeIdx = int(spike) - output_offset 
                if spikeIdx >= 0: 
                    spikeRate[spikeIdx] += 1 
        predictions.append(spikeRate.index(max(spikeRate))) 
    print(f"Total execution time CRIFPGA: {total_time_cri:.5f} s")
    return(predictions)

In [None]:
total = 0
correct = 0
cri_correct = 0
cri_correct_hw = 0
batch_size = 10
# drop_last switched to False to keep all samples
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)
output_offset = 7500
with torch.no_grad():
    net.eval()
    
    train_loader = iter(train_loader)
    for data, targets in test_loader:
        data = data.to(device)
        targets = targets.to(device)
        input = input_to_CRI(data)
#         criPred = torch.tensor(run_CRI(input,output_offset)).to(device)
#         print("CRI Predicted: ",criPred)
        criPred_hw = torch.tensor(run_CRI_hw(input,output_offset)).to(device)
        print("CRI Predicted HW: ",criPred_hw)
        print("Target: ",targets)
        test_spk, _ = forward_pass(net, num_steps, data, batch_size)

        # calculate total accuracy
        _, predicted = test_spk.sum(dim=0).max(1)
        print("Torchsnn Predicted: ",predicted)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()
#         cri_correct += (criPred == targets).sum().item()
        cri_correct_hw += (criPred_hw == targets).sum().item()
        break #run for one batch

In [None]:
# print(f"Totoal execution time: {end_time-start_time:.2f} s")
# print(f"Total correctly classified test set images for TorchSNN: {correct}/{total}")
print(f"Total correctly classified test set images for CRI: {cri_correct}/{total}")
print(f"Test Set Accuracy for TorchSNN: {100 * correct / total:.2f}%")
print(f"Test Set Accuracy for CRI: {100 * cri_correct / total:.2f}%")