# Spike train
- Set hyperparameters
- Load data
- Generate spike train
- Prune spike train
- Train and test using various models

Note: To test for memory and time

In [1]:
import time

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from spikenet import dataset, neuron
import scipy.sparse as sp
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
from typing import Dict, List

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import warnings

warnings.filterwarnings("ignore", category=UserWarning)

## Spike generation

In [2]:
def get_DADx(adj, x, a=0.5, b=0.5):
  degree = np.array(adj.sum(1)).flatten()
  D_inv_a = np.power(degree, -a, where=degree!=0)
  D_inv_b = np.power(degree, -b, where=degree!=0)
  D_inv_a = sp.diags(D_inv_a)
  D_inv_b = sp.diags(D_inv_b)
  transformed_x = D_inv_a @ adj @ D_inv_b @ x
  return torch.FloatTensor(transformed_x)

def generate_spike_train(data: dataset.Dataset, hp: Dict) -> torch.Tensor:
  spike_train = []
  if hp["act"]=="IF":
    snn = neuron.IF(alpha=hp["alpha"], surrogate=hp["surrogate"])
  elif hp["act"]=="LIF":
    snn = neuron.LIF(tau=hp["tau"], alpha=hp["alpha"], surrogate=hp["surrogate"])
  elif hp["act"]=="PLIF":
    snn = neuron.PLIF(tau=hp["tau"], alpha=hp["alpha"], surrogate=hp["surrogate"])
  
  if hp["graph_type"]=="static":
    DADx = get_DADx(data.adj[-1], data.x[-1], a=hp["a"], b=hp["b"])
    for _ in range(hp["time_steps"]):
      spike_train.append(snn(DADx))
  else:
    for adj, x in zip(data.adj, data.x):
      spike_train.append(snn(get_DADx(adj, x, a=hp["a"], b=hp["b"])))
  return torch.stack(spike_train).to(torch.bool)

## Spike pruning

In [3]:
def prune_spikes(spike_train, hp: Dict) -> torch.Tensor:
  num_spikes = torch.sum(spike_train, dim=(1,2))
  prune_param = hp["prune_param"]
  median = torch.median(num_spikes)
  pruned_start_idx = 0
  while(num_spikes[pruned_start_idx] < median * prune_param):
    pruned_start_idx += 1
  return spike_train[pruned_start_idx:]

## ML models to classify the spike train
- LSTM
- MLP

In [4]:
class SpikeTrainDataset(Dataset):
    def __init__(self, X, y):
        self.X = X  # Need to typecast back into a float later
        self.y = y.long()   # Ensure labels are long tensors

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [5]:
class LSTMClassifier(nn.Module):
  def __init__(self, input_size, hidden_size, num_layers, num_classes):
    super().__init__()
    self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
    self.fc = nn.Linear(hidden_size, num_classes)

  def forward(self, x):
    out, (hn, _) = self.lstm(x)
    out = self.fc(hn[-1])
    return out
  
class MLPClassifier(nn.Module):
  def __init__(self, input_size, hidden_size, num_classes):
    super().__init__()
    self.fc1 = nn.Linear(input_size, hidden_size)
    self.relu = nn.ReLU()
    self.fc2 = nn.Linear(hidden_size, num_classes)

  def forward(self, x):
    out = self.fc1(x)
    out = self.relu(out)
    out = self.fc2(out)
    return out

## Helper functions

In [6]:
def check_hyperparameters(hyperparameters):
  if hyperparameters['dataset'] not in ['DBLP',]:
    raise Exception("Invalid dataset name")
  if (hyperparameters['a']+hyperparameters['b']!=1):
    raise Exception("a+b must be equal to 1")
  if hyperparameters['a']<0 or hyperparameters['b']<0:
    raise Exception("a and b must be positive")
  if hyperparameters["graph_type"] not in ["static", "dynamic"]:
    raise Exception("Invalid graph type, only static and dynamic are allowed")
  if hyperparameters["graph_type"]=="static":
    if hyperparameters["time_steps"] is None:
      raise Exception("time_steps is required for static graph")
  if hyperparameters["act"] not in ["IF", "LIF", "PLIF"]:
    raise Exception("Invalid activation function, only IF, LIF and PLIF are allowed")
  
def print_confusion_matrix(all_labels, all_preds):
  cm = confusion_matrix(all_labels, all_preds)
  plt.figure(figsize=(10, 8))
  sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
              xticklabels=range(10), yticklabels=range(10))
  plt.xlabel('Predicted')
  plt.ylabel('True')
  plt.title('Confusion Matrix')
  plt.show()

def get_tensor_memory(tensor):
  element_size = tensor.element_size()  # Size of each element in bytes
  num_elements = tensor.numel()         # Total number of elements
  total_memory = element_size * num_elements      # Total memory in bytes
  total_memory_mb = total_memory / (1024 ** 2)     # Convert to megabytes
  return total_memory_mb

In [7]:
def pack_tensor(tensor: torch.Tensor):
    """
    Packs a tensor of 1s and 0s into a space-optimized representation.
    
    Args:
        tensor (torch.Tensor): A float32 tensor containing 1s and 0s.
    
    Returns:
        torch.Tensor: A packed tensor (torch.uint8) with 1 bit per element.
        tuple: The original shape of the tensor for unpacking.
    """
    # Ensure the tensor is a float and convert to boolean (0 -> False, 1 -> True)
    tensor = tensor.to(torch.bool)
    original_shape = tensor.shape
    
    # Flatten the tensor and convert to numpy for bit packing
    flattened = tensor.flatten().numpy().astype(np.uint8)
    packed = np.packbits(flattened)  # Packs 8 boolean values into 1 byte
    
    # Convert back to a torch tensor
    packed_tensor = torch.from_numpy(packed).to(torch.uint8)
    return packed_tensor, original_shape

def unpack_tensor(packed: torch.Tensor, original_shape: tuple):
    """
    Unpacks a packed tensor back into its original form.
    
    Args:
        packed (torch.Tensor): A packed tensor (torch.uint8) with 1 bit per element.
        original_shape (tuple): The original shape of the tensor.
    
    Returns:
        torch.Tensor: The unpacked tensor.
    """
    # Convert to numpy and unpack the bits
    unpacked = np.unpackbits(packed.numpy())
    
    # Convert back to a torch tensor and reshape
    unpacked_tensor = torch.from_numpy(unpacked).to(torch.float32)
    unpacked_tensor = unpacked_tensor[:np.prod(original_shape)].reshape(original_shape)
    return unpacked_tensor


## Single input to output pipeline
What to track
- Memory usage for spike train
- Memory usage in training model
- Time taken to train model
- Accuracy

In [8]:
def get_results(hyperparameters, system_params):
  check_hyperparameters(hyperparameters)

  if hyperparameters['dataset'] == 'DBLP':
    data = dataset.DBLP()
  original_memory = get_tensor_memory(data.x if hyperparameters["graph_type"]=="dynamic" else data.x[-1])
  original_num_elements = data.x.numel() if hyperparameters["graph_type"]=="dynamic" else data.x[-1].numel()
  spike_train = generate_spike_train(data, hyperparameters)
  if hyperparameters["prune_param"] is not None:
    spike_train = prune_spikes(spike_train, hyperparameters)
  spike_train = spike_train.permute(1,0,2)
  if hyperparameters["model"]=="MLP":
    spike_train = spike_train.reshape(spike_train.shape[0], -1) # flatten the spike train
  compressed_spike_train, original_shape = pack_tensor(spike_train) # show the theorectical space savings
  spike_train = unpack_tensor(compressed_spike_train, original_shape)
  final_memory = get_tensor_memory(compressed_spike_train)
  final_num_elements = spike_train.numel()
  if system_params["test_memory"]:
    print(f"Original memory: {original_memory:.2f} MB, Final memory: {final_memory:.2f} MB")
    print(f"Original num elements: {original_num_elements}, Final num elements: {final_num_elements}")
  y = data.y
  X_train, X_test, y_train, y_test = train_test_split(spike_train, y, test_size=0.2, random_state=42, stratify=y)
  train_dataset = SpikeTrainDataset(X_train, y_train)
  test_dataset = SpikeTrainDataset(X_test, y_test)
  batch_size = system_params["batch_size"]
  train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
  test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
  
  # Model training
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  print(f'Using device: {device}')

  if hyperparameters["model"] == "LSTM":
    model = LSTMClassifier(input_size=spike_train.shape[-1], hidden_size=256, num_layers=2, num_classes=data.num_classes).to(device)
  elif hyperparameters["model"] == "MLP":
    model = MLPClassifier(input_size=spike_train.shape[-1], hidden_size=256, num_classes=data.num_classes).to(device)
  criterion = nn.CrossEntropyLoss()
  optimizer = optim.Adam(model.parameters(), lr=0.001)
  num_epochs = system_params["num_epochs"]

  start_time = time.time()
  final_accuracy = 0
  for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs = inputs.float().to(device) # Convert back to float right before passing to model for space optimisation
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    avg_loss = running_loss / len(train_loader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.float().to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    final_accuracy = accuracy
    if system_params["verbose"]:
      print(f'Accuracy on test set: {accuracy:.2f}%\n')
  end_time = time.time()
  time_taken = end_time - start_time
  
  all_preds = []
  all_labels = []
  model.eval()
  with torch.no_grad():
      for inputs, labels in test_loader:
          inputs = inputs.float().to(device)
          labels = labels.to(device)
          outputs = model(inputs)
          _, predicted = torch.max(outputs.data, 1)
          all_preds.extend(predicted.cpu().numpy())
          all_labels.extend(labels.cpu().numpy())
  if system_params["verbose"]:
    print_confusion_matrix(all_labels, all_preds)

  return final_accuracy, time_taken
  

# Main

In [9]:
baseline_hyperparameters = {
    "dataset": "DBLP", # DBLP
    "graph_type": "static", # static, dynamic
    "time_steps": 20, # Required for static graph
    "tau": 1.0,
    "alpha": 1.0,
    "surrogate": "triangle", 
    "act": "LIF", # IF, LIF, PLIF
    "a": 0.5, # a+b=1
    "b": 0.5, # a+b=1
    "prune_param": None, # Float or None
    "model": "MLP" # LSTM, MLP
}

baseline_hyperparameters_copy = baseline_hyperparameters.copy()

system_params = {
  "batch_size": 64,
  "num_epochs": 20,
  "verbose": False,
  "test_memory": True
}

In [10]:
def main():
  acc, time_taken = get_results(baseline_hyperparameters, system_params)

main()

Original memory: 13.71 MB, Final memory: 8.57 MB
Original num elements: 3594880, Final num elements: 71897600
Using device: cpu
Epoch [1/20], Loss: 1.0037
Epoch [2/20], Loss: 0.8301
Epoch [3/20], Loss: 0.7212
Epoch [4/20], Loss: 0.6217
Epoch [5/20], Loss: 0.5279
Epoch [6/20], Loss: 0.4374
Epoch [7/20], Loss: 0.3597
Epoch [8/20], Loss: 0.2758
Epoch [9/20], Loss: 0.2108
Epoch [10/20], Loss: 0.1710
Epoch [11/20], Loss: 0.1119
Epoch [12/20], Loss: 0.0891
Epoch [13/20], Loss: 0.0636
Epoch [14/20], Loss: 0.0520
Epoch [15/20], Loss: 0.0409
Epoch [16/20], Loss: 0.0534
Epoch [17/20], Loss: 0.1015
Epoch [18/20], Loss: 0.0350
Epoch [19/20], Loss: 0.0177
Epoch [20/20], Loss: 0.0166


## Tests
- Static vs Dynamic graphs
- MLP vs LSTM
- tau values
- Number of time steps for static graph
- a and b values
- Prune param

In [11]:
# Static and dynamic graphs for MLP and LSTM
graph_types = ["static", "dynamic"]
models = ["MLP", "LSTM"]
acc_list = []
time_list = []
for graph_type in graph_types:
  for model in models:
    print(f"Currently testing: Graph type: {graph_type}, Model: {model}")
    baseline_hyperparameters["graph_type"] = graph_type
    baseline_hyperparameters["model"] = model
    acc, time_taken = get_results(baseline_hyperparameters, system_params)
    acc_list.append(acc)
    time_list.append(time_taken)

idx = 0
for graph_type in graph_types:
  for model in models:
    print(f"Graph type: {graph_type}, Model: {model}, Accuracy: {acc_list[idx]:.2f}%, Time: {time_list[idx]:.2f} seconds")
    idx += 1

baseline_hyperparameters = baseline_hyperparameters_copy.copy()

Currently testing: Graph type: static, Model: MLP
Original memory: 13.71 MB, Final memory: 8.57 MB
Original num elements: 3594880, Final num elements: 71897600
Using device: cpu
Epoch [1/20], Loss: 0.9991
Epoch [2/20], Loss: 0.8224
Epoch [3/20], Loss: 0.7062
Epoch [4/20], Loss: 0.6170
Epoch [5/20], Loss: 0.5119
Epoch [6/20], Loss: 0.4297
Epoch [7/20], Loss: 0.3419
Epoch [8/20], Loss: 0.2689
Epoch [9/20], Loss: 0.2024
Epoch [10/20], Loss: 0.1611
Epoch [11/20], Loss: 0.1313
Epoch [12/20], Loss: 0.0835
Epoch [13/20], Loss: 0.0595
Epoch [14/20], Loss: 0.0598
Epoch [15/20], Loss: 0.0554
Epoch [16/20], Loss: 0.0553
Epoch [17/20], Loss: 0.0445
Epoch [18/20], Loss: 0.0454
Epoch [19/20], Loss: 0.0447
Epoch [20/20], Loss: 0.0497
Currently testing: Graph type: static, Model: LSTM
Original memory: 13.71 MB, Final memory: 8.57 MB
Original num elements: 3594880, Final num elements: 71897600
Using device: cpu
Epoch [1/20], Loss: 1.1015
Epoch [2/20], Loss: 0.8966
Epoch [3/20], Loss: 0.8265
Epoch [4/20

In [12]:
# Number of time steps and tau
time_steps_list = [10, 20, 30, 40, 50, 60]
tau_list = [1.0, 2.0, 5.0, 10.0]
acc_list = []
time_list = []

for time_steps in time_steps_list:
  for tau in tau_list:
    print(f"Currently testing: Time steps: {time_steps}, Tau: {tau}")
    baseline_hyperparameters["time_steps"] = time_steps
    baseline_hyperparameters["tau"] = tau
    acc, time_taken = get_results(baseline_hyperparameters, system_params)
    acc_list.append(acc)
    time_list.append(time_taken)

idx = 0
for time_steps in time_steps_list:
  for tau in tau_list:
    print(f"Time steps: {time_steps}, Tau: {tau}, Accuracy: {acc_list[idx]:.2f}%, Time: {time_list[idx]:.2f} seconds")
    idx += 1

baseline_hyperparameters = baseline_hyperparameters_copy.copy()

Currently testing: Time steps: 10, Tau: 1.0
Original memory: 13.71 MB, Final memory: 4.29 MB
Original num elements: 3594880, Final num elements: 35948800
Using device: cpu
Epoch [1/20], Loss: 1.0165
Epoch [2/20], Loss: 0.8331
Epoch [3/20], Loss: 0.7386
Epoch [4/20], Loss: 0.6475
Epoch [5/20], Loss: 0.5688
Epoch [6/20], Loss: 0.5005
Epoch [7/20], Loss: 0.4303
Epoch [8/20], Loss: 0.3550
Epoch [9/20], Loss: 0.2927
Epoch [10/20], Loss: 0.2298
Epoch [11/20], Loss: 0.1844
Epoch [12/20], Loss: 0.1414
Epoch [13/20], Loss: 0.1077
Epoch [14/20], Loss: 0.0804
Epoch [15/20], Loss: 0.0663
Epoch [16/20], Loss: 0.0503
Epoch [17/20], Loss: 0.0439
Epoch [18/20], Loss: 0.0437
Epoch [19/20], Loss: 0.0364
Epoch [20/20], Loss: 0.0480
Currently testing: Time steps: 10, Tau: 2.0
Original memory: 13.71 MB, Final memory: 4.29 MB
Original num elements: 3594880, Final num elements: 35948800
Using device: cpu
Epoch [1/20], Loss: 1.0218
Epoch [2/20], Loss: 0.8095
Epoch [3/20], Loss: 0.7010
Epoch [4/20], Loss: 0.59

In [13]:
# a and b values and prune param
a_list = [0.1, 0.3, 0.5, 0.7, 0.9]
prune_list = [None, 0.6, 0.8, 1.0]
acc_list = []
time_list = []

for a in a_list:
  for prune in prune_list:
    print(f"Currently testing: a: {a}, b: {1-a}, Prune: {prune}")
    baseline_hyperparameters["a"] = a
    baseline_hyperparameters["b"] = 1-a
    baseline_hyperparameters["prune_param"] = prune
    acc, time_taken = get_results(baseline_hyperparameters, system_params)
    acc_list.append(acc)
    time_list.append(time_taken)

idx = 0
for a in a_list:
  for prune in prune_list:
    print(f"a: {a}, b: {1-a}, Prune: {prune}, Accuracy: {acc_list[idx]:.2f}%, Time: {time_list[idx]:.2f} seconds")
    idx += 1

baseline_hyperparameters = baseline_hyperparameters_copy.copy()

Currently testing: a: 0.1, b: 0.9, Prune: None
Original memory: 13.71 MB, Final memory: 8.57 MB
Original num elements: 3594880, Final num elements: 71897600
Using device: cpu
Epoch [1/20], Loss: 1.0258
Epoch [2/20], Loss: 0.8263
Epoch [3/20], Loss: 0.7129
Epoch [4/20], Loss: 0.6040
Epoch [5/20], Loss: 0.5099
Epoch [6/20], Loss: 0.4042
Epoch [7/20], Loss: 0.3222
Epoch [8/20], Loss: 0.2485
Epoch [9/20], Loss: 0.1817
Epoch [10/20], Loss: 0.1421
Epoch [11/20], Loss: 0.1068
Epoch [12/20], Loss: 0.0860
Epoch [13/20], Loss: 0.0643
Epoch [14/20], Loss: 0.0494
Epoch [15/20], Loss: 0.0386
Epoch [16/20], Loss: 0.0462
Epoch [17/20], Loss: 0.0556
Epoch [18/20], Loss: 0.0551
Epoch [19/20], Loss: 0.0515
Epoch [20/20], Loss: 0.0447
Currently testing: a: 0.1, b: 0.9, Prune: 0.6
Original memory: 13.71 MB, Final memory: 7.71 MB
Original num elements: 3594880, Final num elements: 64707840
Using device: cpu
Epoch [1/20], Loss: 1.0403
Epoch [2/20], Loss: 0.8496
Epoch [3/20], Loss: 0.7394
Epoch [4/20], Loss: