In [1]:
import json
import os
import numpy as np
import snntorch as snn
from snntorch import functional as SF
from snntorch import surrogate
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.manual_seed(seed)
torch.use_deterministic_algorithms(True)


In [2]:
### DEVICE SETTINGS
use_gpu = False

if use_gpu:
    gpu_sel = 0
    device = torch.device("cuda:"+str(gpu_sel))
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
else:
    device = torch.device("cpu")


In [3]:
### SPECIFY THE RESET MECHANISM TO USE
reset_mechanism = "subtract" # "zero" or "subtract"
reset_delay = False


In [4]:
### OPTIMAL HYPERPARAMETERS
parameters_path = "data/parameters_noDelay_noBias_ref_{}.json".format(reset_mechanism)

with open(parameters_path) as f:
   parameters = json.load(f)

parameters["reset"] = reset_mechanism
parameters["reset_delay"] = reset_delay

regularization = [parameters["reg_l1"], parameters["reg_l2"]]


In [5]:
### TRAINED WEIGHTS
#saved_state_dict_path = "data/model_ref_{}.pt".format(reset_mechanism)
saved_state_dict_path = "data/model_noDelay_noBias_ref_{}.pt".format(reset_mechanism)
best_val_layers = torch.load(saved_state_dict_path, map_location=device)


In [6]:
### LOSS FUNCTION
loss_fn = SF.ce_count_loss()


In [7]:
### TEST DATA
test_data_path = "data/ds_test.pt"
ds_test = torch.load(test_data_path)

letter_written = ['Space', 'A', 'E', 'I', 'O', 'U', 'Y']


In [8]:
def model_build(settings, input_size, num_steps, device):

    ### Network structure (input data --> encoding -> hidden -> output)
    input_channels = int(input_size)
    num_hidden = int(settings["nb_hidden"])
    num_outputs = 7

    ### Surrogate gradient setting
    spike_grad = surrogate.fast_sigmoid(slope=int(settings["slope"]))

    ### Put things together
    class Net(nn.Module):
        def __init__(self):
            super().__init__()

            ##### Initialize layers #####
            self.fc1 = nn.Linear(input_channels, num_hidden)
            self.fc1.__setattr__("bias",None)
            #self.lif1 = snn.RLeaky(beta=settings["beta_r"], linear_features=num_hidden, spike_grad=spike_grad, reset_mechanism=settings["reset"])
            self.lif1 = snn.RSynaptic(alpha=settings["alpha_r"], beta=settings["beta_r"], linear_features=num_hidden, spike_grad=spike_grad, reset_mechanism=settings["reset"], reset_delay=settings["reset_delay"])
            self.lif1.recurrent.__setattr__("bias",None)
            ### Output layer
            self.fc2 = nn.Linear(num_hidden, num_outputs)
            self.fc2.__setattr__("bias",None)
            #self.lif2 = snn.Leaky(beta=settings["beta_out"], reset_mechanism=settings["reset"])
            self.lif2 = snn.Synaptic(alpha=settings["alpha_out"], beta=settings["beta_out"], spike_grad=spike_grad, reset_mechanism=settings["reset"], reset_delay=settings["reset_delay"])

        def forward(self, x):

            ##### Initialize hidden states at t=0 #####
            #spk1, mem1 = self.lif1.init_rleaky()
            spk1, syn1, mem1 = self.lif1.init_rsynaptic()
            #mem2 = self.lif2.init_leaky()
            syn2, mem2 = self.lif2.init_synaptic()

            # Record the spikes from the hidden layer (if needed)
            spk1_rec = [] # not necessarily needed for inference
            # Record the final layer
            spk2_rec = []
            #syn2_rec = [] # not necessarily needed for inference
            #mem2_rec = [] # not necessarily needed for inference

            for step in range(num_steps):
                ### Recurrent layer
                cur1 = self.fc1(x[step])
                #spk1, mem1 = self.lif1(cur1, spk1, mem1)
                spk1, syn1, mem1 = self.lif1(cur1, spk1, syn1, mem1)
                ### Output layer
                cur2 = self.fc2(spk1)
                #spk2, mem2 = self.lif2(cur2, mem2)
                spk2, syn2, mem2 = self.lif2(cur2, syn2, mem2)

                spk1_rec.append(spk1) # not necessarily needed for inference
                spk2_rec.append(spk2)
                #syn2_rec.append(mem2) # not necessarily needed for inference
                #mem2_rec.append(mem2) # not necessarily needed for inference

            return torch.stack(spk2_rec, dim=0), torch.stack(spk1_rec, dim=0)

    return Net().to(device)


In [13]:
def val_test_loop(dataset, batch_size, net, loss_fn, device, shuffle=True, saved_state_dict=None, label_probabilities=False, regularization=None):
  
  with torch.no_grad():
    if saved_state_dict != None:
        net.load_state_dict(saved_state_dict)
    net.eval()

    loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=False)

    batch_loss = []
    batch_acc = []

    for data, labels in loader:
        data = data.to(device).swapaxes(1, 0)
        labels = labels.to(device)

        spk_out, hid_rec = net(data)

        return spk_out, hid_rec

        # Validation loss
        if regularization != None:
            # L1 loss on spikes per neuron from the hidden layer
            reg_loss = regularization[0]*torch.mean(torch.sum(hid_rec, 0))
            # L2 loss on total number of spikes from the hidden layer
            reg_loss = reg_loss + regularization[1]*torch.mean(torch.sum(torch.sum(hid_rec, dim=0), dim=1)**2)
            loss_val = loss_fn(spk_out, labels) + reg_loss
        else:
            loss_val = loss_fn(spk_out, labels)

        batch_loss.append(loss_val.detach().cpu().item())

        # Accuracy
        act_total_out = torch.sum(spk_out, 0)  # sum over time
        _, neuron_max_act_total_out = torch.max(act_total_out, 1)  # argmax over output units to compare to labels
        batch_acc.extend((neuron_max_act_total_out == labels).detach().cpu().numpy()) # batch_acc.append(np.mean((neuron_max_act_total_out == labels).detach().cpu().numpy()))
    
    if label_probabilities:
        log_softmax_fn = nn.LogSoftmax(dim=-1)
        log_p_y = log_softmax_fn(act_total_out)
        return [np.mean(batch_loss), np.mean(batch_acc)], torch.exp(log_p_y)
    else:
        return [np.mean(batch_loss), np.mean(batch_acc)]


In [16]:
### INFERENCE ON TEST SET

batch_size = 1

input_size = 12 
num_steps = next(iter(ds_test))[0].shape[0]

net = model_build(parameters, input_size, num_steps, device)

spk2, spk1 = val_test_loop(ds_test, batch_size, net, loss_fn, device, shuffle=False, saved_state_dict=best_val_layers, regularization=regularization)
# print("Test accuracy: {}%".format(np.round(test_results[1]*100,2)))
# np.save('snntorch_accuracy_noDelay_noBias_subtract.npy', test_results[1])


In [17]:
spk1.shape, spk1.sum(), spk2.shape, spk2.sum()


(torch.Size([256, 1, 40]), tensor(118.), torch.Size([256, 1, 7]), tensor(132.))

In [20]:
np.save('snntorch_activity_noDelay_noBias_subtract.npy', spk1.squeeze(1).detach().numpy())


(256, 40)

In [11]:
### INFERENCE ON INDIVIDUAL TEST SAMPLES

Ns = 10

for ii in range(Ns):

    single_sample = next(iter(DataLoader(ds_test, batch_size=1, shuffle=True)))
    _, lbl_probs = val_test_loop(TensorDataset(single_sample[0],single_sample[1]), 1, net, loss_fn, device, shuffle=False, saved_state_dict=best_val_layers, label_probabilities=True, regularization=regularization)
    print("Single-sample inference {}/{} from test set:".format(ii+1,Ns))
    print("Sample: {} \tPrediction: {}".format(letter_written[single_sample[1]],letter_written[torch.max(lbl_probs.cpu(),1)[1]]))
    print("Label probabilities (%): {}\n".format(np.round(np.array(lbl_probs.detach().cpu().numpy())*100,2)))


Single-sample inference 1/10 from test set:
Sample: Space 	Prediction: Space
Label probabilities (%): [[100.   0.   0.   0.   0.   0.   0.]]

Single-sample inference 2/10 from test set:
Sample: I 	Prediction: I
Label probabilities (%): [[  0.   0.   0. 100.   0.   0.   0.]]

Single-sample inference 3/10 from test set:
Sample: Y 	Prediction: Y
Label probabilities (%): [[  0.   0.   0.   0.   0.   0. 100.]]

Single-sample inference 4/10 from test set:
Sample: O 	Prediction: O
Label probabilities (%): [[0.000e+00 0.000e+00 9.000e-02 0.000e+00 9.991e+01 0.000e+00 0.000e+00]]

Single-sample inference 5/10 from test set:
Sample: Y 	Prediction: Y
Label probabilities (%): [[  0.   0.   0.   0.   0.   0. 100.]]

Single-sample inference 6/10 from test set:
Sample: Y 	Prediction: Y
Label probabilities (%): [[  0.   0.   0.   0.   0.   0. 100.]]

Single-sample inference 7/10 from test set:
Sample: Space 	Prediction: Space
Label probabilities (%): [[9.991e+01 0.000e+00 0.000e+00 0.000e+00 0.000e+00

In [None]:
# NOTE: this requires snntorch/nir (PR) and nirtorch/master (unreleased)
from snntorch import export_nir


nir_graph = export_nir.to_nir(net, ds_test[0][0])

print('nodes:')
for nodekey, node in nir_graph.nodes.items():
    print('\t', nodekey, node.__class__.__name__)
print('edges:')
for edge in nir_graph.edges:
    print('\t', edge)

import nir
nir.write('braille.nir', nir_graph)
