In [None]:
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv(usecwd=True), override=False)

import torch
import hydra.utils

from hydra import initialize, compose
from hydra.core.global_hydra import GlobalHydra

from neurobench.models import TorchModel
from neurobench.benchmarks import Benchmark
from neurobench.metrics.static import ParameterCount, Footprint, ConnectionSparsity
from neurobench.metrics.workload import ActivationSparsity, SynapticOperations

from omegaconf import OmegaConf

from torch.utils.data import DataLoader, Subset

from sqp_ann.utils import register_resolvers, pretty_configs, DnsmosPreProcessor

In [None]:
# Use model version with externalized preprocessing for neurobench compatiblity
model_type = 'dnsmos_no_preproc'

# Add your model path here
checkpoint_path = "/path/to/pytorch_model.bin"

GlobalHydra.instance().clear()
initialize(version_base=None, config_path="config/")

register_resolvers()

config = compose(config_name = 'train.yaml',
                 overrides = [f'model={model_type}'])

print("Model configs:")
print(OmegaConf.to_yaml(config.model))

model = hydra.utils.instantiate(config.model)
dataset_test = hydra.utils.instantiate(config.dataset_test)

if checkpoint_path:
    state_dict = torch.load(checkpoint_path)
    state_dict = {k: v for k, v in state_dict.items() if not k.startswith('preproc.')}
    model.load_state_dict(state_dict)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

In [None]:
nb_model = TorchModel(model)
preprocessors = [DnsmosPreProcessor(device=device)]

static_metrics = [ParameterCount, Footprint, ConnectionSparsity]
workload_metrics = [ActivationSparsity, SynapticOperations]

data_loader = DataLoader(dataset_test,
                         batch_size = 1,
                         shuffle = False)

benchmark = Benchmark(nb_model,
                      data_loader,
                      preprocessors,
                      [],
                      [static_metrics, workload_metrics])

results = benchmark.run()
results

In [None]:
params = results['ParameterCount']
footprint = results['Footprint']
act_sparsity = results['ActivationSparsity']
e_macs = results['SynapticOperations']['Effective_MACs']

e_flops = e_macs * 2

print(f"\nParams: {params:.2e}")
print(f"Activation sparsity: {act_sparsity:.3f}")
print(f"Effective FLOPs: {e_flops:.3e}")