# SNN Weight & Threshold Balancing - Fully Connected

The code for training, transferring, and normalising the weights of a fully connected ReLU network to a spiking neural network.

In [12]:
# relevant imports

import torch
from torchvision.transforms import ToTensor
from torchvision.datasets import MNIST

import snntorch as snn
import snntorch.functional as SF

from spike_nets import FC_Net, FC_SNN, FC_Count_Net
from norms import model_norm, data_norm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [5]:
# data alread 0-1 normalised, simply convert to tensor
transform_data = ToTensor()

# load the MNIST dataset
batch_size = 100
train_dataset = MNIST(root = './mnist/', train = True, download = True, transform=transform_data)
test_dataset = MNIST(root = './mnist/', train = False, download = True, transform=transform_data)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = batch_size, shuffle = False)
print(len(train_dataset), len(test_dataset))

60000 10000


## FC ReLU Neural Network

This is the code to train the basic FC ReLU network. Weights can be saved to `params/`. For first time transfer of weights to a spiking network, this must be run.

In [3]:
# the original paper used a uniform [-0.1, 0.1] initialiser, so we will too
def init_weights(m):
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.uniform_(m.weight, -0.1, 0.1)

fc_net = FC_Net(784, [1200, 1200], 10).to(device)
fc_net.apply(init_weights)
optimiser = torch.optim.Adam(fc_net.parameters())

# optimiser used in the original paper seems to kill the gradients, so we're just going to use adam
# optimiser = torch.optim.SGD(fc_net.parameters(), lr=.01, momentum=0.5)

fc_net.train()

FC_Net(
  (in_layer): Linear(in_features=784, out_features=1200, bias=False)
  (h1_layer): Linear(in_features=1200, out_features=1200, bias=False)
  (h2_layer): Linear(in_features=1200, out_features=10, bias=False)
  (dropout): Dropout(p=0.5, inplace=False)
  (activator): ReLU()
)

In [None]:
# training the model
num_epochs = 15
for epoch in range(num_epochs):

    for i, (images, labels) in enumerate(train_loader):
        # get data from dataloader and send to device
        images = images.to(device)
        labels = labels.to(device)

        # forward pass
        outputs = fc_net(images)
        loss = torch.nn.functional.cross_entropy(outputs, labels)

        # backward pass and optimize
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()

        if (i+1) % 100 == 0:
            print (f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

In [5]:
# model evaluation on test set
fc_net.eval()

with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:

        images = images.to(device)
        labels = labels.to(device)

        # forward pass and predict using max()
        outputs = fc_net(images)
        _, predicted = torch.max(outputs.data, 1)
        
        # check prediction against ground truth
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

# compute final accuracy
accuracy = correct/total
print(f'Evaluation after training, test accuracy: {accuracy:.4f}')

Evaluation after training, test accuracy: 0.9825


In [10]:
fc_net.save_parameters("params/linear")

## FC Spiking Neural Network

Code to load weights from an FC network to a spiking neural network. Parameters loaded from `params/`. Run above cells to generate weights by training the ordinary FC network.

In [15]:
# load spiking neural network with parameters from ReLU network
threshold = 3
rate = 45
fc_snn = FC_SNN(784, [1200, 1200], 10, beta=1, threshold=threshold, steps=int(rate / 2), rate=45).to(device)
fc_snn.load_parameters("params/linear")
fc_snn.eval()

FC_SNN(
  (in_layer): Linear(in_features=784, out_features=1200, bias=False)
  (h1_layer): Linear(in_features=1200, out_features=1200, bias=False)
  (h2_layer): Linear(in_features=1200, out_features=10, bias=False)
  (in_active): Leaky()
  (h1_active): Leaky()
  (h2_active): Leaky()
)

In [16]:
# evaluate model with minimal tweaking
with torch.no_grad():
    total = 0
    correct = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        spk_out, _ = fc_snn(images)

        correct += SF.accuracy_rate(spk_out, labels) * spk_out.size(1)
        total += spk_out.size(1)

accuracy = correct/total
print(f'Evaluation after training, test accuracy: {accuracy:.4f}')

Evaluation after training, test accuracy: 0.9806


## FC Neural Network Model Normalisation

Code to run and test model normalisation for a fully connected SNN.

In [4]:
# to load parameters and normalise from scratch
fc_mn_snn = FC_SNN(784, [1200, 1200], 10, beta=1).to(device)
fc_mn_snn.load_parameters("params/linear")
fc_mn_snn.eval()

scaling_factors = model_norm(fc_mn_snn, False)
print(scaling_factors)
fc_mn_snn.save_parameters("params/model_norm")

[32.103599548339844, 52.59414291381836]


In [26]:
# to load normalised weights from file
max_rate = 200
threshold = 0.25
simulation_time = 0.5

fc_mn_snn = FC_SNN(784, [1200, 1200], 10, beta=1, threshold=threshold, steps=int(max_rate * simulation_time), rate=max_rate).to(device)
fc_mn_snn.load_parameters("params/model_norm")
fc_mn_snn.eval()

FC_SNN(
  (in_layer): Linear(in_features=784, out_features=1200, bias=False)
  (h1_layer): Linear(in_features=1200, out_features=1200, bias=False)
  (h2_layer): Linear(in_features=1200, out_features=10, bias=False)
  (in_active): Leaky()
  (h1_active): Leaky()
  (h2_active): Leaky()
)

In [27]:
# evaluate model normalised model
with torch.no_grad():
    total = 0
    correct = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        spk_out, _ = fc_mn_snn(images)

        correct += SF.accuracy_rate(spk_out, labels) * spk_out.size(1)
        total += spk_out.size(1)

accuracy = correct/total
print(f'Evaluation after training, test accuracy: {accuracy:.4f}')

Evaluation after training, test accuracy: 0.9776


## FC Neural Network Data Normalisation

Code to run and test data normalisation on an FC SNN.

In [5]:
# get maximum neuron activations
count_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 1, shuffle = False)
fc_count_net = FC_Count_Net(784, [1200, 1200], 10).to(device) # initialise special counting activation counting network
fc_count_net.load_parameters("params/linear")
fc_count_net.eval()

with torch.no_grad():

    for images, labels in count_loader:

        images = images.to(device)
        labels = labels.to(device)

        _ = fc_count_net(images)

max_activations = [torch.max(fc_count_net.maxin_act), torch.max(fc_count_net.maxh1_act), torch.max(fc_count_net.maxh2_act)]

In [10]:
# to load and rescale from scratch
fc_dn_snn = FC_SNN(784, [1200, 1200], 10, beta=1).to(device)
fc_dn_snn.load_parameters("params/linear")
fc_dn_snn.eval()

scaling_factors = data_norm(fc_dn_snn, max_activations)
print(scaling_factors)
fc_dn_snn.save_parameters("params/data_norm")

[tensor(8.4872), tensor(2.1156), tensor(4.4810)]


In [28]:
# to load rescaled weights from file
max_rate = 80
threshold = 0.7
simulation_time = 0.5

fc_dn_snn = FC_SNN(784, [1200, 1200], 10, beta=1, threshold=threshold, steps=int(max_rate * simulation_time), rate=max_rate).to(device)
fc_dn_snn.load_parameters("params/data_norm")
fc_dn_snn.eval()

FC_SNN(
  (in_layer): Linear(in_features=784, out_features=1200, bias=False)
  (h1_layer): Linear(in_features=1200, out_features=1200, bias=False)
  (h2_layer): Linear(in_features=1200, out_features=10, bias=False)
  (in_active): Leaky()
  (h1_active): Leaky()
  (h2_active): Leaky()
)

In [29]:
# evaluate data normalised model
with torch.no_grad():
    total = 0
    correct = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        spk_out, _ = fc_dn_snn(images)

        correct += SF.accuracy_rate(spk_out, labels) * spk_out.size(1)
        total += spk_out.size(1)

accuracy = correct/total
print(f'Evaluation after training, test accuracy: {accuracy:.4f}')

Evaluation after training, test accuracy: 0.9808
