# Custom integer quantization for SNN

In [1]:
import torch
import pickle
import os
import snntorch as snn
import torch.nn as nn
import numpy as np
from collections import OrderedDict

import os

In [2]:
import torch
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as torchvision_transforms
import tonic
import tonic.transforms as transforms

# Define sensor size for NMNIST dataset
sensor_size = tonic.datasets.NMNIST.sensor_size

# Define transformations
# Note: The use of torch.from_numpy is removed as Tonic's transforms handle conversion.
transform = tonic.transforms.Compose([
    transforms.Denoise(filter_time=10000),
    transforms.ToFrame(sensor_size=sensor_size, time_window=10000),
    # torchvision.transforms.RandomRotation is not directly applicable to event data.
    # If rotation is needed, it should be done on the frames after conversion by ToFrame.
])

# Load NMNIST datasets without caching
trainset = tonic.datasets.NMNIST(save_to='tmp/data', transform=transform, train=True)
testset = tonic.datasets.NMNIST(save_to='tmp/data', transform=transform, train=False)

# Split trainset into training and validation datasets
train_size = int(0.8 * len(trainset))
val_size = len(trainset) - train_size
train_dataset, val_dataset = random_split(trainset, [train_size, val_size])

# Create DataLoaders for training, validation, and testing
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=tonic.collation.PadTensors(batch_first=False))
val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False))
test_loader = DataLoader(testset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False), shuffle=True)

# Fetch a single batch from the train_loader to inspect the shape
data, targets = next(iter(train_loader))
print(f"Data shape: {data.shape}")  # Example output: torch.Size([batch_size, timesteps, channels, height, width])
print(f"Targets shape: {targets.shape}")  # Example output: torch.Size([batch_size])

  from .autonotebook import tqdm as notebook_tqdm


Data shape: torch.Size([31, 128, 2, 34, 34])
Targets shape: torch.Size([128])


In [3]:
config = {
    # SNN
    "threshold1": 2.5,
    "threshold2": 8.0,
    "threshold3": 4.0,
    "beta": 0.5,
    "num_steps": 10,
    
    # SNN Dense Shape
    "dense1_input": 2312,
    "num_classes": 10,

    # Hyper Params
    "lr": 0.007,

    # Early Stopping
    "min_delta": 1e-6,
    "patience_es": 20,

    # Training
    "epochs": 1
}

### Last model trained with FC layers and LIF


In [4]:
class SNN(nn.Module):
  def __init__(self, config, gather_mem_stats = None):
    super(SNN, self).__init__()

    # Initialize configuration parameters
      # LIF
    self.thresh1 = config["threshold1"]
    self.thresh2 = config["threshold2"]
    self.thresh3 = config["threshold3"]
    self.beta = config["beta"]
    self.num_steps = config["num_steps"]

      # Dense Shape
    self.dense1_input = config["dense1_input"]
    self.num_classes = config["num_classes"]

      # Network Layers
    self.fc1 = nn.Linear(self.dense1_input, self.dense1_input//4)
    self.lif1 = snn.Leaky(beta=self.beta, threshold=self.thresh1)
    
    self.fc2 = nn.Linear(self.dense1_input//4, self.dense1_input//8)
    self.lif2 = snn.Leaky(beta=self.beta, threshold=self.thresh2)
    
    self.fc3 = nn.Linear(self.dense1_input//8, self.num_classes)
    self.lif3 = snn.Leaky(beta=self.beta, threshold=self.thresh3)
    
    self.flatten = nn.Flatten()

    self.gather_mem_stats = gather_mem_stats
    if gather_mem_stats is not None:
      self.mem1_val = None
      self.mem2_val = None
      self.mem3_val = None
    
    
    # Forward Pass
  def forward(self, inpt):
    mem1 = self.lif1.init_leaky()
    mem2 = self.lif2.init_leaky()
    mem3 = self.lif3.init_leaky()

    spike3_rec = []
    mem3_rec = []

    for step in range(inpt.shape[0]):
      #print(inpt[step].shape)
      
      current_input = inpt[step]
      current_input = self.flatten(current_input)
      
      current1 = self.fc1(current_input)
      spike1, mem1 = self.lif1(current1, mem1)

      if self.gather_mem_stats is not None:
        if self.mem1_val is None:
          self.mem1_val = mem1.flatten().clone()
        else:
          self.mem1_val = torch.cat((self.mem1_val, mem1.flatten()))

      current2 = self.fc2(spike1)
      spike2, mem2 = self.lif2(current2, mem2)

      if self.gather_mem_stats is not None:
        if self.mem2_val is None:
          self.mem2_val = mem2.flatten().clone()
        else:
          self.mem2_val = torch.cat((self.mem2_val, mem2.flatten()))

      current3 = self.fc3(spike2)
      spike3, mem3 = self.lif3(current3, mem3)

      if self.gather_mem_stats is not None:
        if self.mem3_val is None:
          self.mem3_val = mem3.flatten().clone()
        else:
          self.mem3_val = torch.cat((self.mem3_val, mem3.flatten()))

      spike3_rec.append(spike3)
      mem3_rec.append(mem3)

    return torch.stack(spike3_rec, dim=0), torch.stack(mem3_rec, dim=0)

In [5]:
def quantizeFixedPoint(width, frac, X):
    step = 2 ** (-frac)
    max_val = (2 ** (width - 1)) - 1
    min_val = -(2 ** (width - 1))

    X_q = torch.round(X / step)
    X_q = torch.clamp(X_q, min_val, max_val)
    return X_q * step

### Saving quantized weights and biases

In [12]:
import pandas as pd

# Loading model's parameters
model_path = 'best_SNN_model.pth'
model = SNN(config, False)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()

# Quantizing weights
model.fc2.weight.data = quantizeFixedPoint(8, 7, model.fc2.weight.data)
model.fc1.weight.data = quantizeFixedPoint(8, 7, model.fc1.weight.data)
model.fc3.weight.data = quantizeFixedPoint(8, 7, model.fc3.weight.data)

# Quantizing biases
model.fc1.bias.data = quantizeFixedPoint(8, 7, model.fc1.bias.data)
model.fc2.bias.data = quantizeFixedPoint(8, 7, model.fc2.bias.data)
model.fc3.bias.data = quantizeFixedPoint(8, 7, model.fc3.bias.data)

# Convert the tensor to a pandas DataFrame
df1 = pd.DataFrame(torch.transpose(model.fc1.weight.data, 0, 1).numpy())
df2 = pd.DataFrame(torch.transpose(model.fc2.weight.data, 0, 1).numpy())
df3 = pd.DataFrame(torch.transpose(model.fc3.weight.data, 0, 1).numpy())

df4 = pd.DataFrame(model.fc1.bias.data.numpy())
df5 = pd.DataFrame(model.fc2.bias.data.numpy())
df6 = pd.DataFrame(model.fc3.bias.data.numpy())

# Write the DataFrame to a CSV file
df1.to_csv('fc1_weights.csv', index=False, header=False)
df2.to_csv('fc2_weights.csv', index=False, header=False)
df3.to_csv('fc3_weights.csv', index=False, header=False)

df4.to_csv('fc1_bias.csv', index=False, header=False)
df5.to_csv('fc2_bias.csv', index=False, header=False)
df6.to_csv('fc3_bias.csv', index=False, header=False)


In [49]:
# Loading model's parameters
model_path = 'best_SNN_model.pth'
model = SNN(config, False)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()

# Quantizing weights
model.fc1.weight.data = quantizeFixedPoint(8, 7, model.fc1.weight.data)
model.fc2.weight.data = quantizeFixedPoint(8, 7, model.fc2.weight.data)
model.fc3.weight.data = quantizeFixedPoint(8, 7, model.fc3.weight.data)

# Quantizing biases
model.fc1.bias.data = quantizeFixedPoint(8, 7, model.fc1.bias.data)
model.fc2.bias.data = quantizeFixedPoint(8, 7, model.fc2.bias.data)
model.fc3.bias.data = quantizeFixedPoint(8, 7, model.fc3.bias.data)

def test(model, test_loader, criterion, device, model_path="best_SNN_model.pth"):

    # Initialize variables for test loss and accuracy
    test_loss = 0.0
    correct_test = 0
    total_test = 0

    # Switch model to evaluation mode
    model.eval()

    # Iterate over the test data
    with torch.no_grad():
        for data, targets in test_loader:
            data, targets = data.to(device), targets.to(device)

            # Forward pass
            outputs, _ = model(data)  # Modify according to your model's output
            outputs = outputs.mean(dim=0)

            # Calculate loss
            loss = criterion(outputs, targets)
            test_loss += loss.item()

            # Calculate accuracy
            _, predicted = torch.max(outputs.data, 1)
            total_test += targets.size(0)
            correct_test += (predicted == targets).sum().item()

    # Calculate average loss and accuracy
    test_loss /= len(test_loader)
    test_accuracy = 100 * correct_test / total_test

    return test_loss, test_accuracy

criterion = nn.CrossEntropyLoss()
device = torch.device("cpu")
test_loss, test_accuracy = test(model, test_loader, criterion, device)
print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")

KeyboardInterrupt: 

In [6]:
# Loading model's parameters
model_path = 'best_SNN_model.pth'
model = SNN(config, False)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()

# Quantizing weights
Wq = []
Wq.append(quantizeFixedPoint(8, 7, model.fc1.weight.data))
Wq.append(quantizeFixedPoint(8, 7, model.fc2.weight.data))
Wq.append(quantizeFixedPoint(8, 7, model.fc3.weight.data))

# Quantizing biases
Bq = []
Bq.append(quantizeFixedPoint(8, 7, model.fc1.bias.data))
Bq.append(quantizeFixedPoint(8, 7, model.fc2.bias.data))
Bq.append(quantizeFixedPoint(8, 7, model.fc3.bias.data))

# Quantizing neuron parameters
Threshq = torch.tensor([model.thresh1, model.thresh2, model.thresh3])
Lq = Threshq * model.beta
Threshq = quantizeFixedPoint(16, 8, Threshq)
Lq = quantizeFixedPoint(16, 8, Lq)

In [57]:
def matMACQuantized(W, B, x):
    
    y = torch.matmul(W, x) + B
    y = quantizeFixedPoint(16, 8, y)
    return y

def LifQuantized(Isyn, mem, thresh, L, beta, spike):
    
    # After each operation the quantizer has to be called in order to see if over/underflow has occured
    mem = mem * beta
    mem = quantizeFixedPoint(16, 8, mem)

    # Adding L if there was a spike generated in the previous time step
    for i in range(mem.shape[0]):
        if spike[i] == 1:
            mem[i] -= L
    mem = quantizeFixedPoint(16, 8, mem)

    # Rescaling Isyn
    Isyn_trunc = torch.trunc(Isyn)
    Isyn_trunc = quantizeFixedPoint(8, 0, Isyn_trunc)
    Isyn = (Isyn - torch.trunc(Isyn)) + Isyn_trunc

    # Add Isyn to mem
    mem += Isyn
    mem = quantizeFixedPoint(16, 8, mem)

    spike_out = torch.zeros(mem.shape[0])
    spike_out[mem > thresh] = 1

    return spike_out, mem

def runNetworkQuantized(Wq, Bq, Threshq, Lq, config, inpt, targets):

    calc_target = []

    for i in range(inpt.shape[1]):
        x = inpt[:, i, :, :, :]
    
        mem1 = torch.zeros(config["dense1_input"] // 4)
        spike1 = torch.zeros(config["dense1_input"] // 4)

        mem2 = torch.zeros(config["dense1_input"] // 8)
        spike2 = torch.zeros(config["dense1_input"] // 8)

        mem3 = torch.zeros(config["num_classes"])
        spike3 = torch.zeros(config["num_classes"])

        pred = torch.zeros(config["num_classes"])
        
        for step in range(x.shape[0]):
        
            current_input = x[step].flatten()
        
            current1 = matMACQuantized(Wq[0], Bq[0], current_input)
            spike1, mem1 = LifQuantized(current1, mem1, Threshq[0], Lq[0], 0.5, spike1)

            current2 = matMACQuantized(Wq[1], Bq[1], spike1)
            spike2, mem2 = LifQuantized(current2, mem2, Threshq[1], Lq[1], 0.5, spike2)

            current3 = matMACQuantized(Wq[2], Bq[2], spike2)
            spike3, mem3 = LifQuantized(current3, mem3, Threshq[2], Lq[2], 0.5, spike3)

            pred += spike3

        # print("Predicted class: ", pred.argmax().item())
        # print("Correct class: ", targets[i])
        
        calc_target.append(pred.argmax().item())
    
    calc_target = torch.tensor(calc_target)
    calc_target = calc_target - targets
    correct = calc_target.shape[0] - torch.count_nonzero(calc_target).item()
    return correct

In [58]:
input_cnt = 0
correct_cnt = 0
cnt = 0
for data, targets in test_loader:
    print(cnt)
    input_cnt += targets.shape[0]
    correct_cnt += runNetworkQuantized(Wq, Bq, Threshq, Lq, config, data, targets)
    if(cnt == 15):
        break
    cnt += 1
print('Accuracy: ', correct_cnt / input_cnt)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
Accuracy:  0.9423828125
