# SNN Weight & Threshold Balancing - Convolutional

In [1]:
import torch
from torchvision.transforms import ToTensor
from torchvision.datasets import MNIST

import snntorch as snn
import snntorch.functional as SF

from spike_nets import Conv_Net, Conv_SNN, Conv_Count_Net

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [2]:
# data alread 0-1 normalised, simply convert to tensor
transform_data = ToTensor()

# Load the data
batch_size = 100
train_dataset = MNIST(root = './mnist/', train = True, download = True, transform=transform_data)
test_dataset = MNIST(root = './mnist/', train = False, download = True, transform=transform_data)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = batch_size, shuffle = False)

## Convolutional ReLU Neural Network

In [9]:
def init_weights(m):
    if isinstance(m, torch.nn.Linear) or isinstance(m, torch.nn.Conv2d):
        torch.nn.init.uniform_(m.weight, -0.1, 0.1)

conv_net = Conv_Net(1, [12, 64], [1024, 10]).to(device)
conv_net.apply(init_weights)
optimiser = torch.optim.Adam(conv_net.parameters())

# optimiser used in the original paper seems to kill the gradients, so we're just going to use adam
# optimiser = torch.optim.SGD(conv_net.parameters(), lr=.01, momentum=0.5)

conv_net.train()

Conv_Net(
  (in_layer): Conv2d(1, 12, kernel_size=(5, 5), stride=(1, 1), bias=False)
  (h1_layer): Conv2d(12, 64, kernel_size=(5, 5), stride=(1, 1), bias=False)
  (h2_layer): Linear(in_features=1024, out_features=10, bias=False)
  (pooling): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (dropout): Dropout(p=0.5, inplace=False)
  (activator): ReLU()
)

In [None]:
# Training model
num_epochs = 15
for epoch in range(num_epochs):
    # Go trough all samples in train dataset
    for i, (images, labels) in enumerate(train_loader):
        # Get from dataloader and send to device
        images = images.to(device)
        labels = labels.to(device)
        # Forward pass
        outputs = conv_net(images)
        # Compute loss
        loss = torch.nn.functional.cross_entropy(outputs, labels)
        # Backward and optimize
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()
        # Display
        if (i+1) % 100 == 0:
            print (f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

In [11]:
# Evaluate model accuracy on test after training
# Set model in eval mode!
conv_net.eval()
# Evaluate
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        # Get images and labels from test loader
        images = images.to(device)
        labels = labels.to(device)
        # Forward pass and predict class using max
        outputs = conv_net(images)
        _, predicted = torch.max(outputs.data, 1)
        # Check if predicted class matches label
        # and count number of correct predictions
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
# Compute final accuracy and display
accuracy = correct/total
print(f'Evaluation after training, test accuracy: {accuracy:.4f}')

Evaluation after training, test accuracy: 0.9894


In [12]:
conv_net.save_parameters("params/conv")

## Convolutional Spiking Neural Network

In [19]:
step_count = 200
conv_snn = Conv_SNN(1, [12, 64], [1024, 10], 5, beta=1, threshold=4, steps=step_count, rate=400).to(device)
conv_snn.load_parameters("params/conv")
conv_snn.eval()

Conv_SNN(
  (in_layer): Conv2d(1, 12, kernel_size=(5, 5), stride=(1, 1), bias=False)
  (h1_layer): Conv2d(12, 64, kernel_size=(5, 5), stride=(1, 1), bias=False)
  (h2_layer): Linear(in_features=1024, out_features=10, bias=False)
  (pooling): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (in_active): Leaky()
  (h1_active): Leaky()
  (h2_active): Leaky()
)

In [20]:
with torch.no_grad():
    total = 0
    correct = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        spk_out, _ = conv_snn(images)

        correct += SF.accuracy_rate(spk_out, labels) * spk_out.size(1)
        total += spk_out.size(1)

accuracy = correct/total
print(f'Evaluation after training, test accuracy: {accuracy:.4f}')

Evaluation after training, test accuracy: 0.9885


## Convolutional Neural Network Model Normalisation

In [4]:
def model_norm(model: torch.nn.Module, include_last: bool = False):
    """
    takes any pytorch module or network and does model normalisation across all layers.
    by default, the last layer is not normalised.
    """
    param_list = [param for param in model.parameters() if param.requires_grad]
    param_geny = iter(param_list)

    param_count = len(param_list)
    layers = param_count if include_last else param_count - 1
    layer_scales = []

    for _ in range(layers):
        neurons = next(param_geny)
        max_pos_in = 0

        if neurons.dim() >= 2:  # if layer is linear/conv

            for neuron in neurons:    
                input_sum = torch.sum(torch.clamp(neuron, min=0)) # sum all positive parameters
                max_pos_in = max(max_pos_in, input_sum.item())

        if max_pos_in > 0:
            neurons.data /= max_pos_in  # without .data this becomes out-of-place for some reason

        layer_scales.append(max_pos_in)

    return layer_scales

In [5]:
# to load and rescale from scratch
step_count = 200
conv_mn_snn = Conv_SNN(1, [12, 64], [1024, 10], 5, beta=1).to(device)
conv_mn_snn.load_parameters("params/conv")
conv_mn_snn.eval()

scaling_factors = model_norm(conv_mn_snn, False)
print(scaling_factors)
conv_mn_snn.save_parameters("params/conv_model_norm")

[3.0468525886535645, 17.50514030456543]


In [21]:
# to load rescaled weights from file
max_rate = 500
threshold = 1
simulation_time = 0.5

conv_mn_snn = Conv_SNN(1, [12, 64], [1024, 10], 5, beta=1, threshold=threshold, steps=int(max_rate * simulation_time), rate=max_rate).to(device)
conv_mn_snn.load_parameters("params/conv_model_norm")
conv_mn_snn.eval()

Conv_SNN(
  (in_layer): Conv2d(1, 12, kernel_size=(5, 5), stride=(1, 1), bias=False)
  (h1_layer): Conv2d(12, 64, kernel_size=(5, 5), stride=(1, 1), bias=False)
  (h2_layer): Linear(in_features=1024, out_features=10, bias=False)
  (pooling): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (in_active): Leaky()
  (h1_active): Leaky()
  (h2_active): Leaky()
)

In [13]:
with torch.no_grad():
    total = 0
    correct = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        spk_out, _ = conv_mn_snn(images)

        correct += SF.accuracy_rate(spk_out, labels) * spk_out.size(1)
        total += spk_out.size(1)

accuracy = correct/total
print(f'Evaluation after training, test accuracy: {accuracy:.4f}')

Evaluation after training, test accuracy: 0.9882


## Convolutional Neural Network Data Normalisation

In [5]:
# get maximum neuron activations
count_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 1, shuffle = False)
conv_count_net = Conv_Count_Net(1, [12, 64], [1024, 10], 5).to(device)
conv_count_net.load_parameters("params/conv")
conv_count_net.eval()

with torch.no_grad():

    for images, labels in count_loader:

        images = images.to(device)
        labels = labels.to(device)

        _ = conv_count_net(images)

max_activations = [torch.max(conv_count_net.maxin_act), torch.max(conv_count_net.maxh1_act), torch.max(conv_count_net.maxh2_act)]

In [13]:
# algorithm for data normalisation
def data_norm(model: torch.nn.Module, activations: list[torch.Tensor], include_last: bool=True):
    """
    takes a pytorch module or network and does data normalisation.
    requires the list of maximum activations from each layer.
    unlike model normalisation, also normalises the last layer by default.
    """
    param_list = [param for param in model.parameters() if param.requires_grad]
    param_geny = iter(param_list)

    param_count = len(param_list)
    layers = param_count if include_last else param_count - 1
    layer_scales = []

    previous_factor = 1

    for i in range(layers):
        neurons = next(param_geny)
        max_weight = 0

        if neurons.dim() >= 2:

            for neuron in neurons:
                # grab maximum single weight across input connections
                max_weight = max(max_weight, torch.max(neuron))

        if max_weight > 0:
            scale_factor = max(max_weight, activations[i])
            applied_factor = scale_factor / previous_factor

        # rescale all weights wrt applied factor
        neurons.data = neurons / applied_factor # without .data this becomes out-of-place for some reason
        previous_factor = scale_factor
        layer_scales.append(applied_factor)
    
    return layer_scales

In [16]:
# to load and rescale from scratch
conv_dn_snn = Conv_SNN(1, [12, 64], [1024, 10], 5, beta=1).to(device)
conv_dn_snn.load_parameters("params/conv")
conv_dn_snn.eval()

scaling_factors = data_norm(conv_dn_snn, max_activations)
print(scaling_factors)
conv_dn_snn.save_parameters("params/conv_data_norm")

tensor(2.0191)
tensor(5.0003)
tensor(27.3178)
[tensor(2.0191), tensor(2.4765), tensor(5.4632)]


In [26]:
# to load rescaled weights from file
max_rate = 400
threshold = 1
simulation_time = 0.5

conv_dn_snn = Conv_SNN(1, [12, 64], [1024, 10], 5, beta=1, threshold=threshold, steps=int(max_rate * simulation_time), rate=max_rate).to(device)
conv_dn_snn.load_parameters("params/conv_data_norm")
conv_dn_snn.eval()

Conv_SNN(
  (in_layer): Conv2d(1, 12, kernel_size=(5, 5), stride=(1, 1), bias=False)
  (h1_layer): Conv2d(12, 64, kernel_size=(5, 5), stride=(1, 1), bias=False)
  (h2_layer): Linear(in_features=1024, out_features=10, bias=False)
  (pooling): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (in_active): Leaky()
  (h1_active): Leaky()
  (h2_active): Leaky()
)

In [18]:
with torch.no_grad():
    total = 0
    correct = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        spk_out, _ = conv_dn_snn(images)

        correct += SF.accuracy_rate(spk_out, labels) * spk_out.size(1)
        total += spk_out.size(1)

accuracy = correct/total
print(f'Evaluation after training, test accuracy: {accuracy:.4f}')

Evaluation after training, test accuracy: 0.9881
