In [1]:
import json
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
import snntorch
import nirtorch
from snntorch import import_nirtorch
import matplotlib.pyplot as plt
import nir


In [2]:
test_data_path = "data/ds_test.pt"
ds_test = torch.load(test_data_path)
device = "cpu"


## save activities

In [3]:
nir_graph = nir.read("braille_noDelay_bias_zero.nir")
net = import_nirtorch.from_nir(nir_graph)
h_state = nirtorch.from_nir.GraphExecutorState(
    state={
        'lif1': net._modules['lif1'].init_rsynaptic(),  # 3-tuple: spk, syn, mem
        'lif2': net._modules['lif2'].init_synaptic(),  # 2-tuple: syn, mem
    }
)
spk_out_arr = []
h_state_arr = []
data_snn = ds_test[0][0]
print(data_snn.shape)
for t in range(data_snn.shape[0]):
    spk_out_snn, h_state = net(data_snn[t], h_state)
    spk_out_arr.append(spk_out_snn)
    h_state_arr.append(h_state)
spk_out_arr = torch.stack(spk_out_arr, dim=0)
print(spk_out_arr.shape)
spk_lif1_arr = [h_state.cache['lif1'] for h_state in h_state_arr]
spk_lif1_arr = torch.stack(spk_lif1_arr, dim=0).detach().numpy()
print(spk_lif1_arr.shape)
np.save('snntorch_activity_noDelay_bias_zero.npy', spk_lif1_arr)


replace rnn subgraph with nirgraph
HAS BIAS
HAS BIAS
HAS BIAS
torch.Size([256, 12])
torch.Size([256, 7])
(256, 38)


In [4]:
nir_graph = nir.read("braille_noDelay_noBias_subtract.nir")
net = import_nirtorch.from_nir(nir_graph)
h_state = nirtorch.from_nir.GraphExecutorState(
    state={
        'lif1': net._modules['lif1'].init_rsynaptic(),  # 3-tuple: spk, syn, mem
        'lif2': net._modules['lif2'].init_synaptic(),  # 2-tuple: syn, mem
    }
)
spk_out_arr = []
h_state_arr = []
data_snn = ds_test[0][0]
print(data_snn.shape)
for t in range(data_snn.shape[0]):
    spk_out_snn, h_state = net(data_snn[t], h_state)
    spk_out_arr.append(spk_out_snn)
    h_state_arr.append(h_state)
spk_out_arr = torch.stack(spk_out_arr, dim=0)
print(spk_out_arr.shape)
spk_lif1_arr = [h_state.cache['lif1'] for h_state in h_state_arr]
spk_lif1_arr = torch.stack(spk_lif1_arr, dim=0).detach().numpy()
print(spk_lif1_arr.shape)
np.save('snntorch_activity_noDelay_noBias_subtract.npy', spk_lif1_arr)


replace rnn subgraph with nirgraph
torch.Size([256, 12])
torch.Size([256, 7])
(256, 40)


## save accuracies

In [5]:
def save_accuracy(nir_graph_file, save_to_npy_file):
    nir_graph = nir.read(nir_graph_file)
    net = import_nirtorch.from_nir(nir_graph)

    batch_size = 64
    shuffle = False
    loader = DataLoader(ds_test, batch_size=batch_size, shuffle=shuffle)

    with torch.no_grad():
        net.eval()
        batch_acc = []
        for data, labels in loader:  # data comes as: NTC
            data_snn = data.swapaxes(1, 0)  # TNC
            h_state = nirtorch.from_nir.GraphExecutorState(
                state={
                    'lif1': net._modules['lif1'].init_rsynaptic(),  # 3-tuple: spk, syn, mem
                    'lif2': net._modules['lif2'].init_synaptic(),  # 2-tuple: syn, mem
                }
            )
            spk_out_arr = []
            for t in range(data_snn.shape[0]):
                spk_out_snn, h_state = net(data_snn[t], h_state)
                spk_out_arr.append(spk_out_snn)
            spk_out_arr = torch.stack(spk_out_arr, dim=0)
            print(spk_out_arr.shape)

            act_total_out = torch.sum(spk_out_arr, 0)  # sum over time
            _, neuron_max_act_total_out = torch.max(act_total_out, 1)  # argmax output > labels
            batch_acc.extend((neuron_max_act_total_out == labels).detach().cpu().numpy())

    print(np.mean(batch_acc))
    np.save(save_to_npy_file, np.mean(batch_acc))


In [6]:
model = 'noDelay_bias_zero'
save_accuracy(f"braille_{model}.nir", f"snntorch_accuracy_{model}.npy")


replace rnn subgraph with nirgraph
HAS BIAS
HAS BIAS
HAS BIAS
torch.Size([256, 64, 7])
torch.Size([256, 64, 7])
torch.Size([256, 12, 7])
0.95


In [7]:
model = 'noDelay_noBias_subtract'
save_accuracy(f"braille_{model}.nir", f"snntorch_accuracy_{model}.npy")


replace rnn subgraph with nirgraph
torch.Size([256, 64, 7])
torch.Size([256, 64, 7])
torch.Size([256, 12, 7])
0.9
