In [1]:
import nir
from lava_rnn import from_nir
import torch
import numpy as np
import os
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from snntorch import functional as SF

seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.manual_seed(seed)
torch.use_deterministic_algorithms(True)

In [2]:
nirgraph = nir.read('braille_v2.nir')
net = from_nir(nirgraph)

found RNN subgraph, trying to parse
detected subgraph! candidates: ['lif1']
w_in: 1.0, dt: 0.0001, tau_syn: 0.0002
w_in: 1.0, dt: 0.0001, tau_syn: 0.0002222222281091009


In [3]:
### DEVICE SETTINGS
use_gpu = False

if use_gpu:
    gpu_sel = 1
    device = torch.device("cuda:"+str(gpu_sel))
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
else:
    device = torch.device("cpu")

In [4]:
### TEST DATA
test_data_path = "data/ds_test.pt"
ds_test = torch.load(test_data_path)

letter_written = ['Space', 'A', 'E', 'I', 'O', 'U', 'Y']

loss_fn = SF.ce_count_loss()

In [5]:
def val_test_loop(dataset, batch_size, net, loss_fn, device, shuffle=True, saved_state_dict=None, label_probabilities=False, regularization=None):
  
  with torch.no_grad():
    if saved_state_dict != None:
        net.load_state_dict(saved_state_dict)
    net.eval()

    loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=False)

    batch_loss = []
    batch_acc = []

    for data, labels in loader:
        data = data.to(device)
        labels = labels.to(device)

        data = data.swapaxes(1, 2)  # NCT
        spk_out, hid_rec = net(data)
        spk_out = spk_out.moveaxis(2, 0)  # TCN

        # Validation loss
        if regularization != None:
            # L1 loss on spikes per neuron from the hidden layer
            reg_loss = regularization[0]*torch.mean(torch.sum(hid_rec, 0))
            # L2 loss on total number of spikes from the hidden layer
            reg_loss = reg_loss + regularization[1]*torch.mean(torch.sum(torch.sum(hid_rec, dim=0), dim=1)**2)
            loss_val = loss_fn(spk_out, labels) + reg_loss
        else:
            loss_val = loss_fn(spk_out, labels)

        batch_loss.append(loss_val.detach().cpu().item())

        # Accuracy
        act_total_out = torch.sum(spk_out, 0)  # sum over time
        _, neuron_max_act_total_out = torch.max(act_total_out, 1)  # argmax over output units to compare to labels
        batch_acc.extend((neuron_max_act_total_out == labels).detach().cpu().numpy()) # batch_acc.append(np.mean((neuron_max_act_total_out == labels).detach().cpu().numpy()))
    
    if label_probabilities:
        log_softmax_fn = nn.LogSoftmax(dim=-1)
        log_p_y = log_softmax_fn(act_total_out)
        return [np.mean(batch_loss), np.mean(batch_acc)], torch.exp(log_p_y)
    else:
        return [np.mean(batch_loss), np.mean(batch_acc)]

In [6]:
### INFERENCE ON TEST SET

batch_size = 64

input_size = 12 
num_steps = next(iter(ds_test))[0].shape[0]

test_results = val_test_loop(ds_test, batch_size, net, loss_fn, device, shuffle=False, saved_state_dict=None, regularization=None)

print("Test accuracy: {}%".format(np.round(test_results[1]*100,2)))

torch.Size([64, 12, 256]) torch.Size([64])
torch.Size([256, 64, 7]) torch.Size([64])
torch.Size([64, 12, 256]) torch.Size([64])
torch.Size([256, 64, 7]) torch.Size([64])
torch.Size([12, 12, 256]) torch.Size([12])
torch.Size([256, 12, 7]) torch.Size([12])
Test accuracy: 14.29%
