In [1]:
import snntorch as snn
import sys
import os
import platform
import torch
import pandas as pd
import numpy as np
import math
import pickle
import pprint as pp

import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen
from snntorch import surrogate

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

import itertools
import csv


has_gpu = torch.cuda.is_available()
has_mps = getattr(torch,'has_mps',False)
device = "mps" if getattr(torch,'has_mps',False) \
    else "gpu" if torch.cuda.is_available() else "cpu"

print(f"Python Platform: {platform.platform()}")
print(f"PyTorch Version: {torch.__version__}")
print()
print(f"Python {sys.version}")
print(f"Pandas {pd.__version__}")
print("GPU is", "available" if has_gpu else "NOT AVAILABLE")
print("MPS (Apple Metal) is", "AVAILABLE" if has_mps else "NOT AVAILABLE")
print(f"Target device is {device}")

Python Platform: Windows-10-10.0.22621-SP0
PyTorch Version: 2.0.1+cpu

Python 3.11.4 (tags/v3.11.4:d2340ef, Jun  7 2023, 05:45:37) [MSC v.1934 64 bit (AMD64)]
Pandas 2.0.3
GPU is NOT AVAILABLE
MPS (Apple Metal) is NOT AVAILABLE
Target device is cpu


# Init LIFNet and Netron Class

In [2]:
class Net78x78(nn.Module):
    def __init__(self, config, device):
        super().__init__()

        self.n_steps = config["sequence_length"]
        self.input_dim = config["input_dim"]
        self.num_classes = config["num_classes"]
        self.device = device

        self.spike_grads = {
            "fast_sigmoid": surrogate.fast_sigmoid(),
            "arctan": surrogate.atan(),
            "LSO": surrogate.LSO(), 
        }

        # Init layers
        self.fc1 = nn.Linear(self.input_dim, self.input_dim)
        # Init fc1.weight with custom_weight
        # self.fc1.weight.data = torch.nn.Parameter(torch.from_numpy(custom_weight).float())

        self.lif1 = snn.Leaky(beta=config['beta'], spike_grad=self.spike_grads[config['surrogate']]) # learn beta to implement on Netron

        self.fc2 = nn.Linear(self.input_dim, config['hid_layers'][0])
        self.lif2 = snn.Leaky(beta=config['beta'], spike_grad=self.spike_grads[config['surrogate']])

        self.lif_layers = nn.ModuleList()
        self.fc_layers = nn.ModuleList()

        # Create hiden layers with LIF neurons
        for i in range(len(config['hid_layers'])-1):
            self.fc_layers.append(nn.Linear(config['hid_layers'][i], config['hid_layers'][i+1]))
            self.lif_layers.append(snn.Leaky(beta=config['beta'], spike_grad=self.spike_grads[config['surrogate']]))

        # Final layer
        self.fc_final = nn.Linear(config['hid_layers'][-1], self.num_classes)
        self.lif_final = snn.Leaky(beta=config['beta'], threshold=config['out_threshold'], spike_grad=self.spike_grads[config['surrogate']])

        self.dropout = nn.Dropout(p=config["dropout"])
       
    def forward(self, x):
        # Init hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem_layers = [lif_layer.init_leaky() for lif_layer in self.lif_layers]
        mem_final = self.lif_final.init_leaky()

        # Record the final layer
        spk_final_rec = []
        mem_final_rec = []

        for step in range(self.n_steps):
            cur = self.fc1(x[step])
            spk, mem1 = self.lif1(cur, mem1)

            cur = self.fc2(spk)
            spk, mem2 = self.lif2(cur, mem2)

            for i in range(len(self.fc_layers)):
                cur = self.fc_layers[i](spk)
                cur = self.dropout(cur)
                spk, mem_layers[i] = self.lif_layers[i](cur, mem_layers[i])

            cur = self.fc_final(spk)
            cur = self.dropout(cur)
            spk_final, mem_final = self.lif_final(cur, mem_final)

            spk_final_rec.append(spk_final)
            mem_final_rec.append(mem_final)

        return torch.stack(spk_final_rec, dim=0), torch.stack(mem_final_rec, dim=0)

class Netron():
    def __init__(self, config, device):
        self.model_path = config['model_path']
        self.edm_thr_path = config['EDM_thr_path']
        self.device = device
        self.load_model()
        self.load_EDM_thr()

    def load_model(self):
        checkpoint = torch.load(self.model_path, map_location=self.device)
        self.model_config = checkpoint['config']
        self.model = Net78x78(self.model_config, self.device)
        self.state_dict = checkpoint['model_state']
        self.model.load_state_dict(self.state_dict, strict=False)
        self.model.eval()

    def load_EDM_thr(self):
        with open (self.edm_thr_path, 'rb') as f:
            self.ch_ref = pickle.load(f)
            self.edm_thr = {k: v['edm_thresholds'] for k, v in self.ch_ref.items()}

    def forward_pass_e2e(self, input, *args, **kwargs):
        edm_spiketrains = self.forward_pass_EDM(input, *args, **kwargs)
        spk_rec, mem_rec = self.forward_pass_SNN(edm_spiketrains)
        return spk_rec, mem_rec, edm_spiketrains

    def forward_pass_EDM(self, input, drop_ch=None):
        edm_output = self.edm_vec(input)  # (n_alpha, n_samples, n_channels)
        edm_output = np.moveaxis(edm_output, [2, 0], [0, 1])  # (n_channels, n_alpha, n_samples)
        edm_spiketrains = np.zeros(input.shape)

        for i, ch in enumerate(self.edm_thr):
            condition1 = (self.edm_thr[ch][0] < edm_output[i, 0]) & (edm_output[i, 0] < self.edm_thr[ch][1]) # 1 - 4 comparators 
            condition2 = (self.edm_thr[ch][2] < edm_output[i, 1]) & (edm_output[i, 1] < self.edm_thr[ch][3])
            st = condition1 & condition2 # final andgate 
            # use a nand gate instead
            # st = ~(condition1 & condition2)
            edm_spiketrains[st, i] = 1 
        
        if drop_ch:
            edm_spiketrains = np.delete(edm_spiketrains, list(map(netron.idx_translate, drop_ch)), axis=1)

        return edm_spiketrains
                               
    def forward_pass_SNN(self, input):
        # input to this function will be the output of forward_pass_EDM
        # convert to tensor and unsqueeze to add batch dimension
        input = torch.from_numpy(input).unsqueeze(0).float().permute(1, 0, 2) # (n_samples, batch_size, n_channels)
        spk_rec, mem_rec = self.model.forward(input)

        return spk_rec, mem_rec 

    def edm_vec(self, x, alphas=[1, 3], init_edm=0.5):
        n_alpha = len(alphas)
        edm_output = np.zeros((n_alpha, x.shape[0], x.shape[1])) # (n_alpha, n_samples, n_channels)
        edm = np.ones((n_alpha, x.shape[1])) * init_edm
        alphas = np.array([1/2**a for a in alphas]).reshape(-1, 1)

        for t in range(x.shape[0]):
            edm = edm - alphas * (edm - x[t, :].reshape(1, -1))
            edm_output[:, t, :] = edm
        
        return edm_output
    
    def idx_translate(self, ch_id):
       return list(self.ch_ref.keys()).index(ch_id)
    

       


# Load Dataset

In [3]:
loaded = np.load('C:/Users/mikae/Downloads/netron inference-20230812T013749Z-001/netron inference/dataset/ap_dset.npz')
X, Y, markers = loaded['X'], loaded['Y'], loaded['markers']

loaded_test = np.load('C:/Users/mikae/Downloads/netron inference-20230812T013749Z-001/netron inference/dataset/test_set_ap.npz')
X_test, Y_test = loaded_test['X'], loaded_test['Y'], 

markers = markers.astype(int)


In [4]:
netron_config = {
    'model_path': 'C:/Users/mikae/Downloads/netron inference-20230812T013749Z-001/netron inference/models/final_ngc536dn.pth',
    'EDM_thr_path': 'C:/Users/mikae/Downloads/netron inference-20230812T013749Z-001/netron inference/models/EDMNet_liberal_thresholds_edm1driven.pkl', 
}

netron = Netron(netron_config, device)


In [5]:
pp.pprint(netron.model_config)


{'X_shape': (142, 3000, 73),
 'aug_dynamic': None,
 'aug_prob': 0.4,
 'aug_strategy': 'single',
 'aug_trials': 1000,
 'augmented': None,
 'batch_size': 8,
 'beta': 0.9,
 'classes': {'PG': 0, 'SG': 1},
 'dropout': 0.0,
 'dset': 'edm1driven_dropped',
 'enable_wandb': True,
 'grad_clip': None,
 'grad_mask': True,
 'hid_layers': [256],
 'input_dim': 73,
 'loss': 'CECountLoss',
 'lr': 0.0005,
 'lr_scheduler': 'none',
 'n_electrodes': 73,
 'name': 'edm1driven_2152',
 'num_classes': 2,
 'num_epochs': 200,
 'optimizer': 'Adam',
 'out_threshold': 1.0,
 'overfit_minibatch': False,
 'project': 'lif_params',
 'save_model_threshold': 0.0,
 'sequence_length': 3000,
 'surrogate': 'arctan',
 'test_dataset_size': 15,
 'test_minibatch_size': 2,
 'test_split': 0.1,
 'train_dataset_size': 127,
 'train_minibatch_size': 16,
 'unique': True,
 'weight_decay': 0,
 'weight_init': 'even_78x78',
 'workers': 0}


# Calculate EDM

In [6]:
trial = 10
sample_X = X[trial]
sample_Y = Y[trial]
print(sample_X.shape)
edm_X = netron.edm_vec(sample_X)
print(edm_X.shape)

(3000, 78)
(2, 3000, 78)


In [7]:
drop_ch = [11, 33, 26, 58, 18]
edm_spiketrain_X = netron.forward_pass_EDM(sample_X, drop_ch=drop_ch)
print(edm_spiketrain_X.shape)
print(np.sum(edm_spiketrain_X))


(3000, 73)
324.0


In [11]:
print(edm_spiketrain_X)

[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]


# Test End2End Forward Pass

In [8]:
spk_rec, mem_rec, edm_spiketrains = netron.forward_pass_e2e(sample_X, drop_ch=drop_ch)

print("Spike recording: ", spk_rec.shape)
print("Membrane potential recording: ", mem_rec.shape)
print("EDM spike trains: ", edm_spiketrains.shape)

Spike recording:  torch.Size([3000, 1, 2])
Membrane potential recording:  torch.Size([3000, 1, 2])
EDM spike trains:  (3000, 73)


# Arduino > PYNQ Pipeline

## Export EDM Spike Train

In [12]:
# Convert numpy array to pandas DataFrame
df_edm_spiketrains = pd.DataFrame(edm_spiketrains)

# Save DataFrame to CSV
csv_filename = "edm_spiketrains_data.csv"
df_edm_spiketrains.to_csv(csv_filename, index=False)
print(f"Data saved to {csv_filename}")


Data saved to edm_spiketrains_data.csv
