*Copyright (C) 2021 Intel Corporation*<br>
*SPDX-License-Identifier: BSD-3-Clause*<br>
*See: https://spdx.org/licenses/*

---

# MNIST Digit Classification with Lava

_**Motivation**: In this tutorial, we will build a Lava Process for an MNIST
classifier, using the Lava Processes for LIF neurons and Dense connectivity.
Between those leaning towards Neuroscience and those partial to Computer
Science, this tutorial aims to be appealing to the former. It is supposed to
get one started with Lava in a few minutes._

### This tutorial assumes that you:
- have the [Lava framework installed](../in_depth/tutorial01_installing_lava.ipynb "Tutorial on Installing Lava")
- are familiar with the [Process concept in Lava](../in_depth/tutorial02_processes.ipynb "Tutorial on Processes")

### This tutorial gives a bird's-eye-view of
- how Lava Process(es) can perform the MNIST digit classification task using
[Leaky Integrate-and-Fire (LIF)](https://github.com/lava-nc/lava/tree/main/lava/proc/lif "Lava's LIF neuron") neurons and [Dense
(fully connected)](https://github.com/lava-nc/lava/tree/main/lava/proc/dense "Lava's Dense Connectivity") connectivity.
- how to create a Process 
- how to create Python ProcessModels 
- how to connect Processes
- how to execute them

### Follow the links below for deep-dive tutorials on
- [Processes](../in_depth/tutorial02_processes.ipynb "Tutorial on Processes")
- [ProcessModel](../in_depth/tutorial03_process_models.ipynb "Tutorial on ProcessModels")
- [Execution](../in_depth/tutorial04_execution.ipynb "Tutorial on Executing Processes")

### Our MNIST Classifier
In this tutorial, we will build a multi-layer feed-forward classifier without
 any convolutional layers. The architecture is shown below.

> **Important Note**:
>
> Right now, this model uses arbitrary _untrained_ network paramters (weights and biases)! We will update this model and fix this shortcoming in the next few days after release.
> Thus the MNIST classifier is not expected to produce any meaningful output at this point in time. 
> Nevertheless, this example illustrates how to build, compile and run an otherwise functional model in Lava.

<center><img src="https://raw.githubusercontent.com/lava-nc/lava-nc.github.io/main/_static/images/tutorial01/mnist_process_arch.png" alt="Training
flow"
style="width: 800px;"/></center>

The 3 Processes shown above are:
 1. Spike Input Process - generates spikes via integrate and fire dynamics,
 using the image input
 2. MNIST Feed-forward process - encapsulates feed-forward architecture of
 Dense connectivity and LIF neurons
 3. Output Process - accumulates output spikes from the feed-forward process
and infers the class label; compares the predicted class label with the ground truth

#### General Imports

In [1]:
# Assumes: $PYTHONPATH contains lava repository root
import os
import numpy as np

### Create the Process class

Below we create the Lava Process classes. We need to define only the structure of the process here. The details about how the Process will be executed are specified in the [ProcessModels](../in_depth/tutorial03_process_models.ipynb "Tutorial on ProcessModels") below.

As mentioned above, we define Processes for 
- converting input images to binary spikes from those biases (_SpikeInput_),
- the 4-layer fully connected feed-forward network (_MnistClassifier_)
- accumulating the output spikes and inferring the class for an input image
(_OutputProcess_)

In [2]:
# Import Process level premitives
from lava.magma.core.process.process import AbstractProcess
from lava.magma.core.process.variable import Var
from lava.magma.core.process.ports.ports import InPort, OutPort

In [3]:
class SpikeInput(AbstractProcess):
    """Reads image data from the MNIST dataset and converts it to spikes.
    The resulting spike rate is proportional to the pixel value"""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        n_img = kwargs.pop('num_images', 25)
        n_steps_img = kwargs.pop('num_steps_per_image', 128)
        shape = (784,)
        self.spikes_out = OutPort(shape=shape)
        self.label_out = OutPort(shape=(1,))
        self.num_images = Var(shape=(1,), init=n_img)
        self.num_steps_per_image = Var(shape=(1,), init=n_steps_img)
        self.input_img = Var(shape=shape)
        self.ground_truth_label = Var(shape=(1,))
        self.v = Var(shape=shape, init=0)
        self.vth = Var(shape=(1,), init=kwargs['vth'])
        
        
class MnistClassifier(AbstractProcess):
    """A 4 layer feed-forward network with LIF and Dense Processes."""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # As mentioned before, the weights and biases saved on the disk are
        # arbitrary numbers. These will not produce any meaningful output
        # classification.
        trained_weights_path = kwargs.pop('trained_weights_path', os.path
                                          .join('.','mnist_pretrained.npy'))
        real_path_trained_wgts = os.path.realpath(trained_weights_path)

        wb_list = np.load(real_path_trained_wgts, allow_pickle=True)
        w0 = wb_list[0].transpose().astype(np.int32)
        w1 = wb_list[2].transpose().astype(np.int32)
        w2 = wb_list[4].transpose().astype(np.int32)
        b1 = wb_list[1].astype(np.int32)
        b2 = wb_list[3].astype(np.int32)
        b3 = wb_list[5].astype(np.int32)

        self.spikes_in = InPort(shape=(w0.shape[1],))
        self.spikes_out = OutPort(shape=(w2.shape[0],))
        self.w_dense0 = Var(shape=w0.shape, init=w0)
        self.b_lif1 = Var(shape=(w0.shape[0],), init=b1)
        self.w_dense1 = Var(shape=w1.shape, init=w1)
        self.b_lif2 = Var(shape=(w1.shape[0],), init=b2)
        self.w_dense2 = Var(shape=w2.shape, init=w2)
        self.b_output_lif = Var(shape=(w2.shape[0],), init=b3)
        
        
class OutputProcess(AbstractProcess):
    """Process to gather spikes from 10 output LIF neurons and interpret the
    highest spiking rate as the classifier output"""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        shape = (10,)
        n_img = kwargs.pop('num_images', 25)
        self.num_images = Var(shape=(1,), init=n_img)
        self.spikes_in = InPort(shape=shape)
        self.label_in = InPort(shape=(1,))
        self.spikes_accum = Var(shape=shape)
        self.num_steps_per_image = Var(shape=(1,), init=128)
        self.pred_labels = Var(shape=(n_img,))
        self.gt_labels = Var(shape=(n_img,))

### Create ProcessModels for Python execution
The code in these ProcessModels is what will get executed. Processes above
were declarations, in a way.

In [4]:
# Import parent classes for ProcessModels
from lava.magma.core.model.sub.model import AbstractSubProcessModel
from lava.magma.core.model.py.model import PyLoihiProcessModel

# Import ProcessModel ports, data-types
from lava.magma.core.model.py.ports import PyInPort, PyOutPort
from lava.magma.core.model.py.type import LavaPyType

# Import execution protocol and hardware resources
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol
from lava.magma.core.resources import CPU

# Import decorators
from lava.magma.core.decorator import implements, requires

# Import MNIST dataset
from lava.utils.dataloader.mnist import MnistDataset
np.set_printoptions(linewidth=np.inf)

#### ProcessModel for producing spiking input

In [None]:
@implements(proc=SpikeInput, protocol=LoihiProtocol)
@requires(CPU)
class PySpikeInputModel(PyLoihiProcessModel):
    num_images: int = LavaPyType(int, int, precision=32)
    spikes_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, bool, precision=1)
    label_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, np.int32,
                                      precision=32)
    num_steps_per_image: int = LavaPyType(int, int, precision=32)
    input_img: np.ndarray = LavaPyType(np.ndarray, int, precision=32)
    ground_truth_label: int = LavaPyType(int, int, precision=32)
    v: np.ndarray = LavaPyType(np.ndarray, int, precision=32)
    vth: int = LavaPyType(int, int, precision=32)
    mnist_dataset = MnistDataset()
    curr_img_id = -1

    def post_guard(self):
        if self.current_ts % self.num_steps_per_image == 1:
            self.curr_img_id += 1
            return True
        return False

    def run_post_mgmt(self):
        img = self.mnist_dataset.images[self.curr_img_id]
        self.ground_truth_label = self.mnist_dataset.labels[self.curr_img_id]
        self.input_img = img.astype(np.int32) - 127
        self.v = np.zeros(self.v.shape)
        self.label_out.send(np.array([self.ground_truth_label]))

    def run_spk(self):
        self.v[:] = self.v + self.input_img
        s_out = self.v > self.vth
        self.v[s_out] = 0  # reset voltage to 0 after a spike
        self.spikes_out.send(s_out)

#### ProcessModel for the feed-forward network
Notice that the following process model is further decomposed into
sub-Processes, which implement LIF neural dynamics and Dense connectivity. We
 will not go into the details of how these are implemented in this tutorial.

In [None]:
from lava.proc.lif.process import LIF
from lava.proc.dense.process import Dense            

@implements(MnistClassifier)
@requires(CPU)
class PyMnistClassifierModel(AbstractSubProcessModel):
    spikes_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, bool, precision=1)
    spikes_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, bool, precision=1)
    w_dense0: np.ndarray = LavaPyType(np.ndarray, int, precision=8)
    b_lif1: np.ndarray = LavaPyType(np.ndarray, int, precision=13)
    w_dense1: np.ndarray = LavaPyType(np.ndarray, int, precision=8)
    b_lif2: np.ndarray = LavaPyType(np.ndarray, int, precision=13)
    w_dense2: np.ndarray = LavaPyType(np.ndarray, int, precision=8)
    b_output_lif: np.ndarray = LavaPyType(np.ndarray, int, precision=13)

    def __init__(self, proc):
        self.dense0 = Dense(shape=(64, 784), weights=proc.w_dense0.init)
        self.lif1 = LIF(shape=(64,), b=proc.b_lif1.init, vth=400,
                        dv=0, du=4095)
        self.dense1 = Dense(shape=(64, 64), weights=proc.w_dense1.init)
        self.lif2 = LIF(shape=(64,), b=proc.b_lif2.init, vth=350,
                        dv=0, du=4095)
        self.dense2 = Dense(shape=(10, 64), weights=proc.w_dense2.init)
        self.output_lif = LIF(shape=(10,), b=proc.b_output_lif.init,
                              vth=2**17-1, dv=0, du=4095)

        proc.in_ports.spikes_in.connect(self.dense0.in_ports.s_in)
        self.dense0.out_ports.a_out.connect(self.lif1.in_ports.a_in)
        self.lif1.out_ports.s_out.connect(self.dense1.in_ports.s_in)
        self.dense1.out_ports.a_out.connect(self.lif2.in_ports.a_in)
        self.lif2.out_ports.s_out.connect(self.dense2.in_ports.s_in)
        self.dense2.out_ports.a_out.connect(self.output_lif.in_ports.a_in)
        self.output_lif.out_ports.s_out.connect(proc.out_ports.spikes_out)

#### Finally, ProcessModel for inference output

In [None]:
@implements(proc=OutputProcess, protocol=LoihiProtocol)
@requires(CPU)
class PyOutputProcessModel(PyLoihiProcessModel):
    spikes_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, bool, precision=1)
    label_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, int, precision=32)
    num_images: int = LavaPyType(int, int, precision=32)
    spikes_accum: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=32)
    num_steps_per_image: int = LavaPyType(int, int, precision=32)
    pred_labels: np.ndarray = LavaPyType(np.ndarray, int, precision=32)
    gt_labels: np.ndarray = LavaPyType(np.ndarray, int, precision=32)
    current_img_id = -1

    # This is needed for Loihi synchronization protocol
    def post_guard(self):
        if self.current_ts % self.num_steps_per_image == 1 and self\
                .current_ts > 1:
            self.current_img_id += 1
            return True
        return False

    def run_post_mgmt(self):
        print(f'Curr Img: {self.current_img_id}')
        pred_label = np.argmax(self.spikes_accum)
        self.pred_labels[self.current_img_id] = pred_label
        self.spikes_accum = np.zeros(self.spikes_accum.shape)
        gt_label = self.label_in.recv()
        self.gt_labels[self.current_img_id] = gt_label
        print(f'Pred Label: {pred_label}', end='\t')
        print(f'Ground Truth: {gt_label}')

    def run_spk(self):
        spikes_buffer = self.spikes_in.recv()
        self.spikes_accum += spikes_buffer

### Run the Process

In [16]:
num_images = 5
num_steps_per_image = 128

# Create instances
spike_input = SpikeInput(num_images=num_images,
                         num_steps_per_image=num_steps_per_image,
                         vth=1)
mnist_clf = MnistClassifier(
    trained_weights_path=os.path.join('.', 'mnist_pretrained.npy'))
output_proc = OutputProcess(num_images=num_images)

# Connect instances
spike_input.out_ports.spikes_out.connect(mnist_clf.in_ports.spikes_in)
mnist_clf.out_ports.spikes_out.connect(output_proc.in_ports.spikes_in)
spike_input.out_ports.label_out.connect(output_proc.in_ports.label_in)

from lava.magma.core.run_conditions import RunSteps
from lava.magma.core.run_configs import Loihi1SimCfg

mnist_clf.run(
    condition=RunSteps(num_steps=(num_images+1) * num_steps_per_image),
    run_cfg=Loihi1SimCfg(select_sub_proc_model=True))
mnist_clf.stop()

Curr Img: 0
Pred Label: 0	Ground Truth: [5]
Curr Img: 1
Pred Label: 0	Ground Truth: [0]
Curr Img: 2
Pred Label: 0	Ground Truth: [4]
Curr Img: 3
Pred Label: 0	Ground Truth: [1]
Curr Img: 4
Pred Label: 0	Ground Truth: [9]


> **Important Note**:
>
> Right now, this model uses arbitrary _untrained_ network paramters (weights and biases)! We will update this model and fix this shortcoming in the next few days after release.
> Thus the MNIST classifier is not expected to produce any meaningful output at this point in time. 

## How to learn more?

If you want to find out more about Lava, have a look at the [Lava documentation](https://lava-nc.org/ "Lava Documentation") or dive into the [source code](https://github.com/lava-nc/lava/ "Lava Source Code").

To receive regular updates on the latest developments and releases of the Lava Software Framework please subscribe to the [INRC newsletter](http://eepurl.com/hJCyhb "INRC Newsletter").