# Installation
Start with creating a new Conda environment
```
conda create -n frameworks pip
conda activate frameworks
```
Then install PyTorch (adjust for your CUDA version). Instructions available [here](https://pytorch.org/get-started/locally/)
```
conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
```
Install the benchmarked frameworks from PyPI
```
pip install -r requirements.txt
```

This benchmark code is an adaptation of Rockpool's [benchmark script](https://gitlab.com/synsense/rockpool/-/blob/develop/rockpool/utilities/benchmarking/benchmark_utils.py?ref_type=heads). 

In [None]:
import torch
import torch.nn as nn
import numpy as np
from utils import timeit, benchmark_framework

In [None]:
def rockpool_torch():
    from rockpool.nn.modules import LIFTorch, LinearTorch
    from rockpool.nn.combinators import Sequential
    import rockpool

    benchmark_title = f"Rockpool<br>v{rockpool.__version__}"

    def prepare_fn(batch_size, n_steps, n_neurons, n_layers, device):
        model = Sequential(
            LinearTorch(shape=(n_neurons, n_neurons)),
            LIFTorch(n_neurons),
        ).to(device)
        input_static = torch.randn(batch_size, n_steps, n_neurons).to(device)
        with torch.no_grad():
            model(input_static)
        return dict(model=model, input=input_static, n_neurons=n_neurons)

    def forward_fn(bench_dict):
        model, input_static = bench_dict["model"], bench_dict["input"]
        output = model(input_static)[0]
        bench_dict["output"] = output
        return bench_dict

    def backward_fn(bench_dict):
        output = bench_dict["output"]
        loss = output.sum()
        loss.backward(retain_graph=True)

    return prepare_fn, forward_fn, backward_fn, benchmark_title

In [None]:
def sinabs():
    from sinabs.layers import LIF
    import sinabs

    benchmark_title = f"Sinabs<br>v{sinabs.__version__}"

    def prepare_fn(batch_size, n_steps, n_neurons, n_layers, device):
        model = nn.Sequential(
            nn.Linear(n_neurons, n_neurons),
            LIF(tau_mem=torch.tensor(10.0)),
        ).to(device)
        input_static = torch.randn(batch_size, n_steps, n_neurons).to(device)
        with torch.no_grad():
            model(input_static)
        return dict(model=model, input=input_static, n_neurons=n_neurons)

    def forward_fn(bench_dict):
        model, input_static = bench_dict["model"], bench_dict["input"]
        sinabs.reset_states(model)
        bench_dict["output"] = model(input_static)
        return bench_dict

    def backward_fn(bench_dict):
        output = bench_dict["output"]
        loss = output.sum()
        loss.backward(retain_graph=True)

    return prepare_fn, forward_fn, backward_fn, benchmark_title

In [None]:
def sinabs_exodus():
    from sinabs.exodus.layers import LIF
    import sinabs

    benchmark_title = f"Sinabs EXODUS<br>v{sinabs.exodus.__version__}"

    def prepare_fn(batch_size, n_steps, n_neurons, n_layers, device):
        model = nn.Sequential(
            nn.Linear(n_neurons, n_neurons),
            LIF(tau_mem=torch.tensor(10.0)),
        ).to(device)
        input_static = torch.randn(batch_size, n_steps, n_neurons).to(device)
        with torch.no_grad():
            model(input_static)
        return dict(model=model, input=input_static, n_neurons=n_neurons)

    def forward_fn(bench_dict):
        model, input_static = bench_dict["model"], bench_dict["input"]
        sinabs.reset_states(model)
        bench_dict["output"] = model(input_static)
        return bench_dict

    def backward_fn(bench_dict):
        output = bench_dict["output"]
        loss = output.sum()
        loss.backward(retain_graph=True)

    return prepare_fn, forward_fn, backward_fn, benchmark_title

In [None]:
def norse():
    from norse.torch.module.lif import LIF
    from norse.torch import SequentialState
    import norse

    benchmark_title = f"Norse<br>v{norse.__version__}"

    def prepare_fn(batch_size, n_steps, n_neurons, n_layers, device):
        model = SequentialState(
            nn.Linear(n_neurons, n_neurons),
            LIF(),
        ).to(device)
        input_static = torch.randn(n_steps, batch_size, n_neurons).to(device)
        with torch.no_grad():
            model(input_static)
        # output.sum().backward() # JIT compile everything
        return dict(model=model, input=input_static, n_neurons=n_neurons)

    def forward_fn(bench_dict):
        model, input_static = bench_dict["model"], bench_dict["input"]
        bench_dict["output"] = model(input_static)[0]
        return bench_dict

    def backward_fn(bench_dict):
        output = bench_dict["output"]
        loss = output.sum()
        loss.backward(retain_graph=True)

    return prepare_fn, forward_fn, backward_fn, benchmark_title

In [None]:
def snntorch():
    import snntorch

    benchmark_title = f"snnTorch<br>v{snntorch.__version__}"

    def prepare_fn(batch_size, n_steps, n_neurons, n_layers, device):
        class Model(nn.Module):
            def __init__(self, beta: float = 0.95):
                super().__init__()
                self.fc = nn.Linear(n_neurons, n_neurons)
                self.lif = snntorch.Leaky(beta=beta)
                self.mem = self.lif.init_leaky()

            def forward(self, x):
                output = []
                mem = self.mem
                for inp in x:
                    cur = self.fc(inp)
                    spk, mem = self.lif(cur, mem)
                    output.append(spk)
                return torch.stack(output)

        model = Model().to(device)
        input_static = torch.randn(n_steps, batch_size, n_neurons).to(device)
        with torch.no_grad():
            model(input_static)
        return dict(model=model, input=input_static, n_neurons=n_neurons)

    def forward_fn(bench_dict):
        model, input_static = bench_dict["model"], bench_dict["input"]
        bench_dict["output"] = model(input_static)
        return bench_dict

    def backward_fn(bench_dict):
        output = bench_dict["output"]
        loss = output.sum()
        loss.backward(retain_graph=True)

    return prepare_fn, forward_fn, backward_fn, benchmark_title

In [None]:
def spikingjelly():
    from spikingjelly.activation_based import neuron, surrogate, functional, layer

    benchmark_title = f"SpikingJelly<br>v0.0.0.0.14"

    def prepare_fn(batch_size, n_steps, n_neurons, n_layers, device):
        class Model(nn.Module):
            def __init__(self, tau=5.0):
                super().__init__()
                self.layer = nn.Sequential(
                    layer.Linear(n_neurons, n_neurons),
                    neuron.LIFNode(tau=tau, surrogate_function=surrogate.ATan()),
                )

            def forward(self, x):
                functional.reset_net(self.layer)
                output = 0
                for inp in x:
                    output += self.layer(inp)
                return output

        model = Model().to(device)
        input_static = torch.randn(n_steps, batch_size, n_neurons).to(device)
        with torch.no_grad():
            model(input_static)
        return dict(model=model, input=input_static, n_neurons=n_neurons)

    def forward_fn(bench_dict):
        model, input_static = bench_dict["model"], bench_dict["input"]
        bench_dict["output"] = model(input_static)
        return bench_dict

    def backward_fn(bench_dict):
        output = bench_dict["output"]
        loss = output.sum()
        loss.backward(retain_graph=True)

    return prepare_fn, forward_fn, backward_fn, benchmark_title

In [None]:
batch_size = 10
n_steps = 500
n_layers = 1  # doesn't do anything at the moment
device = "cuda"

In [None]:
data = []
for benchmark in [rockpool_torch, sinabs, sinabs_exodus, norse, snntorch, spikingjelly]:
    for n_neurons in [512, 1024, 2048]:
        prepare_fn, forward_fn, backward_fn, bench_desc = benchmark()
        print("Benchmarking", bench_desc, "with n_neurons =", n_neurons)
        forward_times, backward_times = benchmark_framework(
            prepare_fn=prepare_fn,
            forward_fn=forward_fn,
            backward_fn=backward_fn,
            benchmark_desc=bench_desc,
            n_neurons=n_neurons,
            n_layers=n_layers,
            n_steps=n_steps,
            batch_size=batch_size,
            device=device,
        )
        data.append(
            [
                bench_desc,
                np.array(forward_times).mean(),
                np.array(backward_times).mean(),
                n_neurons,
            ]
        )

In [None]:
import pandas as pd

df = pd.DataFrame(data, columns=["framework", "forward", "backward", "neurons"])
df = df.melt(
    id_vars=["framework", "neurons"],
    value_vars=["forward", "backward"],
    var_name="pass",
    value_name="time [s]",
)

In [None]:
import plotly.express as px

fig = px.scatter(
    df,
    y="neurons",
    x="time [s]",
    facet_row="pass",
    color="framework",
    orientation="h",
    symbol="framework",
    log_y=True,
)

fig.update_traces(
    marker=dict(
        size=8,
    )
)
fig.update_layout(
    template="plotly_white",
)

fig.write_image("framework-benchmarking.png", width=800, height=600)
fig.show()

In [None]:
fig = px.scatter(
    df,
    x="neurons",
    y="time [s]",
    facet_col="pass",
    color="framework",
    symbol="framework",
    log_x=True,
)

fig.update_traces(
    marker=dict(
        size=8,
    )
)
fig.update_layout(
    template="plotly_white",
)

fig.write_image("framework-benchmarking2.png", width=800, height=600)
fig.show()

In [None]:
fig = px.pie(
    df,
    values="time [s]",
    names="framework",
    title="Time spent on each framework",
    color="framework",
    width=600,
    height=400,
    facet_col="pass",
)
fig.write_image("framework-benchmarking-pie.png", width=800, height=600)