<a href="https://colab.research.google.com/github/neworderofjamie/riscv_ise/blob/compiler/mnist_classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

There are lots of rough edges here: error checking is lacking in places, the compiler supports an even smaller subset of C than it should and the wrapping of various bits of API is not very Pythonic.

# Installation
The current prototype FeNN toolchain is a little bit tricky to build as it re-uses parts of GeNN (mostly the type system and the GeNNCode scanner, parser and type checker) so, on colab, we can install a prebuilt wheel from my google drive:

In [None]:
if "google.colab" in str(get_ipython()):
    !gdown 1hEx5nI2ITfmrrjfidr5y1SyWnsjFI8Qq
    !pip install pyfenn-0.0.1-cp311-cp311-linux_x86_64.whl

Downloading...
From: https://drive.google.com/uc?id=1hEx5nI2ITfmrrjfidr5y1SyWnsjFI8Qq
To: /content/pyfenn-0.0.1-cp311-cp311-linux_x86_64.whl
  0% 0.00/6.37M [00:00<?, ?B/s]100% 6.37M/6.37M [00:00<00:00, 199MB/s]
Processing ./pyfenn-0.0.1-cp311-cp311-linux_x86_64.whl
pyfenn is already installed with the same version as the provided wheel. Use --force-reinstall to force an installation of the wheel.


In [None]:
!wget https://github.com/neworderofjamie/riscv_ise/raw/refs/heads/compiler/bin/mnist_bias.bin
!wget https://github.com/neworderofjamie/riscv_ise/raw/refs/heads/compiler/bin/mnist_in_hid.bin
!wget https://github.com/neworderofjamie/riscv_ise/raw/refs/heads/compiler/bin/mnist_hid_out.bin

--2025-04-28 15:45:48--  https://github.com/neworderofjamie/riscv_ise/raw/refs/heads/compiler/bin/mnist_bias.bin
Resolving github.com (github.com)... 140.82.113.4
Connecting to github.com (github.com)|140.82.113.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/neworderofjamie/riscv_ise/refs/heads/compiler/bin/mnist_bias.bin [following]
--2025-04-28 15:45:48--  https://raw.githubusercontent.com/neworderofjamie/riscv_ise/refs/heads/compiler/bin/mnist_bias.bin
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.110.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 64 [application/octet-stream]
Saving to: ‘mnist_bias.bin.1’


2025-04-28 15:45:48 (3.30 MB/s) - ‘mnist_bias.bin.1’ saved [64/64]

--2025-04-28 15:45:48--  https://github.com/neworderofjamie/ri

Install the trusty mnist package so we can easily access a dataset:

In [None]:
!pip install mnist



# Imports
Import a bunch of stuff from PyFeNN:

In [None]:
import mnist
import numpy
import numpy as np

from pyfenn import (BackendFeNNSim, EventContainer, Model, ProcessGroup,
                    Runtime, Shape)
from pyfenn import (EventContainer, EventPropagationProcess,
                    NeuronUpdateProcess, NumericValue, Parameter,
                    RNGInitProcess, Shape, UnresolvedType, Variable)

from pyfenn import init_logging
from pyfenn.utils import get_array_view, get_latency_spikes, load_and_push, zero_and_push

# Layer classes
FeNN is programmed using a small number of primitive objects:
*   ``Processes`` perform computation
*   ``Variables`` are used to hold model state e.g. neuron variables and weights
*   ``EventContainers`` are the primary means of communication between neuron processes

The FeNN tools don't really enforce any particular style of modelling but you can easily use these primitives to create PyTorchesque layer objects. We start by creating a leaky integrator for the output layer. This integrates an input current + bias into a membrane voltage which is averaged over the trial. The update to be performed each timestep is implemented in a ``NeuronUpdateProcess`` which performs the same update to each neuron (as dictated by the same of the variables). In future, these processes might be Just-in-Time compiled from Python but, right now, they are implemented in [GeNNCode](https://genn-team.github.io/genn/documentation/5/custom_models.html#genncode). This is basically a subset of C with extensions for fixed-point types inspired by the [ISO standard extension](https://www.open-std.org/jtc1/sc22/wg14/www/docs/n1005.pdf). In the ``LI`` model, this is most obvious is the ``0.0h6`` literal suffix which indicates that this is a fixed point literal with 6 fractional bits (type promotion doesn't work 100% right now...):

In [None]:
class LI:
    def __init__(self, shape, tau_m: float, num_timesteps: int):
        self.shape = Shape(shape)
        dtype = UnresolvedType("s9_6_sat_t")

        self.v = Variable(self.shape, dtype)
        self.i = Variable(self.shape, dtype)
        self.v_avg = Variable(self.shape, dtype)
        self.bias = Variable(self.shape, dtype)
        self.process = NeuronUpdateProcess(
            """
            V = (Alpha * V) + I + Bias;
            I = 0.0h6;
            VAvg += (VAvgScale * V);
            """,
            {"Alpha": Parameter(NumericValue(np.exp(-1.0 / tau_m)), dtype),
             "VAvgScale": Parameter(NumericValue(1.0 / (num_timesteps / 2)), dtype)},
            {"V": self.v, "VAvg": self.v_avg, "I": self.i, "Bias": self.bias})


The Leaky Integrate-and-Fire model we use for the hidden layer is slightly more complex, but is defined in basically the same way. Because the LIF neuron emits spikes, as well as variables, it has an ``EventContainer`` to manage the emitted spike. In the process code, events are emitted by calling the name of assigned to the event container i.e. ``Spike()``:

In [None]:
class LIF:
    def __init__(self, shape, tau_m: float, tau_refrac: int, v_thresh: float):
        self.shape = Shape(shape)
        dtype = UnresolvedType( "s10_5_sat_t")
        self.v = Variable(self.shape, dtype)
        self.i = Variable(self.shape, dtype)
        self.refrac_time = Variable(self.shape, UnresolvedType("int16_t"))
        self.out_spikes = EventContainer(self.shape)
        self.process = NeuronUpdateProcess(
            """
            V = (Alpha * V) + I;
            I = 0.0h5;
            if (RefracTime > 0) {
               RefracTime -= 1;
            }
            else if(V >= VThresh) {
               Spike();
               V -= VThresh;
               RefracTime = TauRefrac;
            }
            """,
            {"Alpha": Parameter(NumericValue(np.exp(-1.0 / tau_m)), dtype),
             "VThresh": Parameter(NumericValue(v_thresh), dtype),
             "TauRefrac": Parameter(NumericValue(tau_refrac), UnresolvedType("int16_t"))},
            {"V": self.v, "I": self.i, "RefracTime": self.refrac_time},
            {"Spike": self.out_spikes})

Synapse updates are also fully programmable but this is currently not exposed. All that is currently exposed is a event-driven spike propagation process which takes an ``EventContainer`` of events and propagates them through a ``Variable`` of weights and writes the accumulated result to a target variable:

In [None]:
class Linear:
    def __init__(self, source_events: EventContainer, target_var: Variable,
                 weight_dtype: str):
        self.shape = Shape([source_events.shape.num_neurons,
                            target_var.shape.num_neurons])
        weight_dtype = UnresolvedType(weight_dtype)

        self.weight = Variable(self.shape, weight_dtype)
        self.process = EventPropagationProcess(source_events, self.weight,
                                               target_var)


# Parameters

In [None]:
num_examples = 10000
num_timesteps = 79
input_shape = [28 * 28]
hidden_shape = [128]
output_shape = [10]
input_hidden_shape = [28 * 28, 128]
hidden_output_shape = [128, 10]

# Dataset
Convert MNIST into a latency. Yan LeCun's original site has been down for some time/blocking colab so we override

In [None]:
mnist.datasets_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"
mnist_spikes = get_latency_spikes(mnist.test_images())
mnist_labels = mnist.test_labels().astype(np.int16)


  times = np.round(tau * np.log(i / (i - threshold))).astype(int)
  times = np.round(tau * np.log(i / (i - threshold))).astype(int)
  times = np.round(tau * np.log(i / (i - threshold))).astype(int)


# Model definition
The FeNN tools can produce lots of helpful logging information so we initialise this system before we do anything else (if you use ``from pyfenn import PlogSeverity`` to import the enum you can then use e.g. ``PlogSeverity.DEBUG`` to control the logging level):

In [None]:
init_logging()

Input spikes can be directly injected into FeNN rather than needing any sort of layer so define an EventContainer to hold them

In [None]:
input_spikes = EventContainer(Shape(input_shape), num_timesteps)

Then create hidden and output layers using the classes we defined above. The fixed-point types are specified as strings, for example s10_5_sat_t is a signed 16-bit fixed point type (this is all FeNN currently supports) with 10 integer and 5 fractional bits to which saturation should be applied (currently only when adding and subtracting):

In [None]:
hidden = LIF(hidden_shape, 20.0, 5, 0.61)
output = LI(output_shape, 20.0, num_timesteps)

Now we connect spiking outputs to input variables using the linear layer class we defined earlier:

In [None]:
input_hidden = Linear(input_spikes, hidden.i, "s10_5_sat_t")
hidden_output = Linear(hidden.out_spikes, output.i, "s9_6_sat_t")

Process groups define computation that can be performed in parallel (in fact, on FeNN it's not but this won't be the case with e.g. GPU backends) so we group our neuron update processes and event propagation processes into seperate groups

In [None]:
neuron_update_processes = ProcessGroup([hidden.process, output.process])
synapse_update_processes = ProcessGroup([input_hidden.process, hidden_output.process])

Now we define a model which groups together all parts of our simulation:

In [None]:
model = Model([neuron_update_processes, synapse_update_processes])

# Simulation
Sadly Google has yet to install FeNN nodes into it's cloud so for now we create a simulation backend (if you are lucky enough to be running on a Kria KV260 with the bitstream loaded, ``BackendFeNNHW`` would be what you need) and use it to create a generic simulation kernel. The control flow of these kernels *will* be fully programmable but for now you can either create a really simple kernel which just runs a list of process groups or a 'simulation' kernel which offloads running a loop over time with a list of process groups in the body and another list that runs at the end (for example to copy data off of FeNN)

In [None]:
backend = BackendFeNNSim()
code = backend.generate_simulation_kernel([synapse_update_processes, neuron_update_processes],
                                          [],
                                          num_timesteps, model)

Now we have some code, we create a ``Runtime`` object to interact with the FeNN. We first use this to allocate the memory required for our model on FeNN:

In [None]:
runtime = Runtime(model, backend)
runtime.allocate()

Now we use some helper functions to load weights into the appropriate variables:

In [None]:
load_and_push("mnist_in_hid.bin", input_hidden.weight, runtime)
load_and_push("mnist_hid_out.bin", hidden_output.weight, runtime)
load_and_push("mnist_bias.bin", output.bias, runtime)

and set the remaining variables to zero:

In [None]:
zero_and_push(hidden.v, runtime)
zero_and_push(hidden.i, runtime)
zero_and_push(hidden.refrac_time, runtime)
zero_and_push(output.v, runtime)
zero_and_push(output.i, runtime)
zero_and_push(output.v_avg, runtime)

Finally we upload the code generated by the backend to FeNN:

In [None]:
runtime.set_instructions(code)

The ``Runtime`` object creates a bunch of 'Array' objects which are used to interact with model state at runtime. To save typing later on, we look these up now:

In [None]:
input_spike_array, input_spike_view = get_array_view(runtime, input_spikes,
                                                     np.uint32)
hidden_spike_array = runtime.get_array(hidden.out_spikes)

output_v_avg_array, output_v_avg_view = get_array_view(runtime, output.v_avg,
                                                       np.int16)

Finally, we're ready to go! Now we can loop through the MNIST digits and:
1.   Copy each digit into the input spike array
2.   Run the kernel
3.   Copy the averaged output voltage back from FeNN
4.   Check whether this matches the correct label

In [None]:
num_correct = 0
for i in range(num_examples):
    # Copy data to array host pointe
    input_spike_view[:] = mnist_spikes[i]
    input_spike_array.push_to_device();

    # Classify
    runtime.run()

    # Copy output V sum from device
    output_v_avg_array.pull_from_device();

    # Determine if output is correct
    classification = np.argmax(output_v_avg_view)
    if classification == mnist_labels[i]:
        num_correct += 1

    # Zero output and push
    output_v_avg_view[:] = 0
    output_v_avg_array.push_to_device()

print(f"{num_correct} / {num_examples} correct {100.0 * (num_correct / num_examples)}%")

9586 / 10000 correct 95.86%


# Disassembling 😥
Sometimes it's cool to know what's happening under the hood so, by using the ``disassemble`` function you can disassemble the code produced be the backend into a slightly friendly form. A slightly outdated description of the instruction set is provided at https://github.com/neworderofjamie/riscv_ise/blob/master/docs/instruction_set.pdf

In [None]:
from pyfenn import disassemble, init_logging
for i, c in enumerate(code):
    print(f"{i * 4} : {disassemble(c)}")

0 : ADDI X1, X0, 0
4 : ADDI X2, X0, 79
8 : LW X7, 44(X0)
12 : ADDI X8, X0, 16
16 : ADD X3, X8, X7
20 : LW X8, 52(X0)
24 : ADDI X9, X0, 64
28 : ADDI X5, X0, 1
32 : ADDI X4, X0, 31
36 : LW X6, 0(X7)
40 : ADDI X7, X7, 4
44 : BEQ X6, X0, 80
48 : ADDI X10, X4, 0
52 : CLZ X11, X6, 1536
56 : BEQ X6, X5, 80
60 : ADDI X12, X11, 1
64 : SLL X6, X6, X12
68 : SUB X10, X10, X11
72 : LW X8, 52(X0)
76 : LW X13, 48(X0)
80 : MUL X14, X10, X9
84 : ADD X13, X13, X14
88 : VLOAD V1, 0(X8)
92 : ADDI X14, X0, 1023
96 : VLOAD V0, 0(X13)
100 : VLOAD V2, 64(X8)
104 : VADD_S V3, V1, V0
108 : VSEL V1, X14, V3
112 : VSTORE V1, 0(X8)
116 : ADDI X10, X10, -1
120 : BNE X6, X0, -68
124 : ADDI X4, X4, 32
128 : BNE X7, X3, -92
132 : BEQ X0, X0, 12
136 : ADDI X6, X0, 0
140 : BEQ X0, X0, -72
144 : LW X7, 32(X0)
148 : ADDI X8, X0, 100
152 : MUL X9, X1, X8
156 : ADD X7, X7, X9
160 : ADD X3, X8, X7
164 : LW X8, 40(X0)
168 : ADDI X9, X0, 256
172 : ADDI X5, X0, 1
176 : ADDI X4, X0, 31
180 : LW X6, 0(X7)
184 : ADDI X7, X7, 4
188