# **Neuromorphic Lab**

<font size="5"><b>"Neuromorphic Computing and Engineering" Ph.D. course</b></font>

Prof. G. Urgese, V. Fra

<font size="3">*Politecnico di Torino, 2024*</font>

<font size="2">The present *draft* notebook is extracted from the material prepared for hands-on coding in a laboratory session of a Ph.D. course at Politecnico di Torino.</font><br>
<font size="2">The aim was to show how NIR allows to train a network in some framework and then use it in some other framework. Consequently, and as for the purpose of the course to let students play with neuron models and parameters, the proposed architecture did not undergo optimization but relied on quick training phases after network design.<br></font>

<font size="2">NOTE that, at the time of running, `reset_delay=False` turned out to produce an error in `snnTorch` which was not present for the NIR experiments (and which was maybe introduced by some snnTorch update). I did not investigated it and I kept the resulting inconsistency knowing it impacts on the comparison between the two adopted frameworks. Despite this issue, the goal of showing how NIR, at the end of a laboratory session entirely based on `snnTorch`, allows to run the network in a completely new framework was achievable.</font>

## Environment set-up for Google Colab

In [None]:
!pip3 freeze > colab_default.txt

Load `env_nce24_colab.txt` and the `figures` folder in `./content/`

In [None]:
with open('colab_default.txt', 'r') as f:
  colab_default = f.readlines()

with open('env_nce24_colab.txt', 'r') as f:
  env_nce24_colab = f.readlines()

with open('nce24_colab.txt', 'w') as f:
  for pkg in env_nce24_colab:
    if any(pkg.split("==")[0].casefold() in ii.casefold() for ii in colab_default):
      pass
    else:
      f.write(pkg)

In [None]:
!pip install -r nce24_colab.txt

In [None]:
!git clone https://gitlab.com/spinnaker2/py-spinnaker2.git

In [None]:
!pip install -e ./py-spinnaker2

---
Remember to include the changes into `snntorch.export_nir.export_to_nir` (path: /usr/local/lib/python3.10/dist-packages/snntorch/export_nir.py) for the `Synaptic` case:

lines
```
# TODO: assert that size of the current layer is correct
alpha = module.alpha.detach().numpy()
beta = module.beta.detach().numpy()
vthr = module.threshold.detach().numpy()
```

to be replaced with
```
assert "n_neurons" in dir(module), "The n_neurons attribute must be set for the given module: module.__setattr__('n_neurons',VALUE)"
alpha = np.ones(module.n_neurons) * module.alpha.detach().numpy()
beta = np.ones(module.n_neurons) * module.beta.detach().numpy()
vthr = np.ones(module.n_neurons) * module.threshold.detach().numpy()
```

---

## Basic settings

---
Remember to include the changes into `snntorch.export_nir.export_to_nir` for the `Synaptic` case:

lines
```
# TODO: assert that size of the current layer is correct
alpha = module.alpha.detach().numpy()
beta = module.beta.detach().numpy()
vthr = module.threshold.detach().numpy()
```

to be replaced with
```
assert "n_neurons" in dir(module), "The n_neurons attribute must be set for the given module: module.__setattr__('n_neurons',VALUE)"
alpha = np.ones(module.n_neurons) * module.alpha.detach().numpy()
beta = np.ones(module.n_neurons) * module.beta.detach().numpy()
vthr = np.ones(module.n_neurons) * module.threshold.detach().numpy()
```

---

In [None]:
import gdown
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pickle as pkl
import random
from sklearn.model_selection import train_test_split
import tensorflow as tf
import torch
from torch.utils.data import TensorDataset
from tqdm import tqdm

In [None]:
use_seed = True

if use_seed:
    seed = 42
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
    np.random.seed(seed)
    random.seed(seed)
    tf.random.set_seed(seed)
    torch.manual_seed(seed)
    torch.use_deterministic_algorithms(True)
else:
    seed = None

In [None]:
device = "cpu"

### Custom utility functions

In [None]:
def create_directory(
    directory_path
    ):
    """
    Muller-Cleve, Simon F.; Istituto Italiano di Tecnologia - IIT; Event-driven perception in robotics - EDPR; Genova, Italy.
    """
    if os.path.exists(directory_path):
        return None
    else:
        try:
            os.makedirs(directory_path)
        except:
            # in case another machine created the path meanwhile! :(
            return None
        return directory_path


def train_validation_test_split(
    data,
    label,
    split=[70, 20, 10],
    seed=None,
    multiple=False,
    save_dataset=False,
    save_tensor=False,
    labels_type=None,
    labels_mapping=None,
    save_name=None,
    save_path=None
    ):
    """
    Creates train-validation-test splits using the sklearn train_test_split() twice.
    Can be used either to prepare "ready-to-use" splits or to create and store splits.

    If multiple splits are not needed and no saving option is set, the lists x_train, y_train, x_val, y_val, x_test, y_test are returned (without labels mapping).

    Function accepts lists, arrays, and tensor.
    Default split: [training: 70, validation: 20, test: 10]

    Fra, Vittorio; Politecnico di Torino; EDA Group; Torino, Italy.
    Muller-Cleve, Simon F.; Istituto Italiano di Tecnologia - IIT; Event-driven perception in robotics - EDPR; Genova, Italy.
    """

    if multiple:
        if (not save_dataset) & (not save_tensor):
            raise ValueError("Multiple train-val splits are created but no saving option is enabled.")

    if save_dataset | save_tensor:
        if (save_path == None) | (save_name == None):
            raise ValueError("Check a file name and a path are provided to save the datasets.")
        filename_prefix = save_path + save_name
        create_directory(save_path)

    # do some sanity checks first
    if len(split) != 3:
        raise ValueError(
            f"Split dimensions are wrong. Expected 3 but got {len(split)}. Please provide split in the form [train size, test size, validation size].")
    if min(split) == 0.0:
        raise ValueError(
            "Found entry 0.0. If you want to use only perfrom a two-folded split, use the sklearn train_test_split function only please.")
    if sum(split) > 99.0:
        split = [x/100 for x in split]
    if sum(split) < 0.99:
        raise ValueError("Please use a split summing up to 1, or 100%.")

    train, val, test = split
    split_1 = test
    split_2 = 1 - train/(train+val)

    x_trainval, x_test, y_trainval, y_test = train_test_split(
        data, label, test_size=split_1, shuffle=True, stratify=label, random_state=seed)


    if save_dataset: # Save the test split
        filename_test = filename_prefix + "_test"
        # xs test
        with open(f"{filename_test}.pkl", 'wb') as handle:
            pkl.dump(np.array(x_test, dtype=object), handle,
                        protocol=pkl.HIGHEST_PROTOCOL)
        # ys test
        with open(f"{filename_test}_label.pkl", 'wb') as handle:
            pkl.dump(np.array(y_test, dtype=object), handle,
                        protocol=pkl.HIGHEST_PROTOCOL)

    if save_tensor: # Save the test split
        filename_test = filename_prefix + "_ds_test"
        x_test = torch.as_tensor(np.array(x_test), dtype=torch.float)
        if labels_type == str:
            labels_test = torch.as_tensor(value2index(
                y_test, labels_mapping), dtype=torch.long)
        else:
            labels_test = torch.as_tensor(y_test, dtype=torch.long)
        ds_test = TensorDataset(x_test, labels_test)
        torch.save(ds_test, "{}.pt".format(filename_test))

    if multiple:

        for ii in range(10):

            x_train, x_val, y_train, y_val = train_test_split(
                x_trainval, y_trainval, test_size=split_2, shuffle=True, stratify=y_trainval, random_state=seed)

            if save_dataset:

                filename_train = filename_prefix + "_train"
                filename_val = filename_prefix + "_val"

                # xs training
                with open(f"{filename_train}_{ii}.pkl", 'wb') as handle:
                    pkl.dump(np.array(x_train, dtype=object), handle,
                                protocol=pkl.HIGHEST_PROTOCOL)
                # ys training
                with open(f"{filename_train}_{ii}_label.pkl", 'wb') as handle:
                    pkl.dump(np.array(y_train, dtype=object), handle,
                                protocol=pkl.HIGHEST_PROTOCOL)

                # xs validation
                with open(f"{filename_val}_{ii}.pkl", 'wb') as handle:
                    pkl.dump(np.array(x_val, dtype=object), handle,
                                protocol=pkl.HIGHEST_PROTOCOL)
                # ys validation
                with open(f"{filename_val}_{ii}_label.pkl", 'wb') as handle:
                    pkl.dump(np.array(y_val, dtype=object), handle,
                                protocol=pkl.HIGHEST_PROTOCOL)

            if save_tensor:

                filename_train = filename_prefix + "_ds_train"
                filename_val = filename_prefix + "_ds_val"

                x_train = torch.as_tensor(np.array(x_train), dtype=torch.float)
                if labels_type == str:
                    labels_train = torch.as_tensor(value2index(
                        y_train, labels_mapping), dtype=torch.long)
                else:
                    labels_train = torch.as_tensor(y_train, dtype=torch.long)

                x_validation = torch.as_tensor(
                    np.array(x_val), dtype=torch.float)
                if labels_type == str:
                    labels_validation = torch.as_tensor(value2index(
                        y_val, labels_mapping), dtype=torch.long)
                else:
                    labels_validation = torch.as_tensor(y_val, dtype=torch.long)

                ds_train = TensorDataset(x_train, labels_train)
                ds_val = TensorDataset(x_validation, labels_validation)

                torch.save(ds_train, "{}_{}.pt".format(filename_train,ii))
                torch.save(ds_val, "{}_{}.pt".format(filename_val,ii))

    else:

        x_train, x_val, y_train, y_val = train_test_split(
            x_trainval, y_trainval, test_size=split_2, shuffle=True, stratify=y_trainval, random_state=seed)

        if save_dataset:

            filename_train = filename_prefix + "_train"
            filename_val = filename_prefix + "_val"

            # xs training
            with open(f"{filename_train}.pkl", 'wb') as handle:
                pkl.dump(np.array(x_train, dtype=object), handle,
                            protocol=pkl.HIGHEST_PROTOCOL)
            # ys training
            with open(f"{filename_train}_label.pkl", 'wb') as handle:
                pkl.dump(np.array(y_train, dtype=object), handle,
                            protocol=pkl.HIGHEST_PROTOCOL)

            # xs validation
            with open(f"{filename_val}.pkl", 'wb') as handle:
                pkl.dump(np.array(x_val, dtype=object), handle,
                            protocol=pkl.HIGHEST_PROTOCOL)
            # ys validation
            with open(f"{filename_val}_label.pkl", 'wb') as handle:
                pkl.dump(np.array(y_val, dtype=object), handle,
                            protocol=pkl.HIGHEST_PROTOCOL)

        if save_tensor:

            filename_train = filename_prefix + "_ds_train"
            filename_val = filename_prefix + "_ds_val"

            x_train = torch.as_tensor(np.array(x_train), dtype=torch.float)
            if labels_type == str:
                labels_train = torch.as_tensor(value2index(
                    y_train, labels_mapping), dtype=torch.long)
            else:
                labels_train = torch.as_tensor(y_train, dtype=torch.long)

            x_validation = torch.as_tensor(
                np.array(x_val), dtype=torch.float)
            if labels_type == str:
                labels_validation = torch.as_tensor(value2index(
                    y_val, labels_mapping), dtype=torch.long)
            else:
                labels_validation = torch.as_tensor(y_val, dtype=torch.long)

            ds_train = TensorDataset(x_train, labels_train)
            ds_val = TensorDataset(x_validation, labels_validation)

            torch.save(ds_train, filename_train)
            torch.save(ds_val, filename_val)

        return x_train, y_train, x_val, y_val, x_test, y_test


def value2index(
    entry,
    dictionary
    ):
    """
    Fra, Vittorio; Politecnico di Torino; EDA Group; Torino, Italy.
    """

    if (type(entry) != list) & (type(entry) != np.ndarray):

        idx = [list(dictionary.values()).index(entry)]

    else:

        idx = [list(dictionary.values()).index(e) for e in entry]

    return idx


### Loops

In [None]:
def training_loop(
    dataset,
    batch_size,
    net,
    optimizer,
    loss_fn):
    """
    Fra, Vittorio; Politecnico di Torino; EDA Group; Torino, Italy.
    """

    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)

    batch_loss = []
    batch_acc = []

    for data, labels in tqdm(train_loader):

      data = data#.to(device)
      labels = labels#.to(device)

      net.train()
      spk_rec, _, _ = net(data)

      # Training loss
      loss_val = loss_fn(spk_rec, labels)
      batch_loss.append(loss_val.detach().cpu().item())

      # Training accuracy
      act_total_out = torch.sum(spk_rec, 0)  # sum over time
      _, neuron_max_act_total_out = torch.max(act_total_out, 1)  # argmax over output units to compare to labels
      batch_acc.append(np.mean((neuron_max_act_total_out == labels).detach().cpu().numpy()))

      # Gradient calculation + weight update
      optimizer.zero_grad()
      loss_val.backward()
      optimizer.step()

    epoch_loss = np.mean(batch_loss)
    epoch_acc = np.mean(batch_acc)

    return [epoch_loss, epoch_acc]


def val_test_loop(
    dataset,
    batch_size,
    net,
    loss_fn,
    shuffle=True,
    label_probabilities=False,
    return_spikes=False):
    """
    Fra, Vittorio; Politecnico di Torino; EDA Group; Torino, Italy.
    """

    with torch.no_grad():
      net.eval()

      loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=False)

      batch_loss = []
      batch_acc = []

      for data, labels in tqdm(loader):
          data = data#.to(device)
          labels = labels#.to(device)

          spk_out, _, _ = net(data)

          # Loss
          loss_val = loss_fn(spk_out, labels)
          batch_loss.append(loss_val.detach().cpu().item())

          # Accuracy
          act_total_out = torch.sum(spk_out, 0)  # sum over time
          _, neuron_max_act_total_out = torch.max(act_total_out, 1)  # argmax over output units to compare to labels
          batch_acc.append(np.mean((neuron_max_act_total_out == labels).detach().cpu().numpy()))

      if label_probabilities:
          log_softmax_fn = nn.LogSoftmax(dim=-1)
          log_p_y = log_softmax_fn(act_total_out)
          if return_spikes:
            return [np.mean(batch_loss), np.mean(batch_acc)], torch.exp(log_p_y), spk_out.detach().cpu().numpy()
          else:
            return [np.mean(batch_loss), np.mean(batch_acc)], torch.exp(log_p_y)
      else:
        if return_spikes:
          return [np.mean(batch_loss), np.mean(batch_acc)], spk_out.detach().cpu().numpy()
        else:
          return [np.mean(batch_loss), np.mean(batch_acc)]

## Retrieve data

In [None]:
### Link(s) to the folder with raw Braille data

braille_nir_train_link = "https://drive.google.com/file/d/1xjT8I-Y70yMV_HcKbWmalqXu1YaJMZEg/view?usp=sharing"
braille_nir_val_link = "https://drive.google.com/file/d/1m6XPpmddEmp0HrO4WUBwQQZn0jnO1MkA/view?usp=sharing"
braille_nir_test_link = "https://drive.google.com/file/d/1KwqB4U5LnPhFvmn1GJoHYSwrX9Ccs5_8/view?usp=sharing"

In [None]:
data_folder = "./data/Braille" 
create_directory(data_folder)
files_in_data_folder = os.listdir(data_folder)

In [None]:
if "ds_train_nir.pt" in files_in_data_folder:
    ds_train = torch.load(os.path.join(data_folder,"ds_train_nir.pt"), map_location=device)
else:
    ds_train = torch.load(gdown.download(braille_nir_train_link, output=os.path.join(data_folder,"ds_train_nir.pt"), fuzzy=True), map_location=device)

if "ds_val_nir.pt" in files_in_data_folder:
    ds_val = torch.load(os.path.join(data_folder,"ds_val_nir.pt"), map_location=device)
else:
    ds_val = torch.load(gdown.download(braille_nir_val_link, output=os.path.join(data_folder,"ds_val_nir.pt"), fuzzy=True), map_location=device)

if "ds_test_nir.pt" in files_in_data_folder:
    ds_test = torch.load(os.path.join(data_folder,"ds_test_nir.pt"), map_location=device)
else:
    ds_test = torch.load(gdown.download(braille_nir_test_link, output=os.path.join(data_folder,"ds_test_nir.pt"), fuzzy=True), map_location=device)

num_steps = next(iter(ds_test))[0].shape[0]
letter_written_nir = ['Space', 'A', 'E', 'I', 'O', 'U', 'Y']


In [None]:
from itertools import islice

import nir

import torch.nn as nn
from torch.utils.data import DataLoader

import snntorch as snn
from snntorch import export_nir
from snntorch import functional as SF

from spinnaker2 import brian2_sim, s2_nir

## Neuromorphic Intermediate Representation (NIR)

<center>
<img src="./figures/logo_NIR_dark.png" alt="drawing" width="700"/>
</center><br>

<div style="text-align:center">    
  From <a href="https://arxiv.org/abs/2311.14641">"Neuromorphic Intermediate Representation: A Unified Instruction Set for Interoperable Brain-Inspired Computing"</a>
</div><br>

<center>
<img src="./figures/NIR_arXiv_reduced.png" alt="drawing" width="1400"/>
</center><br>

<div style="text-align:center">    
  <a href="https://github.com/neuromorphs/NIR/tree/main">GitHub repository</a>
</div><br>

<img src="./figures/NIR_taxonomy.png" alt="drawing" width="700" style="float: right; margin-left: 10px; margin-right: 20px"/>

*Spiking neural networks and neuromorphic hardware platforms that emulate neural dynamics are slowly gaining momentum and entering main-stream usage. Despite a well-established mathematical foundation for neural dynamics,* **the implementation details vary greatly across different platforms.** *Correspondingly, there are a plethora of software and hardware implementations with their own unique technology stacks. Consequently, neuromorphic systems typically diverge from the expected computational model, which challenges the reproducibility and reliability across platforms. Additionally,* **most neuromorphic hardware is limited by its access via a single software frameworks with a limited set of training procedures.** *Here, we establish a common reference-frame for computations in neuromorphic systems, dubbed the Neuromorphic Intermediate Representation (NIR). NIR defines a set of computational primitives as idealized continuous-time hybrid systems that can be composed into graphs and mapped to and from various neuromorphic technology stacks. By abstracting away assumptions around discretization and hardware constraints, NIR faithfully captures the fundamental computation, while simultaneously exposing the exact differences between the evaluated implementation and the idealized mathematical formalism.<br>* 

<img src="./figures/NIR_at-a-glance.png" alt="drawing" width="500" style="float: right; margin-left: 10px; margin-right: 20px"/>

*In the paper, three NIR graphs are reproduced across* **7 neuromorphic simulators and 4 hardware platforms,** *demonstrating support for an unprecedented number of neuromorphic systems:* 

**With NIR, we decouple the evolution of neuromorphic hardware and software, ultimately increasing the interoperability between platforms and improving accessibility to neuromorphic technologies.** 

<img src="./figures/NIR_primitives_table.png" alt="drawing" width="450" style="float: left; margin-right: 10px"/>

<img src="./figures/NIR_higher-order.png" alt="drawing" width="450" style="float: right; margin-left: 10px; margin-right: 20px"/>

NIR defines **11 computational primitives and 3 higher-order primitives**.<br> 
We define common neuromorphic components like the linear map, leaky integrator, and spike threshold function, but we further included mathematical primitives such as the affine map and convolution. The input and output nodes serve to disambiguate the entries and exits of a graph.<br>

The 11 primitives in NIR are “fundamental” in the sense that the backends implementing the primitives are required to approximate the computation of the idealized description as closely as possible within the limitations of the platform. Any given platform is not expected to implement the full specification. This is particularly true for functionally specialized hardware, where hardware restrictions render certain functional primitives impossible.




<p style="text-align:center">    
  <font size="5"><b>The NIR experiments</b></font>
</p>

<center>
<img src="./figures/NIR_experiments.png" alt="drawing" width="700"/>
</center>

### Build and train a SNN in snnTorch for Braille classification and test it in a different framework

<img src="./figures/2L_FC_R-a2a.png" alt="drawing" width="400"/>

In [None]:
settings = {
    "input_size"     :    12,
    "nb_hidden"      :    100,
    "alpha_r"        :    0.3,
    "beta_r"         :    0.5,
    "thr_r"          :    0.8,
    "alpha_out"      :    0.2,
    "beta_out"       :    0.8,
    "thr_out"        :    1.0,
    "lr"             :    0.0005,
    "reset"          :    "zero",
    "reset_delay"    :    True
}

model_name = "nce24_braille_rsnn"

In [None]:
def model_build(settings, num_steps, device):

    ### Network structure (input data --> recurrent -> output)
    input_channels = int(settings["input_size"])
    num_hidden = int(settings["nb_hidden"])
    num_outputs = int(len(letter_written_nir))

    class Net(nn.Module):
        def __init__(self):
            super().__init__()

            ##### Initialize layers #####
            ### Recurrent layer
            self.fc1 = nn.Linear(input_channels, num_hidden)
            self.lif1 = snn.RSynaptic(alpha=settings["alpha_r"], beta=settings["beta_r"], threshold=settings["thr_r"], linear_features=num_hidden, reset_mechanism=settings["reset"], reset_delay=settings["reset_delay"])
            ### Output layer
            self.fc2 = nn.Linear(num_hidden, num_outputs)
            self.lif2 = snn.Synaptic(alpha=settings["alpha_out"], beta=settings["beta_out"], threshold=settings["thr_out"], reset_mechanism=settings["reset"], reset_delay=settings["reset_delay"])
            self.lif2.__setattr__("n_neurons",num_outputs)

        def forward(self, x):

            ##### Initialize hidden states at t=0 #####
            spk1, syn1, mem1 = self.lif1.init_rsynaptic()
            syn2, mem2 = self.lif2.init_synaptic()

            # Record the final layer
            spk2_rec = []
            syn2_rec = []
            mem2_rec = []

            for step in range(num_steps):
                ### Recurrent layer
                cur1 = self.fc1(x[:,step,:])
                spk1, syn1, mem1 = self.lif1(cur1, spk1, syn1, mem1)
                ### Output layer
                cur2 = self.fc2(spk1)
                spk2, syn2, mem2 = self.lif2(cur2, syn2, mem2)

                spk2_rec.append(spk2)
                syn2_rec.append(syn2)
                mem2_rec.append(mem2)

            return torch.stack(spk2_rec, dim=0), torch.stack(syn2_rec, dim=0), torch.stack(mem2_rec, dim=0)

    return Net().to(device)

In [None]:
net = model_build(settings, num_steps, device)

loss_fn = SF.ce_count_loss()

log_softmax_fn = nn.LogSoftmax(dim=-1)

optimizer = torch.optim.Adam(net.parameters(), lr=settings["lr"])

batch_size = 128

#### Train the network (with validation)

In [None]:
num_epochs = 10

training_results = []
validation_results = []

for epoch in range(num_epochs):

  train_loss, train_acc = training_loop(ds_train, batch_size, net, optimizer, loss_fn)
  val_loss, val_acc = val_test_loop(ds_val, batch_size, net, loss_fn)

  training_results.append([train_loss, train_acc])
  validation_results.append([val_loss, val_acc])

  print("Epoch {}/{}: \n\ttraining loss: {} \n\tvalidation loss: {} \n\ttraining accuracy: {}% \n\tvalidation accuracy: {}%".format(epoch+1, num_epochs, training_results[-1][0], validation_results[-1][0], np.round(training_results[-1][1]*100,4), np.round(validation_results[-1][1]*100,4)))

#### Test the network

In [None]:
test_results = val_test_loop(ds_test, batch_size, net, loss_fn)

print("Test accuracy: {}%".format(np.round(test_results[1]*100,4)))

_, counts = np.unique(ds_test[:][1].cpu(), return_counts=True)
most_frequent = counts[np.argmax(counts)]
chance_level = most_frequent/len(ds_test)

print("\tChance level is {}%".format(np.round(chance_level*100,4)))

#### Single-sample inference

In [None]:
single_sample = next(iter(DataLoader(ds_test, batch_size=1, shuffle=True)))
print("Randomly selected sample: {}".format(letter_written_nir[single_sample[1].cpu()[0]]))

In [None]:
_, spk_out = val_test_loop(TensorDataset(single_sample[0],single_sample[1]), 1, net, loss_fn, return_spikes=True)
pred = np.argmax(np.sum(spk_out, axis=0))

### Plot output spiking activity
spk_out = np.moveaxis(spk_out,1,2)
spk_out = np.squeeze(spk_out, axis=-1)
spk_out.shape
aer = []
for num,el in enumerate(spk_out):
  addr = np.where(el)[0].tolist()
  if len(addr) > 0:
    for ii in addr:
      aer.append([num,ii])
aer = np.array(aer)

plt.figure(figsize=(6,4.5))
plt.scatter(aer[:,0], aer[:,1], s=1)
plt.plot(range(0,spk_out.shape[0]),np.ones(spk_out.shape[0])*pred, '--', color="tab:red", alpha=0.5)
plt.xlabel("Timestep (a.u.)")
plt.ylabel("Output neuron")
plt.title("Output spiking activity (character: '{}', prediction: '{}')".format(letter_written_nir[single_sample[1].cpu()[0]],letter_written_nir[pred]))
plt.ylim(-0.5,int(len(letter_written_nir))-0.5)
plt.yticks(range(int(len(letter_written_nir))))
plt.show()

### Export the network to NIR

In [None]:
nir_graph = export_nir.export_to_nir(net.to("cpu"), ds_test[0][0].unsqueeze(dim=0).to(device))

print('Nodes:')
for nodekey, node in nir_graph.nodes.items():
    print('\t', nodekey, node.__class__.__name__)
print('Edges:')
for edge in nir_graph.edges:
    print('\t', edge)

create_directory('./graphs')
nir.write(f'./graphs/{model_name}.nir', nir_graph)

### [snnTorch -->] NIR --> Brian 2

In [None]:
def input_array_to_spike_list(input_array):
    input_spikes = {}
    spike_counts_s2 = 0
    for i, row in enumerate(input_array):
        input_spikes[i] = np.where(row == 1)[0].astype(int).tolist()
        spike_counts_s2 += len(input_spikes[i])
    return input_spikes

In [None]:
backend = "brian2"
brian2_quantize_weights = True

reset_method = s2_nir.ResetMethod.ZERO

In [None]:
nir_model = nir.read(f'./graphs/{model_name}.nir')

print("nodes:")
for nodekey, node in nir_model.nodes.items():
    print("\t", nodekey, node.__class__.__name__, node.input_type["input"].dtype)
print("edges:")
for edge in nir_model.edges:
    print("\t", edge)

s2_nir.add_output_to_node("lif1.lif", nir_model, "ouput_lif1")

cfg = s2_nir.ConversionConfig(
    output_record=["spikes"],
    dt=0.0001,
    conn_delay=0,
    scale_weights=True,
    reset=reset_method,
    integrator=s2_nir.IntegratorMethod.FORWARD,
)

In [None]:
net_brian, inp, outp = s2_nir.from_nir(nir_model, cfg)
assert len(inp) == 1  # make sure there is only one input pop

for pop in net_brian.populations:
    if pop.name == "lif1.lif":
        pop.set_max_atoms_per_core(10)

#### Test

In [None]:
n_samples = 50 #len(ds_test_enc)
predicted_labels = []
actual_labels = []

my_loader = DataLoader(ds_test, batch_size=1, shuffle=True)

for iteration, single_sample in enumerate(islice(my_loader, n_samples)):

    print(f"Sample {iteration+1}/{n_samples}:")

    sample = single_sample[0].numpy()

    spike_times = input_array_to_spike_list(sample[0, :, :].T)

    timesteps = sample.shape[1]
    net_brian.reset()  # clear previous spikes and voltages

    net_brian.populations[0].params = spike_times

    hw = brian2_sim.Brian2Backend()
    hw.run(net_brian, timesteps, quantize_weights=brian2_quantize_weights)

    output_pop = next(p for p in outp if p.name == "lif2")
    hidden_pop = next(p for p in outp if p.name == "lif1.lif")
    spike_times = output_pop.get_spikes()

    n_output_spikes = np.zeros(len(spike_times))
    for nrn, spikes in spike_times.items():
        n_output_spikes[nrn] = len(spikes)

    print(f"Total number of spikes per neuron: \n\t{n_output_spikes}")
    predicted_label = int(np.argmax(n_output_spikes))
    actual_label = int(single_sample[1])
    print(f"Character: '{letter_written_nir[actual_label]}'")
    print(f"Prediction: '{letter_written_nir[predicted_label]}'\n")
    predicted_labels.append(predicted_label)
    actual_labels.append(actual_label)

predicted_labels = np.array(predicted_labels)
actual_labels = np.array(actual_labels)
n_correct = np.count_nonzero(predicted_labels == actual_labels)
print(f"{n_correct} correct predictions out of {n_samples}")
print(f"Test accuracy: {np.round(n_correct/n_samples*100,2)}%")

#### Single-sample inference

In [None]:
single_sample = next(iter(DataLoader(ds_test, batch_size=1, shuffle=True)))
print("Randomly selected sample: {}".format(single_sample[1].cpu()[0]))

sample = single_sample[0].numpy()

spike_times = input_array_to_spike_list(sample.squeeze(axis=0).T)

timesteps = sample.shape[1]
net_brian.reset()

net_brian.populations[0].params = spike_times

hw = brian2_sim.Brian2Backend()
hw.run(net_brian, timesteps, quantize_weights=brian2_quantize_weights)

output_pop = next(p for p in outp if p.name == "lif2")
hidden_pop = next(p for p in outp if p.name == "lif1.lif")
spike_times = output_pop.get_spikes()

n_output_spikes = np.zeros(len(spike_times))
for nrn, spikes in spike_times.items():
    n_output_spikes[nrn] = len(spikes)

print(f"Total number of spikes per neuron: \n\t{n_output_spikes}")
predicted_label = int(np.argmax(n_output_spikes))
actual_label = int(single_sample[1])

spk_out = np.zeros((num_steps,int(len(letter_written_nir))))
for ii in spike_times.keys():
    spikes = np.zeros(num_steps)
    spikes[spike_times[ii]] = 1.
    spk_out[:,ii] = spikes
aer = []
for num,el in enumerate(spk_out):
  addr = np.where(el)[0].tolist()
  if len(addr) > 0:
    for ii in addr:
      aer.append([num,ii])
aer = np.array(aer)

plt.figure(figsize=(6,4.5))
plt.scatter(aer[:,0], aer[:,1], s=1)
plt.plot(range(0,spk_out.shape[0]),np.ones(spk_out.shape[0])*predicted_label, '--', color="tab:red", alpha=0.5)
plt.xlabel("Timestep (a.u.)")
plt.ylabel("Output neuron")
plt.title(f"Output spiking activity (character: '{letter_written_nir[actual_label]}', prediction: '{letter_written_nir[predicted_label]}')")
plt.ylim(-0.5,int(len(letter_written_nir))-0.5)
plt.yticks(range(int(len(letter_written_nir))))
plt.show()