In [1]:
import numpy as np
import torch
import torch.nn as nn
import snntorch as snn
from snntorch import utils
from snntorch import surrogate
import torch.nn.functional as F
from snntorch import functional as SF
import brevitas.nn as qnn 
from tqdm import tqdm
from pathlib import Path

import sys
sys.path.append('../src')
from my_network import *
from dataloader import WisdmDatasetParser, WisdmDataset
from torch.utils.data import  DataLoader
from assistant import Assistant
from stats import LearningStats
from utils import *

In [2]:
#path = f"{Path.home()}/snntorch_network/nni_experiments/inibitory_lif_no_encoder_best/results/ot6eqima/trials/oqxoS/Trained/network_best.npz"
#path = f"{Path.home()}/snntorch_network/nni_experiments/inibitory_lif_no_encoder_worst/results/ipk2erm5/trials/whXvC/Trained/network_best.npz"
path = f"{Path.home()}/snntorch_network/nni_experiments/inibitory_lif_no_encoder_balanced/results/4m8j0yfa/trials/yAnF7/Trained/network_best.npz"
#path = f"{Path.home()}/snntorch_network/nni_experiments/inibitory_lif_no_encoder_balanced_no_net_loss/results/uhrgfo1n/trials/TbZxT/Trained/network_best.npz" # best accuracy
#path = f"{Path.home()}/snntorch_network/nni_experiments/inibitory_lif_no_encoder_balanced_no_net_loss/results/uhrgfo1n/trials/bnh2o/Trained/network_best.npz" #same HP as the one with net_loss

name = "balanced_with_loss"
data = np.load(path,allow_pickle=True)

linear1_w= data['linear1']
leaky1_vth= data['leaky1_vth']
leaky1_betas= 1-data['leaky1_betas'] 
leaky1_betas= leaky1_betas if leaky1_betas >= 0 else np.zeros(leaky1_betas.shape)
print(f"leaky1_betas: {leaky1_betas}")
print(f"leaky1_vth: {leaky1_vth}")
linear2_w = data['linear2']
leaky2_vth= data['recurrent_vth']
leaky2_betas= 1 - data['recurrent_betas']
leaky2_betas= leaky2_betas if  leaky2_betas >= 0 else np.zeros(leaky2_betas.shape)
print(f"leaky2_betas: {leaky2_betas}")
print(f"leaky2_vth: {leaky2_vth}")

recurrent_in_weights = data['input_dense']
recurrent_out_weights = - data['output_dense']
recurrent_vth = data['activation_vth']
recurrent_leaky_betas = 1 - data['activation_betas']
recurrent_leaky_betas= recurrent_leaky_betas if recurrent_leaky_betas >= 0 else np.zeros(recurrent_leaky_betas.shape)
print(f"recurrent_leaky_betas: {recurrent_leaky_betas}")
print(f"recurrent_vth: {recurrent_vth}")

linear3_w = data['linear3']
leaky3_vth= data['leaky2_vth']
leaky3_betas= 1 - data['leaky2_betas']
leaky3_betas= leaky3_betas if leaky3_betas >= 0 else np.zeros(leaky3_betas.shape)
print(f"leaky3_betas: {leaky3_betas}")
print(f"leaky3_vth: {leaky3_vth}")

# Given a list of numpy matrices, count the zeros inside
def count_zeros(matrix_list):
    return [np.count_nonzero(matrix == 0) for matrix in matrix_list]

# Example usage
matrix_list = [linear1_w, linear2_w, linear3_w, recurrent_in_weights, recurrent_out_weights]
zero_counts = count_zeros(matrix_list)
print(f"Zero counts in matrices: {zero_counts}")


leaky1_betas: 1.01234897878021
leaky1_vth: 2.195924758911133
leaky2_betas: 0.07083219289779663
leaky2_vth: 0.8788766860961914
recurrent_leaky_betas: 0.11015737056732178
recurrent_vth: 1.8151839971542358
leaky3_betas: 0.0
leaky3_vth: 0.830508828163147
Zero counts in matrices: [0, 0, 0, 0, 0]


In [3]:
def count_parameters(matrix_list):
    return sum(matrix.size for matrix in matrix_list)

# Example usage
matrix_list = [linear1_w, linear2_w, linear3_w, recurrent_in_weights, recurrent_out_weights]
total_parameters = count_parameters(matrix_list)
print(f"Total number of parameters: {total_parameters}")
print(f"model footprint int 8: {total_parameters/1024} KB")
print(f"model footprint fp 16: {total_parameters*4/1024} KB")

Total number of parameters: 604700
model footprint int 8: 590.52734375 KB
model footprint fp 16: 2362.109375 KB


In [4]:

#device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
device = 'cuda'
print(f'Using device {device}')

Using device cuda


In [5]:
batch_size = 256
slope = 10
# network parameters
num_inputs = 6 
num_steps = 40
net_hidden_1 = 180
net_hidden_2 = 400
net_hidden_3 = 128
num_outputs = 7
pop_outputs = num_outputs * 10
num_epochs = 200
vth_in = 1.0
vth_out = 1.0
vth_recurrent = 1.0
vth_enc_value =  1.0
vth_std =  65 
beta_in = 0.5
beta_recurrent = 0.5
beta_back = 0.6
beta_out = 0.5
encoder_dim = 25
beta_std = 55
lr = 0.002
drop_recurrent =0.15
drop_back = 0.15
drop_out = 0.15
# spiking neuron parameters
beta = 0.8  # neuron decay rate


In [6]:
DATASET_NAME = 'data_watch_subset_0_40.npz'
DATASET_SUBSET = 'custom'
PATIENCE = 12
TRAIN_FOLDER_NAME = 'Trained'
NUM_WORKERS = 8
NET_OUTPUT_DIM = 7
NET_INPUT_DIM = 6
NUM_EPOCHS = 200
SEARCH_SPACE_SHUFFLE = 200

In [7]:
SUBSET_LIST = [0, 1, 4, 8, 9, 10, 14]
trained_folder = TRAIN_FOLDER_NAME
os.makedirs(trained_folder, exist_ok=True)
dataset = WisdmDatasetParser(f'{Path.home()}/snntorch_network/data/{DATASET_NAME}', norm=None, class_sublset=DATASET_SUBSET, subset_list=SUBSET_LIST)
val_set = dataset.get_validation_set(shuffle=False, subset=None)
print(f"val dataset shape: {val_set[0].shape}")
# val_set = (np.transpose(val_set[0], (0, 2, 1)), one_hot_encode(val_set[1],7))
val_dataset = WisdmDataset(val_set)
# print(f"val dataset shape: {val_set[0].shape}")
# print(f"val dataset shape: {val_set[1].shape}")
val_loader  = DataLoader(dataset= val_dataset, batch_size=int(batch_size), shuffle=True, num_workers=NUM_WORKERS, drop_last=False)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
# device = 'cpu'
print(f'Using device {device}')

grad = surrogate.fast_sigmoid(slope) #use slope for HPO
stats = LearningStats()

net_loss = regularization_loss(0.1, 0.03, 40)

net = ExInbitoryNetwork(NET_INPUT_DIM, 200,500, NET_OUTPUT_DIM, grad,
                    vth_in=vth_in, vth_recurrent=vth_recurrent, vth_out=vth_out, vth_back=1.0,
                    beta_in=beta_in, beta_recurrent=beta_recurrent, beta_back=beta_back, beta_out=beta_out,
                    # encoder_dim=int(encoder_dim),
                    # vth_enc_value=vth_enc_value, vth_std=vth_std, beta_std=beta_std,
                    drop_recurrent=drop_recurrent, drop_back=drop_back, drop_out=drop_out,
                    time_dim=2).to(device)

net.from_npz(path)

net.to(device)
optimizer = torch.optim.Adam(net.parameters(), 0.01, betas=(0.9, 0.999))
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 
        T_max=4690, 
        eta_min=0, 
        last_epoch=-1
    )


loss_fn = SF.loss.ce_count_loss()

assistant = Assistant(net, loss_fn, optimizer, stats, classifier=True, scheduler=scheduler)



(6,)
(6,)
ytrain shape (55404, 18)
yval shape (18468, 18)
ytest shape (18469, 18)
num classes train dataset: 7 occurrences of each class:[3127 3066 3044 3047 3150 3087 2973]
num classes eval dataset: 7 occurrences of each class:[1035  968 1048  996 1110 1053 1007]
num classes test dataset: 7 occurrences of each class:[1046 1061 1048 1036 1076 1026  982]
val dataset shape: (7217, 6, 40)
Using device cuda
type of self.linear2 is <class 'brevitas.nn.quant_linear.QuantLinear'>


  warn('Keyword arguments are being passed but they not being used.')


In [8]:

tqdm_dataloader = tqdm(val_loader)
for _, batch in enumerate(tqdm_dataloader): #eval loop
        input, label = batch
        output = assistant.valid(input, label)
        tqdm_dataloader.set_description(f'\r Validation: {stats.validation}')
    

        stats.update()

        torch.cuda.empty_cache()

  return super().rename(names)
  self.mem = torch.zeros_like(input_, device=self.mem.device)
 Validation: loss =     0.27931 (min =     0.04009)     accuracy = 0.93878 (max = 0.99219) : 100%|██████████| 29/29 [00:06<00:00,  4.20it/s]


In [9]:
# from neurobench.models import SNNTorchModel
# from neurobench.postprocessing.postprocessor import aggregate, choose_max_count
# from neurobench.benchmarks import Benchmark
# import neurobench.benchmarks.static_metrics

# torch.cuda.empty_cache()
# model = SNNTorchModel(net)
# postprocessors = [choose_max_count]
# for param in model.__net__().parameters():
#     print(param)

In [10]:

# static_metrics = ["footprint","parameter_count"]
# workload_metrics = ["activation_sparsity", "membrane_updates", "synaptic_operations"]

# benchmark = Benchmark(model, val_loader, [], postprocessors, [static_metrics, workload_metrics])
# results = benchmark.run(device=device)
# print(results)



In [11]:
# import json

# # Define the path to save the JSON file
# results_path = f'{name}.json'

# # Save the results dictionary to a JSON file
# with open(results_path, 'w') as json_file:
#     json.dump(results, json_file)

# print(f"Results saved to {results_path}")

In [12]:
import json
from tabulate import tabulate

# Load the JSON files
with open('balanced_with_loss.json', 'r') as f1, open('balanced_no_net_loss_big.json', 'r') as f2, open('balanced_no_net_loss_small.json', 'r') as f3:
    data1 = json.load(f1)
    data2 = json.load(f2)
    data3 = json.load(f3)

# Prepare data for tabulate
table_data = []
for key in data1.keys():
    value1 = data1.get(key, 'N/A')
    value2 = data2.get(key, 'N/A')
    value3 = data3.get(key, 'N/A')
    table_data.append([key, value1, value2, value3])

# Print the table
print(tabulate(table_data, headers=["Metric", "original", "no_loss_optimal", "original_no_loss"], tablefmt="grid"))


+---------------------+-------------------------------------------------------------+--------------------+-------------------------------------------------------------+
| Metric              | original                                                    | no_loss_optimal    | original_no_loss                                            |
| footprint           | 2419012                                                     | 2419012            | 630092                                                      |
+---------------------+-------------------------------------------------------------+--------------------+-------------------------------------------------------------+
| parameter_count     | 604710                                                      | 604710             | 157480                                                      |
+---------------------+-------------------------------------------------------------+--------------------+-------------------------------------------------