**The training and benchmark was ran in Kaggle Notebook with logging to WandB. The logging codes and plotting codes are commented out here to prevent error while running in different environment and to shorten notebook run time. Besides max_epoch, Kaggle notebook maximum duration of 12 hours is also another training stopper**

# 1. Install libraries

In [None]:
!pip install snntorch --quiet

In [None]:
!pip show snntorch

In [None]:
!pip install --quiet lightning

In [None]:
!pip show lightning

In [None]:
# !pip install --quiet wandb # logging code

In [None]:
# !pip show wandb # logging code

In [None]:
import numpy as np

import torch
import torch.nn as nn
import snntorch as snn
from snntorch import surrogate

import torch
from torch.utils.data import DataLoader, Subset

import tqdm

import matplotlib.pyplot as plt

import statistics

import lightning as L
import torch.nn.functional as F
# from torchmetrics.regression import MeanSquaredError
from torchmetrics.functional import r2_score
from torchmetrics.functional import pearson_corrcoef

# import wandb # logging code
# from kaggle_secrets import UserSecretsClient # logging code

from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import Callback

from lightning.pytorch import seed_everything

from lightning.pytorch.callbacks.early_stopping import EarlyStopping

import time
import random

In [None]:
# setup
seed_num = 42
seed_everything(seed_num, workers=True)
torch.manual_seed(seed_num)
random.seed(seed_num)
np.random.seed(seed_num)

# dataset
data_dir = './nonhuman-primate-reaching/'
filename = 'indy_20160622_01'
batch_size = 1024
bin_width = 0.032
num_workers = 3

def seed_worker(worker_id):
    np.random.seed(seed_num)
    random.seed(seed_num)

g = torch.Generator()
g.manual_seed(seed_num)


# scheduler & training
max_epoch = 50
accelerator = 'cpu'
devices = '1'
strategy = L.pytorch.strategies.SingleDeviceStrategy()
sync_dist = (accelerator=='gpu')

# # logging code
# user_secrets = UserSecretsClient()
# wandb_api = user_secrets.get_secret("wandb_api") 
# wandb.login(key=wandb_api)
# run_name = f"sLSTM1, {seed_num}, {filename}"

In [None]:
# wandb.init(project='grandChallenge', name=run_name) # logging code

In [None]:
# !pip install --quiet --force-reinstall -v "neurobench==1.0.3" 
!pip install neurobench --quiet

In [None]:
!pip show neurobench

In [None]:
#!pip install neurobench

from neurobench.datasets import PrimateReaching
from neurobench.models.torch_model import TorchModel
from neurobench.benchmarks import Benchmark

# 2. Prepare dataset

In [None]:
class NHPDataModule(L.LightningDataModule):
    def __init__(self, data_dir = './nonhuman-primate-reaching/', filename = 'indy_20160622_01', batch_size = 128, bin_width = 0.004, num_workers = 1, seed_worker = None, generator = None):
        super().__init__()
        self.data_dir = data_dir
        self.filename = filename
        self.batch_size = batch_size
        self.bin_width = bin_width
        self.num_workers = num_workers
        self.seed_worker = seed_worker
        self.generator = generator
        self.dataset = None

    def prepare_data(self):
        # download
        PrimateReaching(file_path=self.data_dir, filename=self.filename,
                        num_steps=1, train_ratio=0.5, bin_width=self.bin_width,
                        biological_delay=0, remove_segments_inactive=False, download=False)

    def setup(self, stage=None):
        self.dataset = PrimateReaching(file_path=self.data_dir, filename=self.filename,
                        num_steps=1, train_ratio=0.5, bin_width=self.bin_width,
                        biological_delay=0, remove_segments_inactive=False, download=False)
        
        # Handle Loco's recording
        self.dataset.samples = self.dataset.samples[0:96,:] 
        
        # Assign train/val datasets for use in dataloaders    
        self.nhp_train = Subset(self.dataset, self.dataset.ind_train)
        self.nhp_val = Subset(self.dataset, self.dataset.ind_val)

        # Assign test dataset for use in dataloader(s)
        self.nhp_test = Subset(self.dataset, self.dataset.ind_test)

    def train_dataloader(self):
        return DataLoader(self.nhp_train, batch_size=self.batch_size, shuffle=False, num_workers = self.num_workers, worker_init_fn=self.seed_worker, generator=self.generator)

    def val_dataloader(self):
        return DataLoader(self.nhp_val, batch_size=len(self.nhp_val), shuffle=False, num_workers = self.num_workers, worker_init_fn=self.seed_worker, generator=self.generator)

    def test_dataloader(self):
        return DataLoader(self.nhp_test, batch_size=len(self.nhp_test), shuffle=False, num_workers = self.num_workers, worker_init_fn=self.seed_worker, generator=self.generator)

In [None]:
dm = NHPDataModule(data_dir=data_dir, 
                   filename=filename, 
                   batch_size=batch_size, 
                   bin_width=bin_width, 
                   num_workers=num_workers, 
                   seed_worker=seed_worker, 
                   generator=g)

# 3. Construct model

In [None]:
class LitModel(L.LightningModule):
    def __init__(self, window=50, input_size=96, hidden_size=256, tau=0.96, p=0.3):
        super().__init__()

        # self.window = window
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = 2
        self.surrogate = surrogate.atan(alpha=2)
        
        thr_lstm_1 = torch.rand(self.hidden_size)
        self.slstm_1 = snn.SLSTM(
            input_size=self.input_size,
            hidden_size=self.hidden_size,
            spike_grad=self.surrogate,
            learn_threshold=True,
            threshold=thr_lstm_1,
            reset_mechanism="subtract",
        )
        
        thr_lstm_2 = torch.rand(self.hidden_size)
        self.slstm_2 = snn.SLSTM(
            input_size=self.hidden_size,
            hidden_size=self.hidden_size,
            spike_grad=self.surrogate,
            learn_threshold=True,
            threshold=thr_lstm_2,
            reset_mechanism="subtract",
        )
        
        thr_lstm_3 = torch.rand(self.hidden_size)
        self.slstm_3 = snn.SLSTM(
            input_size=self.hidden_size,
            hidden_size=self.hidden_size,
            spike_grad=self.surrogate,
            learn_threshold=True,
            threshold=thr_lstm_3,
            reset_mechanism="subtract",
        )
        
        beta_out = torch.rand(self.hidden_size)
        thr_out = torch.rand(self.hidden_size)
        self.fc1 = torch.nn.Linear(self.hidden_size, out_features=self.hidden_size)
        self.lif1 = snn.Leaky(
            beta=beta_out,
            threshold=thr_out,
            learn_beta=True,
            learn_threshold=True,
            spike_grad=self.surrogate,
            reset_mechanism="none",
        )
        
        beta_out2 = torch.rand(self.output_size)
        thr_out2 = torch.rand(self.output_size)
        self.fc2 = torch.nn.Linear(self.hidden_size, out_features=self.output_size)
        self.lif2 = snn.Leaky(
            beta=beta_out2,
            threshold=thr_out2,
            learn_beta=True,
            learn_threshold=True,
            spike_grad=self.surrogate,
            reset_mechanism="none",
        )
        
        self.syn_lstm_1, self.mem_lstm_1 = None, None
        self.syn_lstm_2, self.mem_lstm_2 = None, None
        self.syn_lstm_3, self.mem_lstm_3 = None, None
        
        self.mem_out = None
        self.mem_out2 = None
        
#         self.loss = MeanSquaredError()

        # self.register_buffer('inp', torch.zeros(window, self.input_size))
    
    def reset(self):
        self.syn_lstm_1, self.mem_lstm_1 = self.slstm_1.init_slstm()
        self.syn_lstm_2, self.mem_lstm_2 = self.slstm_2.init_slstm()
        self.syn_lstm_3, self.mem_lstm_3 = self.slstm_3.init_slstm()

        self.mem_out = self.lif1.init_leaky()
        self.mem_out2 = self.lif2.init_leaky()
    
    def forward(self, x):
        # here x is expected to be shape (len_series, 1, input_dim)
        
        self.reset()
        
        mem_out_rec = []
        
        for step in range(x.shape[0]):
            x_timestep = x[step, ...].flatten(1)

            spk_lstm_1, self.syn_lstm_1, self.mem_lstm_1 = self.slstm_1(
                x_timestep, self.syn_lstm_1, self.mem_lstm_1
            )
            spk_lstm_2, self.syn_lstm_2, self.mem_lstm_2 = self.slstm_2(
                spk_lstm_1, self.syn_lstm_2, self.mem_lstm_2
            )
            _, self.syn_lstm_3, self.mem_lstm_3 = self.slstm_3(
                spk_lstm_2, self.syn_lstm_3, self.mem_lstm_3
            )

            cur_out = self.fc1(self.mem_lstm_3)
            _, self.mem_out = self.lif1(cur_out, self.mem_out)
            cur_out2 = self.fc2(self.mem_out)
            _, self.mem_out2 = self.lif2(cur_out2, self.mem_out2)

            mem_out_rec.append(self.mem_out2)

        return torch.stack(mem_out_rec).squeeze()
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(params=self.parameters(), lr=1e-3, betas=(0.9, 0.999), weight_decay=1e-2)
#         lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_epoch)
        return [optimizer]#, [lr_scheduler]
    
    def training_step(self, train_batch, batch_idx):
        _, loss, r2score, cc = self._get_preds_loss_accuracy(train_batch, batch_idx)
        
        self.log('train_loss', loss, prog_bar=True, on_epoch=True, sync_dist=sync_dist)
        self.log('train_r2score',r2score, prog_bar=True, on_epoch=sync_dist, sync_dist=sync_dist)
        self.log('train_cc',cc, prog_bar=True, on_epoch=True, sync_dist=sync_dist)
        
        return loss
        
    def validation_step(self, val_batch, batch_idx):
        mem, loss, r2score, cc = self._get_preds_loss_accuracy(val_batch, batch_idx)
        
        self.log('val_loss',loss, prog_bar=True, sync_dist=sync_dist)
        self.log('val_r2score',r2score, prog_bar=True, sync_dist=sync_dist)
        self.log('val_cc',cc, prog_bar=True, sync_dist=sync_dist)
        
        return mem
        
    def _get_preds_loss_accuracy(self, batch, batch_idx):
        '''convenience function since train/valid/test steps are similar'''
        feature, label = batch
        mem = model(feature)
        loss = F.mse_loss(mem, label) # calculate loss
        r2 = r2_score(mem, label, multioutput='uniform_average')
        cc = (sum(pearson_corrcoef(mem, label))/2).item()
        
        return mem, loss, r2, cc

In [None]:
model = LitModel()

In [None]:
print(model)

In [None]:
model.reset()

# 4. Construct training loop

In [None]:
checkpoint_callback = ModelCheckpoint(dirpath='./model-checkpoint/', monitor='val_r2score', mode='max')

# wandb_logger = WandbLogger(project='grandChallenge', name=run_name, log_model='all')  # logging code

In [None]:
trainer=L.Trainer(
    accelerator=accelerator,
    devices=devices,
    strategy=strategy,
#     logger=wandb_logger, # logging code
    callbacks=[checkpoint_callback, EarlyStopping(monitor="val_loss", min_delta=0, patience=10, mode='min')],
    max_epochs=max_epoch,
    deterministic=True)

In [None]:
start_time = time.time()
trainer.fit(model, dm)
end_time = time.time()

In [None]:
 print(f"Time taken: {end_time - start_time} seconds\n")

# 5. Benchmark

In [None]:
checkpoint_callback.best_model_path

In [None]:
# # plotting code
# model = LitModel.load_from_checkpoint(checkpoint_callback.best_model_path)

# model.eval()

In [None]:
# dm.setup()

In [None]:
# # plotting code
# with torch.no_grad():
#     test_batch = iter(dm.test_dataloader())
#     for feature, label in test_batch:
#         mem = model(feature)

#     cc_test = sum(pearson_corrcoef(mem, label))/2
    
# fig, ax = plt.subplots()
# ax.plot(mem[:,0].detach().numpy(), label="Output")
# ax.plot(label[:,0], '--', label="Target")
# ax.set_xlabel("Time")
# ax.set_ylabel("mm/s")
# ax.legend(loc='best')

# # wandb.log({"Testing X-axis Output": fig}) # logging code

# fig.show()

In [None]:
# # plotting code
# fig, ax = plt.subplots()
# ax.plot(mem[:,1].detach().numpy(), label="Output")
# ax.plot(label[:,1], '--', label="Target")
# ax.set_xlabel("Time")
# ax.set_ylabel("mm/s")
# ax.legend(loc='best')

# # wandb.log({"Testing Y-axis Output": fig}) # logging code

# fig.show()

In [None]:
model = LitModel.load_from_checkpoint(checkpoint_callback.best_model_path)
model.eval()
model = TorchModel(model) # using TorchModel instead of SNNTorchModel because the SNN iterates over dimension 0
model.add_activation_module(snn.SpikingNeuron)

In [None]:
static_metrics = ["footprint"]
workload_metrics = ["r2", "synaptic_operations"]

# Benchmark expects the following:
benchmark = Benchmark(model, dm.test_dataloader(), [], [], [static_metrics, workload_metrics])
results = benchmark.run()
print(results)

# # logging code
# columns = list(results.keys())
# data = list(results.values())
# synaptic_operations = data[-1]
# columns.pop()
# columns += list(synaptic_operations.keys())
# columns += ["cc"]
# columns += ["seed"]
# data.pop()
# data += list(synaptic_operations.values())
# data += [cc_test.item()]
# data += [seed_num]

# wandb_logger.log_table(key='benchmark_table', columns=columns, data=[data]) # logging code

In [None]:
# SLSTM1
# seed used: 42, 26, 0, 1, 1234

# recording order: [indy_20160622_01, indy_20160630_01, indy_20170131_02, loco_20170210_03, loco_20170215_02, loco_20170301_05]

# array of mean of 5 seeds for each recording, followed by mean of 6 recordings at the end of the array
# # CC: [0.851101851463318, 0.754076862335205, 0.820740652084351, 0.550093793869019, 0.547384750843048, 0.658520746231079] 0.696986442804337
# # R2: [0.709797763824463, 0.545500159263611, 0.632305586338043, 0.261918079853058, 0.158576899766922, 0.414933794736862] 0.453838713963826
# # Footprint: [5931092, 5931092, 5931092, 5931092, 5931092, 5931092] 5931092
# # Dense: [1214972.0434668, 1214971.11368713, 1214965.93015581, 1214972.18303101, 1214969.14050721, 1214971.3763332] 1214970.29786353
# # MACs: [110208.522818939, 93088.6087599122, 89603.3402333257, 79544.6241215602, 83077.8577263284, 90685.7489950368] 91034.7837758503
# # ACs: [536632.370331086, 538325.53474368, 533955.154562898, 534089.245275364, 532260.841493657, 532443.254005773] 534617.733402077

# array of standard error of 5 seeds for each recording with order: indy_20160622_01, indy_20160630_01, indy_20170131_02, loco_20170210_03, loco_20170215_02, loco_20170301_05
# # CC: [0.00307949979675383, 0.00427690195465703, 0.00330616434867517, 0.00640923480410287, 0.00605118053783611, 0.0056038898047367] 
# # R2: [0.00752700537786349, 0.00910423870285717, 0.00963272275834302, 0.0104894370981451, 0.0475101446991847, 0.0069340290717505] 
# # Footprint: [0, 0, 0, 0, 0, 0] 
# # Dense: [0.00189139095793311, 0.00190797066316009, 0, 0.00149042131379247, 0.00267844311892986, 0.00220944681934646] 
# # MACs: [0.000472386435601032, 3.72286885977047E-06, 7.67518649809076E-06, 0.00037224180996418, 0.000668957497691736, 0.000450561440084129] 
# # ACs: [1049.08713272015, 1296.97879936504, 338.680181635065, 484.066750655915, 430.964037229618, 257.08173220729] 

In [None]:
# wandb.finish() # logging code