# MNIST Classification

This tutorial demonstrates how to use the `lava.lib.dl.netx` API to classify MNIST images using an already trained `lava.lib.dl.slayer` network; i.e., we do _no_ training here, just inference. The classification is done on both: CPU (via $\texttt{Loihi2SimCfg}$) and on Loihi-2 neuro-cores (via $\texttt{Loihi2HwCfg}$). 

Note that this inference tutorial is part of the **end-to-end training** and **evaluation** tutorial: [mnist-on-loihi](https://github.com/R-Gaurav/mnist-on-loihi) -- it contains all the inference code explained here, as well as the `slayer` training code to obtain the trained network-weights (used here). The `slayer` training procedure is also explained in the [accompanying tutorial](https://r-gaurav.github.io/2024/04/13/Lava-Tutorial-MNIST-Training-on-GPU-and-Evaluation-on-Loihi2.html) and is quite straight-forward. However, when it comes to inference, there are some tips-and-tricks to keep in mind while evaluating the trained `slayer`-network via `netx`; and that's precisely the point of this tutorial. 

## `slayer` Network Architecture
The architecture of the trained `slayer`-network is as follows: 
$$\texttt{Dense CUBA(128)} \rightarrow \texttt{Dense CUBA(64)} \rightarrow \texttt{Dense CUBA(10)}$$
where, `Dense` denotes the fully connected `Dense` connection, and `CUBA(m)` denotes $\texttt{m}$ Current Based neurons. Note that the first **Hidden** layer: $\texttt{Dense CUBA(128)}$ accepts $784$-dimensional rate-encoded spikes (of the flattened MNIST images), and the (last) **Output** layer: $\texttt{Dense CUBA(10)}$ consists of $10$ output neurons denoting the classes; classification is done on the maximally spiking output neuron.

## Loihi-2 deployment
To deploy and evaluate the above (trained) `slayer`-network on Loihi-2 boards, we are going to load it via `netx` and connect **Input** and **Output** `Process`es (to its ends), which will encode the test-image to input spikes and predict the class from the output spikes, respectively. Note that since the `slayer`-network is loaded via `netx`, we also call it as <ins>`netx`-obtained network</ins> here (and use the terms interchangeably as appropriate); the architecture for Loihi-2 deployment is _conceptually_ going to look like: **Input** `Process` -> `netx`-obtained network -> **Output** `Process`.

Without further ado, let's start by importing the necessary libraries/modules.

In [1]:
import logging
import numpy as np

from lava.lib.dl import netx
from lava.magma.core.decorator import implements, requires
from lava.magma.core.model.py.model import PyLoihiProcessModel
from lava.magma.core.model.py.ports import PyInPort, PyOutPort
from lava.magma.core.model.py.type import LavaPyType
from lava.magma.core.process.ports.ports import InPort, OutPort
from lava.magma.core.process.process import AbstractProcess
from lava.magma.core.process.variable import Var
from lava.magma.core.resources import CPU
from lava.magma.core.run_conditions import RunSteps
from lava.magma.core.run_configs import Loihi2SimCfg, Loihi2HwCfg
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol
from lava.utils.dataloader.mnist import MnistDataset

from utils import (InputAdapter, PyInputAdapter, NxInputAdapter,
                   OutputAdapter, PyOutputAdapter, NxOutputAdapter)

# `Process` and `ProcessModel` to encode Images to **Input** Spikes
As mentioned earlier, the first **Hidden** layer (in the `netx`-obtained network) accepts a $784$-dimensional spike vector corresponding to a flattened MNIST image. Therefore, we need to write down the **Input** `Process` which will rate-encode the pixels to binary spikes; note that while training the above `slayer`-network, the pixels were first normalized between $[0, 1]$ and then rate-encoded via the following equation:
$$J = \alpha<e.x> + \beta$$
where $J$ is the input current to encoding neuron, $e$ is the encoder, $x$ is the normalized pixel value, and $\alpha$ and $\beta$ are the neuron's `gain` and `bias` values; their values are $e=1$ (since $x>=0$ always), $\alpha=1$ and $\beta=0$. We will use the same above equation (for $J$) to rate-encode the normalized pixels to spikes in our **Input** `Process`: $\texttt{InpImgToSpk}$ below. 

In [2]:
class InpImgToSpk(AbstractProcess):
  """
  Input process to convert flattened images to binary spikes.
  """
  def __init__(self, img_shape, n_tsteps, curr_img_id, v_thr=1):
    super().__init__()
    self.spk_out = OutPort(shape=(img_shape, ))
    self.label_out = OutPort(shape=(1, ))

    self.curr_img_id = Var(shape=(1, ), init=curr_img_id)
    self.n_ts = Var(shape=(1, ), init=n_tsteps)
    self.inp_img = Var(shape=(img_shape, ))
    self.ground_truth_label = Var(shape=(1, ))
    self.v = Var(shape=(img_shape, ), init=0)
    self.vth = Var(shape=(1, ), init=v_thr)

Now that we have defined the `Process`: $\texttt{InpImgToSpk}$, let's implement its corresponding `ProcessModel`: $\texttt{PyInpImgToSpkModel}$ that will run on CPU.

In [3]:
@implements(proc=InpImgToSpk, protocol=LoihiProtocol)
@requires(CPU)
class PyInpImgToSpkModel(PyLoihiProcessModel):
  """
  Python implementation for the above `InpImgToSpk` process.
  """
  spk_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, bool, precision=1)
  label_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, np.int32, precision=32)

  curr_img_id: int = LavaPyType(int, int, precision=32)
  n_ts: int = LavaPyType(int, int, precision=32)
  inp_img: np.ndarray = LavaPyType(np.ndarray, float)
  ground_truth_label: int = LavaPyType(int, int, precision=32)
  v: np.ndarray = LavaPyType(np.ndarray, float)
  vth: float = LavaPyType(float, float)

  def __init__(self, proc_params):
    super().__init__(proc_params=proc_params)
    self.mnist_dset = MnistDataset()
    self.gain = 1
    self.bias = 0

  def post_guard(self):
    """
    Guard function for post-management phase, necessary to update the next image
    index after the current image is processed.

    Note: The execution control calls `post_guard()` after `run_spk()` every
    time-step, before updating the `self.time_step` variable to next time-step.
    """
    if self.time_step % self.n_ts == 1: # n_ts steps passed, one image processed.
      return True

    return False

  def run_post_mgmt(self):
    """
    Post-management phase executed only when the above `post_guard()` returns
    True -> then, move to the next image, reset the neuron states, etc.
    """
    img = self.mnist_dset.test_images[self.curr_img_id]
    self.inp_img = img/255
    self.ground_truth_label = self.mnist_dset.test_labels[self.curr_img_id]
    self.label_out.send(np.array([self.ground_truth_label]))
    self.v = np.zeros(self.v.shape, dtype=float)
    self.curr_img_id += 1

  def run_spk(self):
    """
    Spiking phase, this is executed every simulation time-step unconditionally,
    and first in order of all the phases.
    """
    if self.time_step % self.n_ts == 1:
     self.inp_img = np.zeros(self.inp_img.shape, dtype=float)
     self.v = np.zeros(self.v.shape, dtype=float)

    J = self.gain*self.inp_img + self.bias
    self.v[:] = self.v[:] + J[:]
    mask = self.v > self.vth
    self.v[mask] = 0
    self.spk_out.send(mask)

There are a bunch of important points to note here:
* Lava's execution/run `time-step` starts from $1$, and
* Whenever the run time-step is one more than a multiple of the image presentation - $\texttt{self.n_ts}$ time-steps (per-image):
    * The $\texttt{run_spk()}$ phase resets the input image variable: $\texttt{self.inp_img}$ and the encoding neuron's voltage: $\texttt{self.v}$ to all zeros, and
    * The $\texttt{post_guard()}$ phase returns $\texttt{True}$ and the $\texttt{run_post_mgmt()}$ phase gets called, which also resets the necessary variables

Let's look into these phases' operations more closely; note that they are discussed in considerable details (on a per time-step basis) in the [accompanying tutorial](https://r-gaurav.github.io/2024/04/13/Lava-Tutorial-MNIST-Training-on-GPU-and-Evaluation-on-Loihi2.html). 

As you would know, in each time-step, the $\texttt{run_spk()}$ phase is the first phase to be executed among $\texttt{post_guard()}$ and $\texttt{run_post_mgmt()}$ phases. Therefore, when the $\texttt{InpImgToSpk}$ `Process`'s execution starts, i.e., time-step $=1$, the $\texttt{self.inp_img}$ and $\texttt{self.v}$ are both reset to all zeros in $\texttt{run_spk()}$, and since $\texttt{post_guard()}$ returns $\texttt{True}$, the $\texttt{run_post_mgmt()}$ phase updates $\texttt{self.inp_img}$ to the first test-image (assuming $\texttt{self.curr_img_id}$ is set to start from $0$), as well as the other related variables. In the subsequent time-steps $\texttt{run_spk()}$ keeps getting called and the rate-encoding of $\texttt{self.inp_img}$ progresses.

When the per-image presentation time-steps (i.e., $\texttt{self.n_ts}$) are over, i.e., in the $(\texttt{self.n_ts} + 1)^{\text{th}}$ time-step, $\texttt{run_spk()}$ is called gain, but note that $\texttt{self.inp_img}$ is still the previous _old_ image, therefore, it's important to reset the $\texttt{self.inp_img}$ and $\texttt{self.v}$ in $\texttt{run_spk()}$ to ensure that the previous old image does _not_ corrupt the prediction corresponding to the new (to be updated) image. In the same $(\texttt{self.n_ts} + 1)^{\text{th}}$ time-step, $\texttt{post_guard()}$ returns $\texttt{True}$ and the $\texttt{run_post_mgmt()}$ phase finally updates $\texttt{self.inp_img}$ to the next new image (along with updating the ground truth).

Thus, it is important to reset $\texttt{self.inp_img}$ and $\texttt{self.v}$ in the $\texttt{run_spk()}$ phase in every $(k\times\texttt{self.n_ts} + 1)^{\text{th}}$ time-step, where $k \in \mathbb{W}$.

# `Process` and `ProcessModel` to infer Classes from **Output** Spikes
As mentioned before, the **Output** layer -- composed of $10$ neurons (each denoting a class) in the `netx`-obtained network produces spikes, upon which we can infer classes by accumulating them over a period of $\texttt{self.n_ts}$ time-steps (for each image) and reporting the index which has the maximum number of accumulated spikes. To do the same, we write down the following **Output** `Process`: $\texttt{OutSpkToCls}$.

In [4]:
class OutSpkToCls(AbstractProcess):
  """
  Output process to collect output neuron spikes and infer predicted class.
  """
  def __init__(self, n_tsteps, num_test_imgs, n_cls_shape=(10, )):
    super().__init__()
    self.spikes_in = InPort(shape=n_cls_shape) # Receives output spikes.
    self.label_in = InPort(shape=(1, )) # Receives ground truth labels.
    self.spikes_accum = Var(shape=n_cls_shape) # Accum. spikes for prediction.
    self.n_ts = Var(shape=(1, ), init=n_tsteps) # Image presentation time.
    self.pred_labels = Var(shape=(num_test_imgs, ))
    self.true_labels = Var(shape=(num_test_imgs, ))

Now that we have the $\texttt{OutSpkToCls}$ `Process` ready, let's write down its corresponding `ProcessModel`: $\texttt{PyOutSpkToClsModel}$ that runs on CPU. 

In [5]:
@implements(proc=OutSpkToCls, protocol=LoihiProtocol)
@requires(CPU)
class PyOutSpkToClsModel(PyLoihiProcessModel):
  spikes_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, bool, precision=1)
  label_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, int, precision=32)
  spikes_accum: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=32)
  n_ts: int = LavaPyType(int, int, precision=32)
  pred_labels: np.ndarray = LavaPyType(np.ndarray, int, precision=32)
  true_labels: np.ndarray = LavaPyType(np.ndarray, int, precision=32)

  def __init__(self, proc_params):
    super().__init__(proc_params=proc_params)
    self.curr_idx = 0

  def post_guard(self):
    """
    Guard function for Post-Management phase.
    """
    if self.time_step % self.n_ts == 0:
      return True

    return False

  def run_post_mgmt(self):
    """
    Post-Management phase: executed only when the guard function above returns
    True.
    """
    true_label = self.label_in.recv()
    pred_label = np.argmax(self.spikes_accum)
    self.true_labels[self.curr_idx] = true_label[0]
    self.pred_labels[self.curr_idx] = pred_label
    self.curr_idx += 1
    self.spikes_accum = np.zeros_like(self.spikes_accum)

  def run_spk(self):
    """
    Spiking phase: executed unconditionally at every time-step, first in order
    among all the phases.
    """
    spk_in = self.spikes_in.recv()
    self.spikes_accum = self.spikes_accum + spk_in

As can be seen above, the $\texttt{post_guard()}$ phase returns $\texttt{True}$ in every $(k\times\texttt{self.n_ts})^{\text{th}}$ time-step, where $k\in\mathbb{W}$, and thus, the $\texttt{run_post_mgmt()}$ phase gets evaluated in the very same time-step. Let's look into the operations of these phases more closely; per time-step operational details can be found in the [accompanying tutorial](https://r-gaurav.github.io/2024/04/13/Lava-Tutorial-MNIST-Training-on-GPU-and-Evaluation-on-Loihi2.html).

As you already know, the Lava run time-step starts with $1$ and $\texttt{run_spk()}$ is the first phase to be called every time-step in a `Process`'s execution. Here, in the time-step $=1$, the $\texttt{run_spk()}$ phase is called first and it accumulates the incoming spikes from the **Output** layer of the `netx`-obtained network; $\texttt{post_guard()}$ returns $\texttt{False}$ and $\texttt{run_post_mgmt()}$ is _not_ called. Such processing continues until the $\texttt{self.n_ts}^{\text{th}}$ time-step arrives. In the time-step $=\texttt{self.n_ts}$, $\texttt{run_spk()}$ still accumulates the output spikes corresponding to the first input image, _post_ which, $\texttt{post_guard()}$ returns $\texttt{True}$ and $\texttt{run_post_mgmt()}$ subsequently computes the index of the maximally spiking neuron as the predicted class (other variables are accordingly reset or updated). 

For the next time-steps, i.e., $(\texttt{self.n_ts} + 1)^{\text{th}}$ onwards, the execution of $\texttt{OutSpkToCls}$ `Process` continues as explained above, but with the updated $\texttt{self.inp_img}$ in the $\texttt{InpImgToSpk}$ `Process`. 

# Load the `slayer`-trained weights

Now that both the **Input** and **Output** `Process`es are ready, we can proceed with loading the (trained) `slayer`-network via `netx`. However, before we do that, note that the `slayer`-network was trained for $20$ time-steps each, on MNIST training images. Therefore, during inference, the test-image presentation time-steps, i.e., $\texttt{self.n_ts}$ is set $20$ here (it can be any reasonable number for practical purposes). In the code below, $\texttt{n_tsteps}$ denotes the test-image presentation time-steps (i.e.,  $\texttt{self.n_ts}$). 

In [6]:
# `n_tsteps` is the presentation time-steps of each test-image.
#n_tsteps = 20  
n_tsteps = 32 # Since reset_interval on Loihi-2 Hardware has to be a power of 2.

# `num_test_images` is the number of test-images to do inference on.
num_test_imgs = 25 # Set 10000 for Loihi2SimCfg (--takes some time to execute).

net = netx.hdf5.Network(
    net_config="./trained_mnist_network.net", # Trained network path.
    reset_interval=n_tsteps, # Presentation time-steps of each test-image.
    reset_offset=1 # Phase shift / offset time-step to reset this network.
    )

Note the two important nuances above:
* $\texttt{reset_interval}$ is set equal to $\texttt{n_tsteps}$, which implies that the `netx`-obtained network is _reset_ after every $\texttt{n_tsteps}$ time-steps, however
* $\texttt{reset_offset}$ is set equal to $1$, which implies that the network is reset with a _phase shift_ of $1$ time-step (an important detail here)

In other words, $\texttt{reset_offset}=1$ implies that the count of $\texttt{reset_interval}=\texttt{n_tsteps}$ starts _after_ the time-step $1$ is over. That is, in the above cell's code, if $\texttt{n_tsteps}=20$, then the `netx`-obtained network: $\texttt{net}$ is reset after $21^{\text{st}}, 41^{\text{st}}, 61^{\text{st}}, \cdots$ time-steps. 

The next step now is to instantiate the `Process`es and connect them appropriately, followed by their execution on either $\textsf{Loihi-2 Simulation}$ or $\textsf{Loihi-2 Hardware}$.

# Instantiating and Connecting `Process`es

Before we connect the `Process`es, we still need _adapters_ to transfer spikes to-and-fro between the CPU and the Loihi-2 neuro-cores. These adapters are fairly straightforward to understand and are already written in the $\texttt{utils.py}$ file in this current directory; we simply import and use them here.

In [7]:
# Instantiate Processes.

# `curr_img_id=0` implies that inference starts from test image at index 0.
img_to_spk = InpImgToSpk(img_shape=784, n_tsteps=n_tsteps, curr_img_id=0)

spk_to_cls = OutSpkToCls(n_tsteps=n_tsteps, num_test_imgs=num_test_imgs)
inp_adp = InputAdapter(shape=net.inp.shape)
out_adp = OutputAdapter(shape=net.out.shape)

# Connect Processes.
img_to_spk.spk_out.connect(inp_adp.inp)
inp_adp.out.connect(net.inp)
net.out.connect(out_adp.inp)
out_adp.out.connect(spk_to_cls.spikes_in)
# Connect ImgToSpk Input directly to SpkToCls Output for ground truths.
img_to_spk.label_out.connect(spk_to_cls.label_in)

The only major task now remaining is to create an appropriate $\texttt{run_config}$ depending upon the $\texttt{backend}$ we intend to deploy our network on. We do that in the function $\texttt{get_run_config()}$ below.

In [8]:
def get_run_config(backend):
  """
  Returns the run-time config corresponding to the `backend`.

  Args:
    backend <str>: Either "L2Sim" or "L2Hw" for Loihi2SimCfg or Loihi2HwCfg.
  """
  assert backend in ["L2Sim", "L2Hw"]

  if backend == "L2Sim": # Run on the Loihi-2 Simulation Hardware on CPU.
    run_config = Loihi2SimCfg(
        select_tag="fixed_pt", # To select fixed point implementation.
        exception_proc_model_map={
            InpImgToSpk: PyInpImgToSpkModel,
            OutSpkToCls: PyOutSpkToClsModel,
            InputAdapter: PyInputAdapter,
            OutputAdapter: PyOutputAdapter
            }
        )
  elif backend == "L2Hw": # Run on the Loihi-2 Physical Hardware on INRC.
    run_config = Loihi2HwCfg(
        select_sub_proc_model=True,
        exception_proc_model_map={
            InpImgToSpk: PyInpImgToSpkModel,
            OutSpkToCls: PyOutSpkToClsModel,
            InputAdapter: NxInputAdapter,
            OutputAdapter: NxOutputAdapter
            }
        )
  return run_config

# Inference on CPU and Loihi-2 Hardware

The function $\texttt{run_inference}$ below assists in evaluating our network on two $\texttt{backend}$s: 
* "$\texttt{L2Sim}$" for $\texttt{Loihi2SimCfg}$ on CPU, and
* "$\texttt{L2Hw}$" for $\texttt{Loihi2HwCfg}$ on Loihi-2 neuro-cores.

In [9]:
def run_inference(backend, is_log=False):
  """
  Args:
    backend <str>: "L2Sim" for deployment on CPU, "L2Hw" for deployment on 
                   Loihi-2 Hardware.
    is_log <bool>: Log the execution steps on Loihi-2 Hardware if True.
  """
  assert backend in ["L2Sim", "L2Hw"]
  run_config = get_run_config(backend=backend)
  if is_log and backend=="L2Hw":
    img_to_spk._log_config.level = logging.INFO
  
  for _ in range(num_test_imgs):
    img_to_spk.run(
      condition=RunSteps(num_steps=n_tsteps), run_cfg=run_config
    )
  ground_truths = spk_to_cls.true_labels.get().astype(np.int32)
  predtd_clsses = spk_to_cls.pred_labels.get().astype(np.int32)

  img_to_spk.stop()
  print("Accuracy on Loihi {0}: ".format(
        "Simulation" if backend=="L2Sim" else "Board"),
        np.mean(np.array(ground_truths) == np.array(predtd_clsses)))

In [10]:
# Execute.
run_inference("L2Hw", is_log=False)

Partitioning converged after iteration=5
Per core utilization:
-------------------------------------------------------------------------
| AxonIn |NeuronGr| Neurons|Synapses| AxonMap| AxonMem|  Total |  Cores |
|-----------------------------------------------------------------------|
|   0.40%|  12.50%|   0.24%|   1.60%|   0.06%|   0.00%|   1.71%|       1|
|   0.80%|  12.50%|   1.56%|  16.80%|   0.40%|   0.00%|  14.72%|       1|
|   4.90%|  12.50%|   1.22%|  78.40%|   0.31%|   0.00%|  67.14%|       3|
|-----------------------------------------------------------------------|
| Total                                                        |       5|
-------------------------------------------------------------------------
Accuracy on Loihi Board:  1.0
