In [None]:
import pandas as pd
import random
import torch

from torch.utils.data import DataLoader, Subset

from sqp_snn import project_config
from sqp_snn.data.datasets import SQPDataset
from sqp_snn.models.utils import load_wandb_model, count_state_flops
from sqp_snn.neurobench_ext import StorkModel, StorkBenchmark

In [None]:
wandb_project = 'your-project'
wandb_run = 'your-run'

model, config = load_wandb_model(wandb_project, wandb_run, device='cuda:1')
nb_model = StorkModel(model)

state_flops = count_state_flops(model)

In [None]:
df_test = pd.read_pickle(project_config.METRICS_FILE_TEST)
dataset_test = SQPDataset(
    df = df_test,
    #cache_fname = "%s/dns2020_sqp_test_cache.pkl.gz"%(project_config.CACHE_DIR)
)

In [None]:
dataloader = DataLoader(
    dataset = dataset_test,
    batch_size = config.batch_size,
    shuffle = False
)

static_metrics = ["parameter_count", "footprint", "connection_sparsity"]
workload_metrics = ["activation_sparsity", "synaptic_operations", "membrane_updates"]

benchmark = StorkBenchmark(nb_model, dataloader, [], [], [static_metrics, workload_metrics])
results = benchmark.run()

In [None]:
params = results['parameter_count']
act_sparsity = results['activation_sparsity']
e_acs = results['synaptic_operations']['Effective_ACs']

print(f"\nParams: {params:.2e}")
print(f"Act. sparsity: {act_sparsity:.2f}")

print(f"\nState FLOPs: {state_flops:.3e}")
print(f"eACs: {e_acs:.3e}")

print(f"\neFLOPs: {state_flops + e_acs:.3e}")