In [None]:
import torch
import torch.nn as nn
import sinabs
from sinabs.nir import from_nir
import numpy as np
import nir

In [None]:
graph = nir.read('lif_norse.nir')
for nkey, node in graph.nodes.items():
    print(f"{nkey:7} {node}")

In [None]:
isis = [
    6, 15, 4, 3, 0, 2, 1, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 
    0, 0, 0, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 14
]
d1 = list(np.concatenate([isi*[0]+[1] for isi in isis]))[:100]
d = torch.tensor([[e] + 9*[0] for e in d1], dtype=torch.float).reshape(-1, 1)
d *= 0.04
dt = 1e-4

In [None]:
# - Generate sinabs model
sinabs_model = from_nir(graph, batch_size=1)

In [None]:
# - Make LIF layer record its membrane potential
lif_layer = sinabs_model.execution_order[2].elem  # Handle to LIF layer
lif_layer.record_states = True
# - Scale time constant according to dt of data
lif_layer.tau_mem.data /= dt

In [None]:
output = sinabs_model(d).detach().numpy()[:, 0]
v_mem = lif_layer.recordings["v_mem"][0, : , 0].detach().numpy()

In [None]:
import matplotlib.pyplot as plt
fig, axs = plt.subplots(2, 1, sharex=True)
axs[0].eventplot(np.where(output > 0)[0])
axs[1].plot(v_mem)
plt.show();

In [None]:
with open(f'lif_sinabs.csv', 'w') as fw:
    for idx in range(d.shape[0]):
        fw.write(f'{d[idx, 0]},{v_mem[idx]},{output[idx]}\n')