# Installation
Start with creating a new Conda environment
```
conda create -n frameworks pip
conda activate frameworks
```
Then install PyTorch (adjust for your CUDA version). Instructions available [here](https://pytorch.org/get-started/locally/)
```
conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
```
Install the benchmarked frameworks from PyPI
```
pip install -r requirements.txt
```

This benchmark code is an adaptation of Rockpool's [benchmark script](https://gitlab.com/synsense/rockpool/-/blob/develop/rockpool/utilities/benchmarking/benchmark_utils.py?ref_type=heads). 

In [None]:
import torch
import torch.nn as nn
import numpy as np
from utils import timeit, benchmark_framework

In [None]:
def rockpool_torch():
    from rockpool.nn.modules import LIFTorch, LinearTorch
    from rockpool.nn.combinators import Sequential
    import rockpool

    benchmark_title = f"Rockpool v{rockpool.__version__}"

    def prepare_fn(batch_size, n_steps, n_neurons, n_layers, device):
        model = Sequential(
            LinearTorch(shape=(n_neurons, n_neurons)),
            LIFTorch(n_neurons),
        ).to(device)
        input_static = torch.randn(batch_size, n_steps, n_neurons).to(device)
        with torch.no_grad():
            model(input_static)
        return dict(model=model, input=input_static, n_neurons=n_neurons)

    def forward_fn(bench_dict):
        model, input_static = bench_dict["model"], bench_dict["input"]
        output = model(input_static)[0]
        bench_dict["output"] = output
        return bench_dict

    def backward_fn(bench_dict):
        output = bench_dict["output"]
        loss = output.sum()
        loss.backward(retain_graph=True)

    return prepare_fn, forward_fn, backward_fn, benchmark_title

In [None]:
def sinabs():
    from sinabs.layers import LIF
    import sinabs
    
    benchmark_title = f"Sinabs v{sinabs.__version__}"

    def prepare_fn(batch_size, n_steps, n_neurons, n_layers, device):
        model = nn.Sequential(
            nn.Linear(n_neurons, n_neurons),
            LIF(tau_mem=torch.tensor(10.0)),
        ).to(device)
        input_static = torch.randn(batch_size, n_steps, n_neurons).to(device)
        with torch.no_grad():
            model(input_static)
        return dict(model=model, input=input_static, n_neurons=n_neurons)

    def forward_fn(bench_dict):
        model, input_static = bench_dict["model"], bench_dict["input"]
        sinabs.reset_states(model)
        bench_dict["output"] = model(input_static)
        return bench_dict

    def backward_fn(bench_dict):
        output = bench_dict["output"]
        loss = output.sum()
        loss.backward(retain_graph=True)

    return prepare_fn, forward_fn, backward_fn, benchmark_title

In [None]:
def sinabs_exodus():
    from sinabs.exodus.layers import LIF
    import sinabs

    benchmark_title = f"Sinabs EXODUS v{sinabs.exodus.__version__}"

    def prepare_fn(batch_size, n_steps, n_neurons, n_layers, device):
        model = nn.Sequential(
            nn.Linear(n_neurons, n_neurons),
            LIF(tau_mem=torch.tensor(10.0)),
        ).to(device)
        input_static = torch.randn(batch_size, n_steps, n_neurons).to(device)
        with torch.no_grad():
            model(input_static)
        return dict(model=model, input=input_static, n_neurons=n_neurons)

    def forward_fn(bench_dict):
        model, input_static = bench_dict["model"], bench_dict["input"]
        sinabs.reset_states(model)
        bench_dict["output"] = model(input_static)
        return bench_dict

    def backward_fn(bench_dict):
        output = bench_dict["output"]
        loss = output.sum()
        loss.backward(retain_graph=True)

    return prepare_fn, forward_fn, backward_fn, benchmark_title

In [None]:
def norse():
    from norse.torch.module.lif import LIF
    from norse.torch import SequentialState
    import norse

    benchmark_title = f"Norse v{norse.__version__}"

    def prepare_fn(batch_size, n_steps, n_neurons, n_layers, device):
        model = SequentialState(
            nn.Linear(n_neurons, n_neurons),
            LIF(),
        ).to(device)
        input_static = torch.randn(n_steps, batch_size, n_neurons).to(device)
        with torch.no_grad():
            model(input_static)
        return dict(model=model, input=input_static, n_neurons=n_neurons)

    def forward_fn(bench_dict):
        model, input_static = bench_dict["model"], bench_dict["input"]
        bench_dict["output"] = model(input_static)[0]
        return bench_dict

    def backward_fn(bench_dict):
        output = bench_dict["output"]
        loss = output.sum()
        loss.backward(retain_graph=True)

    return prepare_fn, forward_fn, backward_fn, benchmark_title

In [None]:


def snntorch():
    import snntorch

    benchmark_title = f"snnTorch v{snntorch.__version__}"

    def prepare_fn(batch_size, n_steps, n_neurons, n_layers, device):
        class Model(nn.Module):
            def __init__(self, beta: float = 0.95):
                super().__init__()
                self.fc =  nn.Linear(n_neurons, n_neurons)
                self.lif = snntorch.Leaky(beta=beta)
                self.mem = self.lif.init_leaky()
            
            def forward(self, x):
                output = []
                mem = self.mem
                for inp in x:
                    cur = self.fc(inp)
                    spk, mem = self.lif(cur, mem)
                    output.append(spk)
                return torch.stack(output)

        model = Model().to(device)
        input_static = torch.randn(n_steps, batch_size, n_neurons).to(device)
        with torch.no_grad():
            model(input_static)
        return dict(model=model, input=input_static, n_neurons=n_neurons)

    def forward_fn(bench_dict):
        model, input_static = bench_dict["model"], bench_dict["input"]
        bench_dict["output"] = model(input_static)[0]
        return bench_dict

    def backward_fn(bench_dict):
        output = bench_dict["output"]
        loss = output.sum()
        loss.backward(retain_graph=True)

    return prepare_fn, forward_fn, backward_fn, benchmark_title

In [None]:
batch_size = 10
n_steps = 100
n_layers = 2
n_neurons = 256
device = "cuda"

In [None]:
data = []
for benchmark in [rockpool_torch, sinabs, sinabs_exodus, norse, snntorch]:
# for benchmark in []:
    prepare_fn, forward_fn, backward_fn, bench_desc = benchmark()
    print("Benchmarking:", bench_desc)
    forward_times, backward_times = benchmark_framework(
        prepare_fn=prepare_fn,
        forward_fn=forward_fn,
        backward_fn=backward_fn,
        benchmark_desc=bench_desc,
        n_neurons=n_neurons,
        n_layers=n_layers,
        n_steps=n_steps,
        batch_size=batch_size,
        device=device,
    )
    data.append([bench_desc, np.array(forward_times).flatten(), np.array(backward_times).flatten()])

In [None]:
import pandas as pd

df = pd.DataFrame(data, columns=["framework", "forward", "backward",])
df = df.explode('forward', ignore_index=True)

In [None]:
import plotly.express as px

fig = px.box(df, x="framework", y="forward",)
fig.update_layout(
    template="plotly_white",
)

In [None]:
assert False

In [None]:
import snntorch

beta = 0.9  # neuron decay rate

model = nn.Sequential(
    nn.Linear(n_neurons, n_neurons),
    snntorch.Leaky(beta=beta, init_hidden=True),
).to(device)

static_input = torch.randn(n_steps, batch_size, n_neurons).to(device)

output = []
for inp in static_input:
    output.append(model(inp)[0])

In [None]:
fc =  nn.Linear(n_neurons, n_neurons).to(device)
lif = snntorch.Leaky(beta=beta).to(device)
mem = lif.init_leaky()

output = []
for inp in static_input:
    cur = fc(inp) # post-synaptic current <-- spk_in x weight
    spk, mem = lif(cur, mem) # mem[t+1] <--post-syn current + decayed membrane
    output.append(spk)
output = torch.stack(output)


In [None]:
output[1].shape

In [None]:
inp.shape

In [None]:
prepare, forward, backward, desc = rockpool_torch()
bench_dict = prepare(
    batch_size=10,
    n_steps=500,
    n_neurons=512,
    n_layers=4,
    device="cpu",
)

In [None]:
ok = bench_dict['model'](bench_dict['input'])

In [None]:
import rockpool



In [None]:
import sinabs



In [None]:
# all_backward_times[1]

In [None]:
output.sum().backward()