# Fast-Classifying, High-Accuracy Spiking Deep Networks

In [1]:
import torch
from torchvision.transforms import ToTensor#, Compose, Normalize
from torchvision.datasets import MNIST

import snntorch as snn
import snntorch.functional as SF

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [2]:
# data alread 0-1 normalised, simply convert to tensor
transform_data = ToTensor()

# Load the data
batch_size = 100
train_dataset = MNIST(root = './mnist/', train = True, download = True, transform=transform_data)
test_dataset = MNIST(root = './mnist/', train = False, download = True, transform=transform_data)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = batch_size, shuffle = False)

## FC ReLU Neural Network

In [3]:
# fully connected neural network
class FC_Net(torch.nn.Module):
    def __init__(self, n_x: int, n_h: list, n_y: int):
        super().__init__()

        self.in_layer = torch.nn.Linear(n_x, n_h[0], bias=False)
        self.h1_layer = torch.nn.Linear(n_h[0], n_h[1], bias=False)
        self.h2_layer = torch.nn.Linear(n_h[1], n_y, bias=False)
        self.dropout = torch.nn.Dropout()
        self.activator = torch.nn.ReLU()

    def forward(self, x):
        # Flatten images
        x = x.view(x.size(0), -1)
        
        inp = self.dropout(self.activator(self.in_layer(x)))
        h1 = self.dropout(self.activator(self.h1_layer(inp)))
        y = self.activator(self.h2_layer(h1))

        return y
    
    def save_parameters(self, path: str):
        torch.save(self.in_layer, path + "0.pt")
        torch.save(self.h1_layer, path + "1.pt")
        torch.save(self.h2_layer, path + "2.pt")

    def load_parameters(self, path: str):
        self.in_layer = torch.load(path + "0.pt", weights_only=False)
        self.h1_layer = torch.load(path + "1.pt", weights_only=False)
        self.h2_layer = torch.load(path + "2.pt", weights_only=False)

In [7]:
def init_weights(m):
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.uniform_(m.weight, -0.1, 0.1)

fc_net = FC_Net(784, [1200, 1200], 10).to(device)
fc_net.apply(init_weights)
optimiser = torch.optim.Adam(fc_net.parameters())

# optimiser used in the original paper seems to kill the gradients, so we're just going to use adam
# optimiser = torch.optim.SGD(fc_net.parameters(), lr=.01, momentum=0.5)

fc_net.train()

FC_Net(
  (in_layer): Linear(in_features=784, out_features=1200, bias=False)
  (h1_layer): Linear(in_features=1200, out_features=1200, bias=False)
  (h2_layer): Linear(in_features=1200, out_features=10, bias=False)
  (dropout): Dropout(p=0.5, inplace=False)
  (activator): ReLU()
)

In [None]:
# Training model
num_epochs = 15
for epoch in range(num_epochs):
    # Go trough all samples in train dataset
    for i, (images, labels) in enumerate(train_loader):
        # Get from dataloader and send to device
        images = images.to(device)
        labels = labels.to(device)
        # Forward pass
        outputs = fc_net(images)
        # Compute loss
        loss = torch.nn.functional.cross_entropy(outputs, labels)
        # Backward and optimize
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()
        # Display
        if (i+1) % 100 == 0:
            print (f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

In [9]:
# Evaluate model accuracy on test after training
# Set model in eval mode!
fc_net.eval()
# Evaluate
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        # Get images and labels from test loader
        images = images.to(device)
        labels = labels.to(device)
        # Forward pass and predict class using max
        outputs = fc_net(images)
        _, predicted = torch.max(outputs.data, 1)
        # Check if predicted class matches label
        # and count number of correct predictions
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
# Compute final accuracy and display
accuracy = correct/total
print(f'Evaluation after training, test accuracy: {accuracy:.4f}')

Evaluation after training, test accuracy: 0.9823


In [10]:
fc_net.save_parameters("params/linear")

## FC Spiking Neural Network

In [9]:
# to generate poisson spike trains
def to_poisson_spikes(data, steps: int, max_rate: int=200):

    # Rescale factor for Poisson distribution
    rescale_factor = max_rate / steps
    rand_vals = torch.rand(steps, *data.shape, device=data.device)

    # Compare against intensity to generate spikes
    spikes = (rand_vals < data * rescale_factor).float()
    
    return spikes

# fully connected neural network
class FC_SNN(torch.nn.Module):
    def __init__(self, n_x: int, n_h: list, n_y: int, beta: float=0, threshold: float=1, steps: int=100, rate: int=200):
        super().__init__()

        self.in_layer = torch.nn.Linear(n_x, n_h[0], bias=False)
        self.h1_layer = torch.nn.Linear(n_h[0], n_h[1], bias=False)
        self.h2_layer = torch.nn.Linear(n_h[1], n_y, bias=False)
        self.in_active = snn.Leaky(beta=beta, threshold=threshold)
        self.h1_active = snn.Leaky(beta=beta, threshold=threshold)
        self.h2_active = snn.Leaky(beta=beta, threshold=threshold)

        self.steps = steps
        self.rate = rate

    def forward(self, x):
        # Flatten images
        x = x.view(x.size(0), -1)
        x = to_poisson_spikes(x, self.steps, self.rate)
        # x = snn.spikegen.rate(x, self.steps)

        memin = self.in_active.reset_mem()
        memh1 = self.h1_active.reset_mem()
        memh2 = self.h2_active.reset_mem()

        out_spikes = []
        memh2_mem = []

        for step in x:
            curin = self.in_layer(step)
            spkin, memin = self.in_active(curin, memin)
            curh1 = self.h1_layer(spkin)
            spkh1, memh1 = self.h1_active(curh1, memh1)
            curh2 = self.h2_layer(spkh1)
            spkh2, memh2 = self.h2_active(curh2, memh2)

            out_spikes.append(spkh2)
            memh2_mem.append(memh2)

        return torch.stack(out_spikes), torch.stack(memh2_mem)

    def save_parameters(self, path: str):
        torch.save(self.in_layer, path + "0.pt")
        torch.save(self.h1_layer, path + "1.pt")
        torch.save(self.h2_layer, path + "2.pt")

    def load_parameters(self, path: str):
        self.in_layer = torch.load(path + "0.pt", weights_only=False)
        self.h1_layer = torch.load(path + "1.pt", weights_only=False)
        self.h2_layer = torch.load(path + "2.pt", weights_only=False)

In [5]:
step_count = 100
fc_snn = FC_SNN(784, [1200, 1200], 10, beta=1, threshold=4, steps=step_count, rate=200).to(device)
fc_snn.load_parameters("params/linear")
fc_snn.eval()

FC_SNN(
  (in_layer): Linear(in_features=784, out_features=1200, bias=False)
  (h1_layer): Linear(in_features=1200, out_features=1200, bias=False)
  (h2_layer): Linear(in_features=1200, out_features=10, bias=False)
  (in_active): Leaky()
  (h1_active): Leaky()
  (h2_active): Leaky()
)

In [64]:
with torch.no_grad():
    total = 0
    correct = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        spk_out, _ = fc_snn(images)

        correct += SF.accuracy_rate(spk_out, labels) * spk_out.size(1)
        total += spk_out.size(1)

accuracy = correct/total
print(f'Evaluation after training, test accuracy: {accuracy:.4f}')

Evaluation after training, test accuracy: 0.9807


## FC Neural Network Model Normalisation

In [5]:
# model based normalisation algorithm
def model_norm(model: torch.nn.Module, include_last: bool=False):
    """
    takes any pytorch module or network and does model normalisation across all layers.
    by default, the last layer is not normalised.
    """
    param_count = len([param for param in model.parameters()])
    param_geny = model.parameters()
    layers = param_count if include_last else param_count - 1
    layer_scales = []

    for _ in range(layers):
        neurons = next(param_geny)
        max_pos_in = 0

        for neuron in neurons:
            input_sum = 0

            for weight in neuron:
                # compute maximum possible input size to a particular neuron
                input_sum += max(0, weight)

            # save maximum possible input across all neurons
            max_pos_in = max(max_pos_in, input_sum)
        
        # rescale all weights wrt maximum possible input
        neurons.data = neurons / max_pos_in # without .data this becomes out-of-place for some reason
        layer_scales.append(max_pos_in)
    
    return layer_scales

In [None]:
# to load and rescale from scratch
fc_snn = FC_SNN(784, [1200, 1200], 10, beta=1).to(device)
fc_snn.load_parameters("params/linear")
fc_snn.eval()

scaling_factors = model_norm(fc_snn, False)
print(scaling_factors)
fc_snn.save_parameters("params/model_norm")

In [9]:
# to load rescaled weights from file
max_rate = 400
threshold = 0.25
simulation_time = 0.5

fc_snn = FC_SNN(784, [1200, 1200], 10, beta=1, threshold=threshold, steps=int(max_rate * simulation_time), rate=max_rate).to(device)
fc_snn.load_parameters("params/model_norm")
fc_snn.eval()

FC_SNN(
  (in_layer): Linear(in_features=784, out_features=1200, bias=False)
  (h1_layer): Linear(in_features=1200, out_features=1200, bias=False)
  (h2_layer): Linear(in_features=1200, out_features=10, bias=False)
  (in_active): Leaky()
  (h1_active): Leaky()
  (h2_active): Leaky()
)

In [10]:
with torch.no_grad():
    total = 0
    correct = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        spk_out, _ = fc_snn(images)

        correct += SF.accuracy_rate(spk_out, labels) * spk_out.size(1)
        total += spk_out.size(1)

accuracy = correct/total
print(f'Evaluation after training, test accuracy: {accuracy:.4f}')

Evaluation after training, test accuracy: 0.9803


## FC Neural Network Data Normalisation

In [3]:
# counting feed forward network for data normalisation
class FC_Count_Net(torch.nn.Module):
    def __init__(self, n_x: int, n_h: list, n_y: int):
        super().__init__()

        self.dims = [n_x, n_h, n_y]

        self.in_layer = torch.nn.Linear(n_x, n_h[0], bias=False)
        self.h1_layer = torch.nn.Linear(n_h[0], n_h[1], bias=False)
        self.h2_layer = torch.nn.Linear(n_h[1], n_y, bias=False)
        self.activator = torch.nn.ReLU()

        # to store maximum activations
        self.maxin_act = torch.zeros([n_h[0]])
        self.maxh1_act = torch.zeros([n_h[1]])
        self.maxh2_act = torch.zeros([n_y])

    def forward(self, x):
        # Flatten images
        x = x.view(x.size(0), -1)
        
        inp = self.activator(self.in_layer(x))
        self.maxin_act = torch.maximum(self.maxin_act, inp)

        h1 = self.activator(self.h1_layer(inp))
        self.maxh1_act = torch.maximum(self.maxh1_act, h1)

        y = self.activator(self.h2_layer(h1))
        self.maxh2_act = torch.maximum(self.maxh2_act, y)

        return y
    
    def reset_max_count(self):
        """reset maximum activation memory"""
        self.maxin_act = torch.zeros([self.dims[1][0]])
        self.maxh1_act = torch.zeros([self.dims[1][1]])
        self.maxh2_act = torch.zeros([self.dims[2]])
    
    def save_parameters(self, path: str):
        torch.save(self.in_layer, path + "0.pt")
        torch.save(self.h1_layer, path + "1.pt")
        torch.save(self.h2_layer, path + "2.pt")

    def load_parameters(self, path: str):
        self.in_layer = torch.load(path + "0.pt", weights_only=False)
        self.h1_layer = torch.load(path + "1.pt", weights_only=False)
        self.h2_layer = torch.load(path + "2.pt", weights_only=False)

In [6]:
# get maximum neuron activations
count_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 1, shuffle = False)
fc_count_net = FC_Count_Net(784, [1200, 1200], 10).to(device)
fc_count_net.load_parameters("params/linear")
fc_count_net.eval()

with torch.no_grad():

    for images, labels in count_loader:

        images = images.to(device)
        labels = labels.to(device)

        _ = fc_count_net(images)

max_activations = [torch.max(fc_count_net.maxin_act), torch.max(fc_count_net.maxh1_act), torch.max(fc_count_net.maxh2_act)]

tensor([[3.7076, 3.6518, 1.7070,  ..., 3.2691, 7.0213, 4.6011]])
tensor([[6.4801, 2.1321, 9.1305,  ..., 8.6561, 7.5289, 8.5343]])
tensor([[68.3983, 42.4817, 80.4577, 64.8406, 60.4998, 58.6631, 74.8351, 70.2156,
         34.5703, 37.1122]])


In [7]:
# algorithm for data normalisation
def data_norm(model: torch.nn.Module, activations: list[torch.Tensor], include_last: bool=True):
    """
    takes a pytorch module or network and does data normalisation.
    requires the list of maximum activations from each layer.
    unlike model normalisation, also normalises the last layer by default.
    """
    param_count = len([param for param in model.parameters()])
    param_geny = model.parameters()
    layers = param_count if include_last else param_count - 1
    layer_scales = []

    previous_factor = 1

    for i in range(layers):
        neurons = next(param_geny)
        max_weight = 0
        max_active = 0

        for neuron in neurons:

            for weight in neuron:
                # grab maximum input weight
                max_weight = max(max_weight, weight)

        scale_factor = max(max_weight, activations[i])
        applied_factor = scale_factor / previous_factor

        # rescale all weights wrt applied factor
        neurons.data = neurons / applied_factor # without .data this becomes out-of-place for some reason
        previous_factor = scale_factor
        layer_scales.append(applied_factor)
    
    return layer_scales

In [11]:
# to load and rescale from scratch
fc_snn = FC_SNN(784, [1200, 1200], 10, beta=1).to(device)
fc_snn.load_parameters("params/linear")
fc_snn.eval()

scaling_factors = data_norm(fc_snn, [fc_count_net.maxin_act, fc_count_net.maxh1_act, fc_count_net.maxh2_act])
print(scaling_factors)
fc_snn.save_parameters("params/data_norm")

[tensor(8.4872), tensor(2.1156), tensor(4.4810)]


In [17]:
# to load rescaled weights from file
max_rate = 400
threshold = 1
simulation_time = 0.5

fc_snn = FC_SNN(784, [1200, 1200], 10, beta=1, threshold=threshold, steps=int(max_rate * simulation_time), rate=max_rate).to(device)
fc_snn.load_parameters("params/data_norm")
fc_snn.eval()

FC_SNN(
  (in_layer): Linear(in_features=784, out_features=1200, bias=False)
  (h1_layer): Linear(in_features=1200, out_features=1200, bias=False)
  (h2_layer): Linear(in_features=1200, out_features=10, bias=False)
  (in_active): Leaky()
  (h1_active): Leaky()
  (h2_active): Leaky()
)

In [18]:
with torch.no_grad():
    total = 0
    correct = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        spk_out, _ = fc_snn(images)

        correct += SF.accuracy_rate(spk_out, labels) * spk_out.size(1)
        total += spk_out.size(1)

accuracy = correct/total
print(f'Evaluation after training, test accuracy: {accuracy:.4f}')

Evaluation after training, test accuracy: 0.9815
