# MNIST Classification

This tutorial demonstrates how to use the `lava.lib.dl.netx` API to classify MNIST images, using a `lava.lib.dl.slayer` trained network. The classification is done on both: CPU (via $\texttt{Loihi2SimCfg}$) and Loihi-2 neuro-cores (via $\texttt{Loihi2HwCfg}$).

Note that this inference tutorial is part of the **end-to-end training** and **evaluation** tutorial: [mnist-on-loihi](https://github.com/R-Gaurav/mnist-on-loihi) -- it contains all the inference code explained here, as well as the `slayer` training code to obtain the trained network-weights (used here). In this tutorial, we will _not_ be training the `slayer` network, rather load the exported (trained) weights via `netx` and do inference on two backends: $\textsf{Loihi-2 Simulation}$ ($\texttt{Loihi2SimCfg}$) and $\textsf{Loihi-2 Hardware}$ ($\texttt{Loihi2HwCfg}$). The `slayer` training procedure is explained in the [accompanying tutorial](https://r-gaurav.github.io/2024/04/13/Lava-Tutorial-MNIST-Training-on-GPU-and-Evaluation-on-Loihi2.html) and is quite straight-forward. However, there are some tips-and-tricks to keep in mind while evaluating the trained `slayer` network via `netx`; and that's precisely the point of this tutorial. Without further ado, let's begin!

# `Process` and `ProcessModel` to encode Images to Spikes

In [1]:

class InpImgToSpk(AbstractProcess):
  """
  Input process to convert flattened images to binary spikes.
  """
  def __init__(self, img_shape, n_tsteps, curr_img_id, v_thr=1):
    super().__init__()
    self.spk_out = OutPort(shape=(img_shape, ))
    self.label_out = OutPort(shape=(1, ))

    self.curr_img_id = Var(shape=(1, ), init=curr_img_id)
    self.n_ts = Var(shape=(1, ), init=n_tsteps)
    self.inp_img = Var(shape=(img_shape, ))
    self.ground_truth_label = Var(shape=(1, ))
    self.v = Var(shape=(img_shape, ), init=0)
    self.vth = Var(shape=(1, ), init=v_thr)

@implements(proc=InpImgToSpk, protocol=LoihiProtocol)
@requires(CPU)
class PyInpImgToSpkModel(PyLoihiProcessModel):
  """
  Python implementation for the above `InpImgToSpk` process.
  """
  spk_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, bool, precision=1)
  label_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, np.int32, precision=32)

  curr_img_id: int = LavaPyType(int, int, precision=32)
  n_ts: int = LavaPyType(int, int, precision=32)
  inp_img: np.ndarray = LavaPyType(np.ndarray, float)
  ground_truth_label: int = LavaPyType(int, int, precision=32)
  v: np.ndarray = LavaPyType(np.ndarray, float)
  vth: float = LavaPyType(float, float)

  def __init__(self, proc_params):
    super().__init__(proc_params=proc_params)
    self.mnist_dset = MnistDataset()
    self.gain = 1
    self.bias = 0

  def post_guard(self):
    """
    Guard function for post-management phase, necessary to update the next image
    index after the current image is processed.

    Note: The execution control calls `post_guard()` after `run_spk()` every
    time-step, before updating the `self.time_step` variable to next time-step.
    """
    if self.time_step % self.n_ts == 1: # n_ts steps passed, one image processed.
      return True

    return False

  def run_post_mgmt(self):
    """
    Post-management phase executed only when the above `post_guard()` returns
    True -> then, move to the next image, reset the neuron states, etc.
    """
    img = self.mnist_dset.test_images[self.curr_img_id]
    self.inp_img = img/255
    self.ground_truth_label = self.mnist_dset.test_labels[self.curr_img_id]
    self.label_out.send(np.array([self.ground_truth_label]))
    self.v = np.zeros(self.v.shape, dtype=float)
    self.curr_img_id += 1

  def run_spk(self):
    """
    Spiking phase, this is executed every simulation time-step unconditionally,
    and first in order of all the phases.
    """
    if self.time_step % self.n_ts == 1:
      self.inp_img = np.zeros(self.inp_img.shape, dtype=float)
      self.v = np.zeros(self.v.shape, dtype=float)

    J = self.gain*self.inp_img + self.bias
    self.v[:] = self.v[:] + J[:]
    mask = self.v > self.vth
    self.v[mask] = 0
    self.spk_out.send(mask)

NameError: name 'AbstractProcess' is not defined

# `Process` and `ProcessModel` to infer Classes from Spikes

In [2]:
class OutSpkToCls(AbstractProcess):
  """
  Output process to collect output neuron spikes and infer predicted class.
  """
  def __init__(self, n_tsteps, num_test_imgs, n_cls_shape=(10, )):
    super().__init__()
    self.spikes_in = InPort(shape=n_cls_shape) # Receives output spikes.
    self.label_in = InPort(shape=(1, )) # Receives ground truth labels.
    self.spikes_accum = Var(shape=n_cls_shape) # Accum. spikes for prediction.
    self.n_ts = Var(shape=(1, ), init=n_tsteps) # Image presentation time.
    self.pred_labels = Var(shape=(num_test_imgs, ))
    self.true_labels = Var(shape=(num_test_imgs, ))

@implements(proc=OutSpkToCls, protocol=LoihiProtocol)
@requires(CPU)
class PyOutSpkToClsModel(PyLoihiProcessModel):
  spikes_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, bool, precision=1)
  label_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, int, precision=32)
  spikes_accum: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=32)
  n_ts: int = LavaPyType(int, int, precision=32)
  pred_labels: np.ndarray = LavaPyType(np.ndarray, int, precision=32)
  true_labels: np.ndarray = LavaPyType(np.ndarray, int, precision=32)

  def __init__(self, proc_params):
    super().__init__(proc_params=proc_params)
    self.curr_idx = 0

  def post_guard(self):
    """
    Guard function for Post-Management phase.
    """
    if self.time_step % self.n_ts == 0:
      return True

    return False

  def run_post_mgmt(self):
    """
    Post-Management phase: executed only when the guard function above returns
    True.
    """
    true_label = self.label_in.recv()
    pred_label = np.argmax(self.spikes_accum)
    self.true_labels[self.curr_idx] = true_label
    self.pred_labels[self.curr_idx] = pred_label
    self.curr_idx += 1
    self.spikes_accum = np.zeros_like(self.spikes_accum)

  def run_spk(self):
    """
    Spiking phase: executed unconditionally at every time-step, first in order
    among all the phases.
    """
    spk_in = self.spikes_in.recv()
    self.spikes_accum = self.spikes_accum + spk_in


NameError: name 'AbstractProcess' is not defined

# Load the `slayer`-trained weights

# Connecting `Process`es and `netx`-obtained Network

# Inference on CPU

# Inference on Loihi-2 neuro-cores