## Imports

In [214]:
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import seaborn as sns
import numpy as np
import os, sys
import torch
import random

from torch.utils.data import TensorDataset, random_split
import torch.nn as nn
from sklearn.metrics import confusion_matrix
from utils import get_shd_dataset
import h5py

In [215]:
## Load in the N-MNIST dataset
import tonic

train_dataset = tonic.datasets.NMNIST(save_to='./data', train=True, transform=None, target_transform=None)
test_dataset = tonic.datasets.NMNIST(save_to='./data', train=False, transform=None, target_transform=None)

## Universal Model Settings

We set them up here since some of the random manifold components are dependent on this so it has the right structure

In [216]:
# The coarse network structure and the time steps are dicated by the SHD dataset. 
nb_inputs  = 700
nb_hidden  = 200
nb_outputs = 20

time_step = 1e-3
nb_steps = 100
max_time = 1.4

batch_size = 256

# Set random seed for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

torch.use_deterministic_algorithms(True)
torch.backends.mps.benchmark = False
torch.backends.mps.deterministic = True


In [217]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

dtype = torch.float

In [218]:
tau_mem = 10e-3
tau_syn = 5e-3

alpha   = float(np.exp(-time_step/tau_syn))
beta    = float(np.exp(-time_step/tau_mem))

weight_scale = 7*(1.0-beta) # this should give us some spikes to begin with

w1 = torch.empty((nb_inputs, nb_hidden),  device=device, dtype=dtype, requires_grad=True)
w2 = torch.empty((nb_hidden, nb_outputs), device=device, dtype=dtype, requires_grad=True)
v1 = torch.empty((nb_hidden, nb_hidden), device=device, dtype=dtype, requires_grad=True)

## Import SHD Data

In [219]:
# Here we load the Dataset
cache_dir = os.path.expanduser("~/data")
cache_subdir = "hdspikes"
get_shd_dataset(cache_dir, cache_subdir)

train_file = h5py.File(os.path.join(cache_dir, cache_subdir, 'shd_train.h5'), 'r')
test_file = h5py.File(os.path.join(cache_dir, cache_subdir, 'shd_test.h5'), 'r')

x_train = train_file['spikes']
y_train = train_file['labels']
x_test = test_file['spikes']
y_test = test_file['labels']

Available at: /Users/naliniramanathan/data/hdspikes/shd_train.h5
Available at: /Users/naliniramanathan/data/hdspikes/shd_test.h5


In [220]:
import numpy as np
import torch

def non_sparse_data_generator_from_hdf5_spikes(X, y, batch_size, nb_steps, nb_units, max_time, device=device, shuffle=True):
    """ This generator takes a spike dataset and generates spiking network input as dense tensors.

    Args:
        X: The data ( sample x event x 2 ) the last dim holds (time,neuron) tuples
        y: The labels
        batch_size: The number of samples per batch
        nb_steps: The number of time steps
        nb_units: The number of neurons
        max_time: The maximum time value in the spike data
        device: The torch device to use (e.g., 'cpu' or 'cuda')
        shuffle: Whether to shuffle the data at the beginning of each epoch
    """

    labels_ = np.array(y, dtype=np.int64)
    number_of_batches = len(labels_) // batch_size
    sample_index = np.arange(len(labels_))

    # compute discrete firing times
    firing_times = X['times']
    units_fired = X['units']

    time_bins = np.linspace(0, max_time, num=nb_steps)

    if shuffle:
        np.random.shuffle(sample_index)

    total_batch_count = 0
    counter = 0
    while counter < number_of_batches:
        batch_index = sample_index[batch_size * counter:batch_size * (counter + 1)]

        X_batch = torch.zeros((batch_size, nb_steps, nb_units), dtype=torch.float32, device=device)
        y_batch = torch.tensor(labels_[batch_index], device=device)

        for bc, idx in enumerate(batch_index):
            times = np.digitize(firing_times[idx], time_bins)
            units = units_fired[idx]

            # Filter out-of-bounds time indices
            valid_indices = (times >= 0) & (times < nb_steps)
            times = times[valid_indices]
            units = units[valid_indices]

            # Fill the dense tensor
            X_batch[bc, times, units] = 1.0

        yield X_batch, y_batch

        counter += 1

In [221]:
def get_representative_subset_indices(data, limit, num_inputs = 10):
    """
    Selects a representative subset of data point indices for PyTorch Subset,
    using the number of unique labels as the number of representatives.

    Args:
        labels (np.ndarray or list): Array/list of labels for each data point.

    Returns:
        torch.Tensor: Tensor of indices of the selected representative data points.
    """

    num_representatives = len(data)  # Use number of unique labels

    representatives_per_class = num_representatives//num_inputs
    sample_per_class = limit // num_inputs
    selected_indices = []
    for i in range(num_inputs):
        range_v = np.arange(i*representatives_per_class, (i+1)*representatives_per_class)
        selected_indices.extend(np.random.choice(range_v, sample_per_class, replace=False))
        # selected_indices.append(np.random.choice(indices, 1, replace=False).item()) #ensure only one item is selected.

    return torch.tensor(selected_indices)

In [222]:
from torch.utils.data import DataLoader, Subset, TensorDataset

train_limit = 200
test_limit = 200
# train_dataset = test_dataset
train_subset_indices = get_representative_subset_indices(train_dataset, train_limit)
test_subset_indices = get_representative_subset_indices(test_dataset, test_limit)

train_subset = torch.utils.data.Subset(train_dataset, train_subset_indices)
test_subset = torch.utils.data.Subset(test_dataset, test_subset_indices)

train_x_data = torch.stack([preprocess_spike_events(sample, nb_steps, nb_units=nb_inputs) for sample, _ in train_subset])
train_y_tensor = torch.tensor([label for _, label in train_subset], dtype=torch.int64)  # Shape: (num_samples,)

train_data = TensorDataset(train_x_data, train_y_tensor)

test_x_data = torch.stack([preprocess_spike_events(sample, nb_steps, nb_units=nb_inputs) for sample, _ in test_subset])
test_y_tensor = torch.tensor([label for _, label in test_subset], dtype=torch.int64)  # Shape: (num_samples,)

test_data = TensorDataset(test_x_data, test_y_tensor)


train_loader = DataLoader(
    train_data,
    batch_size=batch_size,
    shuffle=True,
    # collate_fn=custom_collate_fn,  # Use the custom collate function
    # num_workers=4,
    pin_memory=True,
    drop_last=True
)

test_loader = DataLoader(
    test_data,
    batch_size=batch_size,
    shuffle=True,
    # collate_fn=custom_collate_fn,  # Use the custom collate function
    # num_workers=4,
    pin_memory=True,
    drop_last=True
)
# test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

# nb_steps = train_dataset[0][0].shape[0] # Get the number of time steps
nb_inputs = 34 * 34

NameError: name 'preprocess_spike_events' is not defined

## Create a Baseline SNN Hybrid Model

### Model Implementation - No Recurrent Weights

The input is a set of spikes. At each timestep, certain SNN neurons fire in response to those inputs. The ANNs act as recurrent units within the network. We'll exclude the recurrent weights to simplify things to start.

#### SNN Only

In [None]:
def spike_fn(x):
    out = torch.zeros_like(x)
    out[x > 0] = 1.0
    return out

def SNN(inputs, timesteps=nb_steps):
    inputs = inputs.to(device)
    h1 = torch.einsum("abc,cd->abd", (inputs, w1))
    syn = torch.zeros((batch_size,nb_hidden), device=device, dtype=dtype)
    mem = torch.zeros((batch_size,nb_hidden), device=device, dtype=dtype)

    mem_rec = []
    spk_rec = []

    # Compute hidden layer activity
    for t in range(timesteps):
        mthr = mem-1.0
        out = spike_fn(mthr)
        rst = out.detach() # We do not want to backprop through the reset

        new_syn = alpha*syn +h1[:,t]
        new_mem = (beta*mem +syn)*(1.0-rst)

        mem_rec.append(mem)
        spk_rec.append(out)
        
        mem = new_mem
        syn = new_syn

    mem_rec = torch.stack(mem_rec,dim=1)
    spk_rec = torch.stack(spk_rec,dim=1)

    # Readout layer
    h2= torch.einsum("abc,cd->abd", (spk_rec, w2))
    flt = torch.zeros((batch_size,nb_outputs), device=device, dtype=dtype)
    out = torch.zeros((batch_size,nb_outputs), device=device, dtype=dtype)
    out_rec = [out]
    for t in range(timesteps):
        new_flt = alpha*flt +h2[:,t]
        new_out = beta*out +flt

        flt = new_flt
        out = new_out

        out_rec.append(out)

    out_rec = torch.stack(out_rec,dim=1)
    other_recs = [mem_rec, spk_rec]
    return out_rec, other_recs

In [None]:
def plot_voltage_traces(mem, spk=None, dim=(3,5), spike_height=5):
    gs=GridSpec(*dim)
    if spk is not None:
        dat = 1.0*mem
        dat[spk>0.0] = spike_height
        dat = dat.detach().cpu().numpy()
    else:
        dat = mem.detach().cpu().numpy()
    for i in range(np.prod(dim)):
        if i==0: a0=ax=plt.subplot(gs[i])
        else: ax=plt.subplot(gs[i],sharey=a0)
        ax.plot(dat[i])
        ax.axis("off")

In [None]:
def classification_accuracy(timesteps=nb_steps, model = SNN):
    """ Dirty little helper function to compute classification accuracy. """
    output,_ = model(x_data)
    m,_= torch.max(output,1) # max over time
    _,am=torch.max(m,1) # argmax over output units
    am = am.detach().cpu().numpy() # convert to numpy
    # Compute accuracy
    acc = np.mean((y_data.detach().cpu().numpy()==am))
    print("Accuracy %.3f"%acc)
    return acc

##### SNN Regularization

##### SNN Recurrent

In [None]:
def SNN_recurrent(inputs, timesteps=nb_steps):
    inputs = inputs.to(device)
    syn = torch.zeros((batch_size,nb_hidden), device=device, dtype=dtype)
    mem = torch.zeros((batch_size,nb_hidden), device=device, dtype=dtype)
    mem_rec = []
    spk_rec = []
    out = torch.zeros((batch_size, nb_hidden), device=device, dtype=dtype)
    h1_from_input = torch.einsum("abc,cd->abd", (inputs, w1))
    # Compute hidden layer activity
    for t in range(timesteps):
        h1 = h1_from_input[:,t] + torch.einsum("ab,bc->ac", (out, v1))
        mthr = mem-1.0
        out = spike_fn(mthr)
        rst = out.detach() # We do not want to backprop through the reset

        new_syn = alpha*syn +h1
        new_mem = (beta*mem +syn)*(1.0-rst)

        mem_rec.append(mem)
        spk_rec.append(out)
        
        mem = new_mem
        syn = new_syn

    mem_rec = torch.stack(mem_rec,dim=1)
    spk_rec = torch.stack(spk_rec,dim=1)

    # Readout layer
    h2= torch.einsum("abc,cd->abd", (spk_rec, w2))
    flt = torch.zeros((batch_size,nb_outputs), device=device, dtype=dtype)
    out = torch.zeros((batch_size,nb_outputs), device=device, dtype=dtype)
    out_rec = [out]
    for t in range(timesteps):
        new_flt = alpha*flt +h2[:,t]
        new_out = beta*out +flt

        flt = new_flt
        out = new_out

        out_rec.append(out)

    out_rec = torch.stack(out_rec,dim=1)
    other_recs = [mem_rec, spk_rec]
    return out_rec, other_recs

#### ANN Only

Let's model an RNN with a similar structure to our SNN. It will take in our temporal data (spikes), and then transform it accordingly in each layer using the weights.

Since we take spikes, weight them, and then output them using traditional tanh, we don't really need to worry about conversion.

We'd usually have recurrent weights for the RNN, but since we excluded them for the SNN we'll do the same here.


In [None]:
def ANN_non_recurrent_with_LIF_output(inputs, timesteps=nb_steps):
    inputs = inputs.to(device)
    
    # Weights for the hidden layer for RNN is just w1 -- multiplying by inputs just gives us the output at each timestep
    h1 = torch.einsum("abc,cd->abd", (inputs, w1))
    # We could also add a bias term but let's exclude for now to make the merging more simple

    interim_rec = []

    # The difference for a pure ANN is that the hidden and the output layer usually happen at the same time
    # It's really fine that they're different here though if we just compute them sequentially

    for t in range(timesteps):

        layer_weights = h1[:,t]

        # Use a tanh function similar to what's done in other models
        out = torch.tanh(layer_weights)
        interim_rec.append(out)

    interim_rec = torch.stack(interim_rec,dim=1)
    # We will use the same LIF model as before, and just use the raw output as an input to start since it has a window of 1

    # Readout layer
    h2= torch.einsum("abc,cd->abd", (interim_rec, w2))
    flt = torch.zeros((batch_size,nb_outputs), device=device, dtype=dtype)
    out = torch.zeros((batch_size,nb_outputs), device=device, dtype=dtype)
    out_rec = [out]

    # Convert the h2 readout layer to a rate that can be transmitted to the output layer
    
    for t in range(nb_steps):
        new_flt = alpha*flt +h2[:,t]
        new_out = beta*out +flt

        flt = new_flt
        out = new_out

        out_rec.append(out)

    out_rec = torch.stack(out_rec,dim=1)
    other_recs = interim_rec
    return out_rec, other_recs

##### With Recurrent Network

In [None]:
def RNN_with_LIF_output(inputs, timesteps=nb_steps):
    inputs = inputs.to(device)

    # Weights for the hidden layer for RNN is just w1 -- multiplying by inputs just gives us the output at each timestep
    h1 = torch.einsum("abc,cd->abd", (inputs, w1))
    # We could also add a bias term but let's exclude for now to make the merging more simple

    interim_rec = []
    
    # The difference for a pure ANN is that the hidden and the output layer usually happen at the same time
    # It's really fine that they're different here though if we just compute them sequentially
    out = torch.zeros((batch_size, nb_hidden), device=device, dtype=dtype)

    for t in range(timesteps):

        layer_weights = h1[:,t] + torch.einsum("ab,bc->ac", (out, v1))

        # Use a tanh function similar to what's done in other models
        out = torch.tanh(layer_weights)
        interim_rec.append(out)

    interim_rec = torch.stack(interim_rec,dim=1)
    # We will use the same LIF model as before, and just use the raw output as an input to start since it has a window of 1

    # Readout layer
    h2= torch.einsum("abc,cd->abd", (interim_rec, w2))
    flt = torch.zeros((batch_size,nb_outputs), device=device, dtype=dtype)
    out = torch.zeros((batch_size,nb_outputs), device=device, dtype=dtype)
    out_rec = [out]

    # Convert the h2 readout layer to a rate that can be transmitted to the output layer
    
    for t in range(nb_steps):
        new_flt = alpha*flt +h2[:,t]
        new_out = beta*out +flt

        flt = new_flt
        out = new_out

        out_rec.append(out)

    out_rec = torch.stack(out_rec,dim=1)
    other_recs = interim_rec
    return out_rec, other_recs

#### ANN and SNN - Network to Network

For this implementation, we just have 1 timestep, so we don't need to worry about converting ANN to SNN or any of the intermediary weights.

In [None]:
torch.manual_seed(42)
snn_mask = torch.randint(0, 2, (nb_hidden,), dtype=torch.float32, device=device)  # Randomly assign spiking or analog neurons
snn_mask # 1 is spiking

tensor([1., 1., 0., 0., 0., 0., 1., 0., 1., 1., 0., 0., 1., 0., 0., 1., 0., 1.,
        1., 0., 0., 0., 1., 0., 0., 1., 0., 1., 0., 1., 0., 1., 1., 0., 0., 0.,
        0., 0., 1., 1., 0., 1., 0., 0., 0., 0., 1., 0., 1., 1., 1., 0., 0., 1.,
        1., 0., 0., 0., 0., 1., 0., 0., 1., 0., 1., 0., 0., 1., 0., 0., 0., 1.,
        0., 1., 1., 0., 1., 0., 0., 1., 0., 1., 1., 0., 1., 1., 0., 0., 1., 0.,
        0., 0., 0., 1., 0., 0., 1., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 1.,
        1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0.,
        1., 1., 1., 1., 0., 0., 0., 1., 0., 1., 1., 0., 0., 0., 0., 1., 0., 1.,
        1., 1., 0., 1., 1., 0., 0., 1., 1., 1., 0., 1., 0., 0., 1., 0., 0., 0.,
        1., 0., 1., 1., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 1., 0.,
        1., 0., 1., 0., 1., 1., 1., 1., 0., 1., 1., 0., 0., 1., 1., 0., 0., 1.,
        1., 1.], device='mps:0')

Options for SNN to ANN conversion include:
1. Time of first spike (this makes less sense with a window of size one)
2. Rate (this also makes less sense with a window of size one)
3. ISI (it's a difference between spikes so I think this also makes less sense)
4. Just take it and weight it (learnable) --> trying this first

In [None]:
def Hybrid_ANN_non_recurrent(inputs, timesteps=nb_steps):
    inputs = inputs.to(device)
    # Weight matrix for training, I think this can be used for both SNN and ANN but the training may have to be different
    h1 = torch.einsum("abc,cd->abd", (inputs, w1))
    # Let's just do 2 matrices to keep it clean for now
    h1_ann = h1.clone() * (1.0-snn_mask)  # ANN neurons
    h1_snn = h1.clone() * snn_mask  # SNN neurons
    syn = torch.zeros((batch_size,nb_hidden), device=device, dtype=dtype)
    mem = torch.zeros((batch_size,nb_hidden), device=device, dtype=dtype)

    mem_rec = []
    spk_rec = []
    ann_rec = []

    # Compute hidden layer activity
    for t in range(timesteps):

        # SNN neurons

        # Apply the mask to the synaptic input to ensure SNN neurons are only updated with spiking rules
        mem = mem*snn_mask
        syn = syn*snn_mask

        mthr = mem-1.0
        out_snn = spike_fn(mthr)
        rst = out_snn.detach() # We do not want to backprop through the reset

        new_syn = alpha*syn +h1_snn[:,t]
        new_mem = (beta*mem +syn)*(1.0-rst)

        mem_rec.append(mem)
        spk_rec.append(out_snn)

        # ANN neurons - don't interact with the synaptic input at all to start
        out_ann = torch.tanh(h1_ann[:,t])
        ann_rec.append(out_ann)
        
        mem = new_mem
        syn = new_syn

    mem_rec = torch.stack(mem_rec,dim=1)
    spk_rec = torch.stack(spk_rec,dim=1)
    ann_rec = torch.stack(ann_rec,dim=1)

    h2_snn = torch.einsum("abc,cd->abd", (spk_rec, w2))
    h2_ann= torch.einsum("abc,cd->abd", (ann_rec, w2))

    # We can add the two together to get the final output since the two do not interact at all
    h2 = h2_snn + h2_ann

    flt = torch.zeros((batch_size,nb_outputs), device=device, dtype=dtype)
    out = torch.zeros((batch_size,nb_outputs), device=device, dtype=dtype)
    out_rec = [out]
    for t in range(timesteps):
        new_flt = alpha*flt +h2[:,t]
        new_out = beta*out +flt

        flt = new_flt
        out = new_out

        out_rec.append(out)

    out_rec = torch.stack(out_rec,dim=1)
    other_recs = [mem_rec, spk_rec, ann_rec]
    return out_rec, other_recs

##### With Recurrent Network

In [None]:
def Hybrid_RNN(inputs, timesteps=nb_steps):
    inputs = inputs.to(device)
    # Weight matrix for training, I think this can be used for both SNN and ANN but the training may have to be different
    h1 = torch.einsum("abc,cd->abd", (inputs, w1))
    # Let's just do 2 matrices to keep it clean for now
    h1_ann = h1.clone() * (1.0-snn_mask)  # ANN neurons
    h1_snn = h1.clone() * snn_mask  # SNN neurons
    syn = torch.zeros((batch_size,nb_hidden), device=device, dtype=dtype)
    mem = torch.zeros((batch_size,nb_hidden), device=device, dtype=dtype)

    mem_rec = []
    spk_rec = []
    ann_rec = []

    out_ann = torch.zeros((batch_size, nb_hidden), device=device, dtype=dtype)
    
    # Compute hidden layer activity
    for t in range(timesteps):

        # SNN neurons

        # Apply the mask to the synaptic input to ensure SNN neurons are only updated with spiking rules
        mem = mem*snn_mask
        syn = syn*snn_mask

        mthr = mem-1.0
        out_snn = spike_fn(mthr)
        rst = out_snn.detach() # We do not want to backprop through the reset

        new_syn = alpha*syn +h1_snn[:,t]
        new_mem = (beta*mem +syn)*(1.0-rst)

        mem_rec.append(mem)
        spk_rec.append(out_snn)

        # ANN neurons - don't interact with the synaptic input at all to start
        out_ann = torch.tanh(h1_ann[:,t]) + torch.einsum("ab,bc->ac", (out_ann, v1))
        ann_rec.append(out_ann)
        
        mem = new_mem
        syn = new_syn

    mem_rec = torch.stack(mem_rec,dim=1)
    spk_rec = torch.stack(spk_rec,dim=1)
    ann_rec = torch.stack(ann_rec,dim=1)

    h2_snn = torch.einsum("abc,cd->abd", (spk_rec, w2))
    h2_ann= torch.einsum("abc,cd->abd", (ann_rec, w2))

    # We can add the two together to get the final output since the two do not interact at all
    h2 = h2_snn + h2_ann

    flt = torch.zeros((batch_size,nb_outputs), device=device, dtype=dtype)
    out = torch.zeros((batch_size,nb_outputs), device=device, dtype=dtype)
    out_rec = [out]
    for t in range(timesteps):
        new_flt = alpha*flt +h2[:,t]
        new_out = beta*out +flt

        flt = new_flt
        out = new_out

        out_rec.append(out)

    out_rec = torch.stack(out_rec,dim=1)
    other_recs = [mem_rec, spk_rec, ann_rec]
    return out_rec, other_recs

##### Recurrent SNN Network

In [None]:
def Hybrid_RNN_SNN_rec(inputs, timesteps=nb_steps):
    inputs = inputs.to(device)
    # Weight matrix for training, I think this can be used for both SNN and ANN but the training may have to be different
    h1 = torch.einsum("abc,cd->abd", (inputs, w1))
    # Let's just do 2 matrices to keep it clean for now
    h1_ann = h1.clone() * (1.0-snn_mask)  # ANN neurons
    h1_snn_input = h1.clone() * snn_mask  # SNN neurons
    syn = torch.zeros((batch_size,nb_hidden), device=device, dtype=dtype)
    mem = torch.zeros((batch_size,nb_hidden), device=device, dtype=dtype)

    mem_rec = []
    spk_rec = []
    ann_rec = []

    out_ann = torch.zeros((batch_size, nb_hidden), device=device, dtype=dtype)
    out_snn = torch.zeros((batch_size, nb_hidden), device=device, dtype=dtype)
    # Compute hidden layer activity
    for t in range(timesteps):
        h1_snn = h1_snn_input[:,t] + torch.einsum("ab,bc->ac", (out_snn, v1))
        # SNN neurons

        # Apply the mask to the synaptic input to ensure SNN neurons are only updated with spiking rules
        mem = mem*snn_mask
        syn = syn*snn_mask

        mthr = mem-1.0
        out_snn = spike_fn(mthr)
        rst = out_snn.detach() # We do not want to backprop through the reset

        new_syn = alpha*syn +h1_snn
        new_mem = (beta*mem +syn)*(1.0-rst)

        mem_rec.append(mem)
        spk_rec.append(out_snn)

        # ANN neurons - don't interact with the synaptic input at all to start
        out_ann = torch.tanh(h1_ann[:,t]) + torch.einsum("ab,bc->ac", (out_ann, v1))
        ann_rec.append(out_ann)
        
        mem = new_mem
        syn = new_syn

    mem_rec = torch.stack(mem_rec,dim=1)
    spk_rec = torch.stack(spk_rec,dim=1)
    ann_rec = torch.stack(ann_rec,dim=1)

    h2_snn = torch.einsum("abc,cd->abd", (spk_rec, w2))
    h2_ann= torch.einsum("abc,cd->abd", (ann_rec, w2))

    # We can add the two together to get the final output since the two do not interact at all
    h2 = h2_snn + h2_ann

    flt = torch.zeros((batch_size,nb_outputs), device=device, dtype=dtype)
    out = torch.zeros((batch_size,nb_outputs), device=device, dtype=dtype)
    out_rec = [out]
    for t in range(timesteps):
        new_flt = alpha*flt +h2[:,t]
        new_out = beta*out +flt

        flt = new_flt
        out = new_out

        out_rec.append(out)

    out_rec = torch.stack(out_rec,dim=1)
    other_recs = [mem_rec, spk_rec, ann_rec]
    return out_rec, other_recs

### ANN and SNN - Layer to Layer

In [None]:
def Hybrid_RNN_SNN_V1_same_layer(inputs, timesteps=nb_steps):
    inputs = inputs.to(device)
    # Weight matrix for training, I think this can be used for both SNN and ANN but the training may have to be different
    h1 = torch.einsum("abc,cd->abd", (inputs, w1))
    # Let's just do 2 matrices to keep it clean for now
    h1_ann = h1.clone() * (1.0-snn_mask)  # ANN neurons
    h1_snn_input = h1.clone() * snn_mask  # SNN neurons
    syn = torch.zeros((batch_size,nb_hidden), device=device, dtype=dtype)
    mem = torch.zeros((batch_size,nb_hidden), device=device, dtype=dtype)

    mem_rec = []
    spk_rec = []
    ann_rec = []
    # (O_A + O_S) * W_2 = (O_A * W_2 + O_S * W_2) by distributive property, no theoretical change there apart from H1
    # Prove/think about this after more after test it
    out_ann = torch.zeros((batch_size, nb_hidden), device=device, dtype=dtype)
    out_snn = torch.zeros((batch_size, nb_hidden), device=device, dtype=dtype)
    # Compute hidden layer activity
    for t in range(timesteps):
        out = out_ann + out_snn
        h1_snn = h1_snn_input[:,t] + torch.einsum("ab,bc->ac", (out, v1))
        # SNN neurons

        # Apply the mask to the synaptic input to ensure SNN neurons are only updated with spiking rules
        mem = mem*snn_mask
        syn = syn*snn_mask

        mthr = mem-1.0
        out_snn = spike_fn(mthr)
        rst = out_snn.detach() # We do not want to backprop through the reset

        new_syn = alpha*syn +h1_snn
        new_mem = (beta*mem +syn)*(1.0-rst)

        mem_rec.append(mem)
        spk_rec.append(out_snn)

        # ANN neurons - don't interact with the synaptic input at all to start
        out_ann = torch.tanh(h1_ann[:,t]) + torch.einsum("ab,bc->ac", (out, v1))
        ann_rec.append(out_ann)
        
        mem = new_mem
        syn = new_syn

    mem_rec = torch.stack(mem_rec,dim=1)
    spk_rec = torch.stack(spk_rec,dim=1)
    ann_rec = torch.stack(ann_rec,dim=1)

    h2_snn = torch.einsum("abc,cd->abd", (spk_rec, w2))
    h2_ann = torch.einsum("abc,cd->abd", (ann_rec, w2))

    # We can add the two together to get the final output since the two do not interact at all
    h2 = h2_snn + h2_ann

    flt = torch.zeros((batch_size,nb_outputs), device=device, dtype=dtype)
    out = torch.zeros((batch_size,nb_outputs), device=device, dtype=dtype)
    out_rec = [out]
    for t in range(timesteps):
        new_flt = alpha*flt +h2[:,t]
        new_out = beta*out +flt

        flt = new_flt
        out = new_out

        out_rec.append(out)

    out_rec = torch.stack(out_rec,dim=1)
    other_recs = [mem_rec, spk_rec, ann_rec]
    return out_rec, other_recs

### With Training (Surrogate Gradient)

In [None]:
class SurrGradSpike(torch.autograd.Function):
    """
    Here we implement our spiking nonlinearity which also implements 
    the surrogate gradient. By subclassing torch.autograd.Function, 
    we will be able to use all of PyTorch's autograd functionality.
    Here we use the normalized negative part of a fast sigmoid 
    as this was done in Zenke & Ganguli (2018).
    """
    
    scale = 100.0 # controls steepness of surrogate gradient

    @staticmethod
    def forward(ctx, input):
        """
        In the forward pass we compute a step function of the input Tensor
        and return it. ctx is a context object that we use to stash information which 
        we need to later backpropagate our error signals. To achieve this we use the 
        ctx.save_for_backward method.
        """
        ctx.save_for_backward(input)
        out = torch.zeros_like(input)
        out[input > 0] = 1.0
        return out

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor we need to compute the 
        surrogate gradient of the loss with respect to the input. 
        Here we use the normalized negative part of a fast sigmoid 
        as this was done in Zenke & Ganguli (2018).
        """
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad = grad_input/(SurrGradSpike.scale*torch.abs(input)+1.0)**2
        return grad
    
# here we overwrite our naive spike function by the "SurrGradSpike" nonlinearity which implements a surrogate gradient
spike_fn = SurrGradSpike.apply

In [None]:
def compute_classification_accuracy(data_loader, model):
    """ 
    Computes classification accuracy and confusion matrix on supplied data in batches.
    
    Returns:
        accuracy: Overall classification accuracy.
        conf_matrix: Confusion matrix of shape (num_classes, num_classes).
    """
    accs = []
    all_preds = []
    all_labels = []
    spike_count = 0
    for x_local, y_local in non_sparse_data_generator_from_hdf5_spikes(x_data, y_data, batch_size, nb_steps, nb_inputs, max_time, shuffle=False):
        # Move data to the appropriate device
        x_local, y_local = x_local.to(device), y_local.to(device)

        # Run the SNN model
        output, other_recs = model(x_local)
        spks = other_recs[1].clone() # spiking activity
                # just include spikes greater than 1 and do a count of them
        spks[spks<1.0] = 0.0
        spks[spks>1.0] = 1.0
        spike_count += torch.sum(spks).detach().cpu().numpy()
        # Compute predictions
        m, _ = torch.max(output, 1)  # Max over time
        _, preds = torch.max(m, 1)   # Argmax over output units

        # Compute accuracy for the current batch
        correct = (y_local == preds).float()  # Shape: [batch_size]
        tmp = correct.mean().item()  # Accuracy for the current batch
        accs.append(tmp)

        # Store predictions and labels for the confusion matrix
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(y_local.cpu().numpy())

    # Compute overall accuracy
    accuracy = np.mean(accs)

    # Compute confusion matrix
    conf_matrix = confusion_matrix(all_labels, all_preds)

    return accuracy, conf_matrix, spike_count

def plot_confusion_matrix(conf_matrix, class_names, title):
    """ Plots the confusion matrix. """
    plt.figure(figsize=(8, 6))
    sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"{title} Confusion Matrix - Random Manifolds")
    plt.show()

In [None]:
def train_and_print_results(model, x_data=x_train, y_data=y_train, epochs=100, regularization=False, l1=1e-5, l2=1e-5):
    global w1, w2, v1
    # The following lines will reinitialize the weights
    torch.nn.init.normal_(w1, mean=0.0, std=weight_scale/np.sqrt(nb_inputs))
    torch.nn.init.normal_(w2, mean=0.0, std=weight_scale/np.sqrt(nb_hidden))
    torch.nn.init.normal_(v1, mean=0.0, std=weight_scale/np.sqrt(nb_hidden))

    print("init done")

    params = [w1,w2, v1]
    optimizer = torch.optim.Adam(params, lr=2e-3, betas=(0.9,0.999))

    log_softmax_fn = nn.LogSoftmax(dim=1)
    loss_fn = nn.NLLLoss()
    spike_count = 0
    loss_hist = []
    for e in range(epochs):
        local_loss = []
        # with tqdm(train_loader, desc=f"Epoch {e+1}/{nb_epochs}", unit="batch") as t:
        for x_local, y_local in non_sparse_data_generator_from_hdf5_spikes(x_data, y_data, batch_size, nb_steps, nb_inputs, max_time, shuffle=False):
            x_local = x_local.to(device)
            y_local = y_local.to(device)
            output, other_recs = model(x_local)  # Assuming run_snn() is integrated into model
            m,_=torch.max(output,1)
            log_p_y = log_softmax_fn(m)
            if "SNN" or "Hybrid" in model.__name__:
                spks = other_recs[1].clone() # spiking activity
                # just include spikes greater than 1 and do a count of them
                spks[spks<1.0] = 0.0
                spks[spks>1.0] = 1.0
                spike_count += torch.sum(spks).detach().cpu().numpy()
            loss_val = loss_fn(log_p_y, y_local)
            if regularization:
                spks = other_recs[1]
                reg_loss = l1*torch.sum(spks) # L1 loss on total number of spikes
                reg_loss += l2*torch.mean(torch.sum(torch.sum(spks,dim=0),dim=0)**2) # L2 loss on spikes per neuron
                loss_val += reg_loss
            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()
            local_loss.append(loss_val.item())
        loss_hist.append(np.mean(local_loss))

    plt.plot(loss_hist)
    plt.title(f"{model.__name__} Loss - Regularization: {regularization}")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    sns.despine()
    # output,other_recs_h = model(x_data)
    # fig=plt.figure(dpi=100)
    # plot_voltage_traces(other_recs_h[0],other_recs_h[1],dim=(2,5))
    fig=plt.figure(dpi=100)
    # if 'run_hybrid' in model.__name__:
    #     plot_voltage_traces(other_recs_h[2],dim=(2,5))
    #     fig=plt.figure(dpi=100)
    plot_voltage_traces(output,dim=(2,5))
    train_accuracy, train_conf_matrix, _ = compute_classification_accuracy(x_data, y_data, model=model)

    # Print results
    print(f"Train Accuracy: {train_accuracy * 100:.2f}%")
    class_names = [f"Class {i}" for i in range(nb_outputs)]

    # Plot the confusion matrix
    plot_confusion_matrix(train_conf_matrix, class_names, f"{model.__name__} Train Results - Regularization: {regularization}")
    test_accuracy, test_conf_matrix, _ = compute_classification_accuracy(x_test, y_test, model=model)

    # Print results
    print(f"Test Accuracy: {test_accuracy * 100:.2f}%")
    class_names = [f"Class {i}" for i in range(nb_outputs)]

    # Plot the confusion matrix
    plot_confusion_matrix(test_conf_matrix, class_names, f"{model.__name__} Train Results - Regularization: {regularization}")
    return spike_count

In [None]:
train_and_print_results(SNN_recurrent, epochs=40, regularization=True, l1=1e-6, l2=1e-6)

init done


KeyboardInterrupt: 

In [None]:
train_and_print_results(SNN, epochs=40, regularization=True, l1=1e-6, l2=1e-6)

init done


#### Outputs

In [None]:
models = [SNN, SNN_recurrent, ANN_non_recurrent_with_LIF_output, RNN_with_LIF_output, Hybrid_ANN_non_recurrent, Hybrid_RNN, Hybrid_RNN_SNN_rec]
# Create a table with all the model accuracies
model_accuracies = {}
for model in models:
    print(f"Training model: {model.__name__}")
    print("Regularization disabled")
    spike_count = train_and_print_results(model, epochs=40)
    print("--------------------------------------------------")
    acc, matrix, spike_count_t = compute_classification_accuracy(x_test, y_test, model=model)
    model_accuracies[model.__name__+"_No_Regularization"] = (acc, spike_count, spike_count_t)
    if "SNN" in model.__name__:
        print(f"Training model: {model.__name__}")
        print("Regularization enabled")
        train_and_print_results(model, epochs=40, regularization=True)
        acc, matrix, spike_count_t = compute_classification_accuracy(x_test, y_test, model=model)
        model_accuracies[model.__name__+"_Regularization"] = (acc, spike_count, spike_count_t)


Training model: SNN
Regularization disabled
init done


#### Skip to end

In [None]:
for model, (acc, spike1, spike2) in model_accuracies.items():
    print(f"{model}: {acc * 100:.2f}%")

{'SNN': 1.0,
 'SNN_recurrent': 0.98046875,
 'ANN_non_recurrent_with_LIF_output': 0.98046875,
 'RNN_with_LIF_output': 0.96875,
 'Hybrid_ANN_non_recurrent': 1.0,
 'Hybrid_RNN': 0.99609375,
 'Hybrid_RNN_SNN_rec': 0.515625}

In [None]:
for model, (acc, spike1, spike2) in model_accuracies.items():
    print(f"{model}: {spike1:,} Training, {spike2:,} Testing")