# Spike train
- Set hyperparameters
- Load data
- Generate spike train
- Prune spike train
- Train and test using various models

Note: To test for memory and time

# Currently updating

1. Record the saved file size and compare with the original representation
2. Compare the number of MAC operations

In [None]:
import time

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from spikenet import dataset, neuron
import scipy.sparse as sp
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
from typing import Dict, List

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import warnings

from thop import profile

warnings.filterwarnings("ignore", category=UserWarning)

## Spike generation

In [None]:
def get_DADx(adj, x, a=0.5, b=0.5):
  degree = np.array(adj.sum(1)).flatten()
  D_inv_a = np.power(degree, -a, where=degree!=0)
  D_inv_b = np.power(degree, -b, where=degree!=0)
  D_inv_a = sp.diags(D_inv_a)
  D_inv_b = sp.diags(D_inv_b)
  transformed_x = D_inv_a @ adj @ D_inv_b @ x
  return torch.FloatTensor(transformed_x)

def _generate_dynamic_spike_train(data: dataset.Dataset, hp: Dict, snn) -> torch.Tensor:
    """
    data.adj : shape = (T, N, N) or list of length T of adjacency matrices
    data.x   : shape = (T, N, F) or list of length T of node features
    T = number of snapshots
    N = number of nodes
    F = number of node features

    hp["time_steps"] = how many internal SNN time steps to simulate per snapshot
    hp["a"], hp["b"] = exponents in D^(-a) A D^(-b)
    """
    spike_train_all = []
    DADx_prev = None
    spikes_prev = None
    threshold = hp["threshold"]

    T = len(data.adj)       # number of snapshots
    # 2) Loop over each snapshot
    for t in range(T):
        snn.reset()
        adj_t = data.adj[t]  # NxN
        x_t = data.x[t]      # NxF
        DADx_t = get_DADx(adj_t, x_t, a=hp["a"], b=hp["b"])

        if DADx_prev is not None:
            # delta shape: (num_nodes,), mask shape: (num_nodes,)
            delta = torch.abs(DADx_t - DADx_prev).max(dim=1)[0] # Get the max feature difference for each node
            mask = delta < threshold # Mask nodes whose change is insignificant
            DADx_t[mask] = 0

        spike_trains_this_snapshot = []
        for _ in range(hp["time_steps"]):
            spikes = snn(DADx_t)
            spike_trains_this_snapshot.append(spikes)
        spikes_t = torch.stack(spike_trains_this_snapshot)

        if spikes_prev is not None:
            for i in range(hp["time_steps"]):
                x = spikes_prev[i][mask]
                print(f"x shape: {x.shape} spikes brought over: {x.sum()}")
                spikes_t[i][mask] = spikes_prev[i][mask]

        spike_train_all.append(spikes_t)
        DADx_prev = DADx_t
        spikes_prev = spikes_t

    # 3) Concatenate all T snapshots if you’d like: shape = (T, time_steps, N, F)
    spike_train_all = torch.stack(spike_train_all, dim=0)
    spike_train_all = spike_train_all.view(-1, spike_train_all.size(-2), spike_train_all.size(-1))

    # (Optional) convert to bool for memory savings
    spike_train_all = spike_train_all.to(torch.bool)

    return spike_train_all

def _generate_static_spike_train(data: dataset.Dataset, hp: Dict, snn) -> torch.Tensor:
  spike_train = []
  DADx = get_DADx(data.adj[-1], data.x[-1], a=hp["a"], b=hp["b"])
  for _ in range(hp["time_steps"]):
    spike_train.append(snn(DADx))
  return torch.stack(spike_train).to(torch.bool)

def generate_spike_train(data: dataset.Dataset, hp: Dict) -> torch.Tensor:
  if hp["act"] == "IF":
      snn = neuron.IF(alpha=hp["alpha"], surrogate=hp["surrogate"])
  elif hp["act"] == "LIF":
      snn = neuron.LIF(tau=hp["tau"], alpha=hp["alpha"], surrogate=hp["surrogate"])
  elif hp["act"] == "PLIF":
      snn = neuron.PLIF(tau=hp["tau"], alpha=hp["alpha"], surrogate=hp["surrogate"])

  if hp["graph_type"]=="static":
    return _generate_static_spike_train(data, hp, snn)
  else:
    # Final shape should be (time_steps, num_nodes, num_features)
    return _generate_dynamic_spike_train(data, hp, snn)

## Spike pruning

In [None]:
def prune_spikes(spike_train, hp: Dict) -> torch.Tensor:
  num_spikes = torch.sum(spike_train, dim=(1,2))
  prune_param = hp["prune_param"]
  median = torch.median(num_spikes)
  pruned_start_idx = 0
  while(num_spikes[pruned_start_idx] < median * prune_param):
    pruned_start_idx += 1
  return spike_train[pruned_start_idx:]

## ML models to classify the spike train
- LSTM
- MLP

In [None]:
class SpikeTrainDataset(Dataset):
    def __init__(self, X, y):
        self.X = X  # Need to typecast back into a float later
        self.y = y.long()   # Ensure labels are long tensors

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [None]:
class LSTMClassifier(nn.Module):
  def __init__(self, input_size, hidden_size, num_layers, num_classes):
    super().__init__()
    self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
    self.fc = nn.Linear(hidden_size, num_classes)

  def forward(self, x):
    out, (hn, _) = self.lstm(x)
    out = self.fc(hn[-1])
    return out
  
class MLPClassifier(nn.Module):
  def __init__(self, input_size, hidden_size, num_classes):
    super().__init__()
    self.fc1 = nn.Linear(input_size, hidden_size)
    self.relu = nn.ReLU()
    self.fc2 = nn.Linear(hidden_size, num_classes)

  def forward(self, x):
    out = self.fc1(x)
    out = self.relu(out)
    out = self.fc2(out)
    return out

## Helper functions

In [None]:
def check_hyperparameters(hyperparameters):
  if hyperparameters['dataset'] not in ['DBLP',]:
    raise Exception("Invalid dataset name")
  if (hyperparameters['a']+hyperparameters['b']!=1):
    raise Exception("a+b must be equal to 1")
  if hyperparameters['a']<0 or hyperparameters['b']<0:
    raise Exception("a and b must be positive")
  if hyperparameters["graph_type"] not in ["static", "dynamic"]:
    raise Exception("Invalid graph type, only static and dynamic are allowed")
  if hyperparameters["graph_type"]=="static":
    if hyperparameters["time_steps"] is None:
      raise Exception("time_steps is required for static graph")
  if hyperparameters["act"] not in ["IF", "LIF", "PLIF"]:
    raise Exception("Invalid activation function, only IF, LIF and PLIF are allowed")
  
def print_confusion_matrix(all_labels, all_preds):
  cm = confusion_matrix(all_labels, all_preds)
  plt.figure(figsize=(10, 8))
  sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
              xticklabels=range(10), yticklabels=range(10))
  plt.xlabel('Predicted')
  plt.ylabel('True')
  plt.title('Confusion Matrix')
  plt.show()

def get_tensor_memory(tensor):
  element_size = tensor.element_size()  # Size of each element in bytes
  num_elements = tensor.numel()         # Total number of elements
  total_memory = element_size * num_elements      # Total memory in bytes
  total_memory_mb = total_memory / (1024 ** 2)     # Convert to megabytes
  return total_memory_mb

In [None]:
def pack_tensor(tensor: torch.Tensor):
    """
    Packs a tensor of 1s and 0s into a space-optimized representation.
    
    Args:
        tensor (torch.Tensor): A float32 tensor containing 1s and 0s.
    
    Returns:
        torch.Tensor: A packed tensor (torch.uint8) with 1 bit per element.
        tuple: The original shape of the tensor for unpacking.
    """
    # Ensure the tensor is a float and convert to boolean (0 -> False, 1 -> True)
    tensor = tensor.to(torch.bool)
    original_shape = tensor.shape
    
    # Flatten the tensor and convert to numpy for bit packing
    flattened = tensor.flatten().numpy().astype(np.uint8)
    packed = np.packbits(flattened)  # Packs 8 boolean values into 1 byte
    
    # Convert back to a torch tensor
    packed_tensor = torch.from_numpy(packed).to(torch.uint8)
    return packed_tensor, original_shape

def unpack_tensor(packed: torch.Tensor, original_shape: tuple):
    """
    Unpacks a packed tensor back into its original form.
    
    Args:
        packed (torch.Tensor): A packed tensor (torch.uint8) with 1 bit per element.
        original_shape (tuple): The original shape of the tensor.
    
    Returns:
        torch.Tensor: The unpacked tensor.
    """
    # Convert to numpy and unpack the bits
    unpacked = np.unpackbits(packed.numpy())
    
    # Convert back to a torch tensor and reshape
    unpacked_tensor = torch.from_numpy(unpacked).to(torch.float32)
    unpacked_tensor = unpacked_tensor[:np.prod(original_shape)].reshape(original_shape)
    return unpacked_tensor


## Single input to output pipeline
What to track
- Memory usage for spike train
- Memory usage in training model
- Time taken to train model
- Accuracy

In [None]:
def get_results(hyperparameters, system_params):
    check_hyperparameters(hyperparameters)

    # ----------------------
    # Data Loading and Setup
    # ----------------------
    if hyperparameters['dataset'] == 'DBLP':
        data = dataset.DBLP()

    original_memory = get_tensor_memory(data.x if hyperparameters["graph_type"]=="dynamic" else data.x[-1])
    original_num_elements = data.x.numel() if hyperparameters["graph_type"]=="dynamic" else data.x[-1].numel()

    spike_train = generate_spike_train(data, hyperparameters)

    print(spike_train.shape)
    
    if hyperparameters["prune_param"] is not None:
        spike_train = prune_spikes(spike_train, hyperparameters)

    # Permute to [samples, time, features] if RNN-based model
    spike_train = spike_train.permute(1, 0, 2)

    # For MLP, flatten the [time, features] dimension
    if hyperparameters["model"]=="MLP":
        spike_train = spike_train.reshape(spike_train.shape[0], -1)  # [samples, time * features]

    print(spike_train.shape)
    # --------------
    # Compression (packing) example
    # --------------
    compressed_spike_train, original_shape = pack_tensor(spike_train)
    # Show theoretical space savings by "packing"
    spike_train = unpack_tensor(compressed_spike_train, original_shape)
    final_memory = get_tensor_memory(compressed_spike_train)

    if system_params["save_tensor"]:
        if hyperparameters["graph_type"]=="dynamic":
            data.x.numpy().tofile(f"{hyperparameters['dataset']}_x_original.npy")
        else:
            data.x[-1].numpy().tofile(f"{hyperparameters['dataset']}_x[-1]_original.npy")
        compressed_spike_train.numpy().tofile(f"{hyperparameters['dataset']}_spike_train_compressed.npy")

    final_num_elements = spike_train.numel()
    if system_params["test_memory"]:
        print(f"Original memory: {original_memory:.2f} MB, Final memory: {final_memory:.2f} MB")
        print(f"Original num elements: {original_num_elements}, Final num elements: {final_num_elements}")

    # -----------------------
    # Train/Test Split
    # -----------------------
    y = data.y
    X_train, X_test, y_train, y_test = train_test_split(spike_train, y,
                                                        test_size=0.2,
                                                        random_state=42,
                                                        stratify=y)

    train_dataset = SpikeTrainDataset(X_train, y_train)
    test_dataset  = SpikeTrainDataset(X_test,  y_test)

    batch_size = system_params["batch_size"]
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False)

    # ------------------------
    # Model Definition
    # ------------------------
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')

    if hyperparameters["model"] == "LSTM":
        # LSTM expects [batch_size, seq_len, input_size]
        model = LSTMClassifier(input_size=spike_train.shape[-1],
                               hidden_size=256,
                               num_layers=2,
                               num_classes=data.num_classes).to(device)
    elif hyperparameters["model"] == "MLP":
        # MLP expects [batch_size, input_size]
        model = MLPClassifier(input_size=spike_train.shape[-1],
                              hidden_size=256,
                              num_classes=data.num_classes).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    num_epochs = system_params["num_epochs"]

    # ---------------------------------
    # Measure MACs for a Single Forward
    # ---------------------------------
    # We'll approximate training MACs as (forward MACs + backward MACs)
    # Typically backward pass ~2x forward pass => total ~3x forward pass.
    # Also measure inference MACs for the test set.

    # Get a single batch from train_loader for MAC profiling
    dummy_input, _ = next(iter(train_loader))
    dummy_input = dummy_input.float().to(device)

    # Profile the forward pass
    macs, params = profile(model, inputs=(dummy_input,), verbose=False)
    print(f"Single-batch MACs (forward): {macs:.2f}, Number of parameters: {params}")

    # Multiply by the number of training batches and epochs
    macs_per_epoch_forward = macs * len(train_loader)
    training_macs_forward = macs_per_epoch_forward * num_epochs

    # Approximate backward pass cost as 2× forward
    # (This is a rough rule of thumb, actual overhead can vary.)
    training_macs_backward = 2 * training_macs_forward

    # Total training MACs
    total_training_macs = training_macs_forward + training_macs_backward
    print(f"Approx. total training MACs (forward+backward): {total_training_macs:.2f}")

    # Inference (test) MACs: #batches × single forward pass
    inference_macs = macs * len(test_loader)
    print(f"Approx. total test inference MACs: {inference_macs:.2f}")

    # --------------
    # Training Loop
    # --------------
    start_time = time.time()
    final_accuracy = 0
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs = inputs.float().to(device)  # Convert back to float
            labels = labels.to(device)
            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        avg_loss = running_loss / len(train_loader)
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')

        # Evaluation
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs = inputs.float().to(device)
                labels = labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        accuracy = 100.0 * correct / total
        final_accuracy = accuracy
        if system_params["verbose"]:
            print(f'Accuracy on test set: {accuracy:.2f}%\n')

    end_time = time.time()
    time_taken = end_time - start_time

    # ----------------------------
    # Gather Predictions for CM
    # ----------------------------
    all_preds = []
    all_labels = []
    model.eval()
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.float().to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    if system_params["verbose"]:
        print_confusion_matrix(all_labels, all_preds)

    return final_accuracy, time_taken

# Main

In [None]:
baseline_hyperparameters = {
    "dataset": "DBLP", # DBLP
    "graph_type": "dynamic", # static, dynamic
    "time_steps": 20, # Required for static graph
    "tau": 1.0,
    "alpha": 1.0,
    "surrogate": "triangle", 
    "act": "LIF", # IF, LIF, PLIF
    "a": 0.5, # a+b=1
    "b": 0.5, # a+b=1
    "prune_param": None, # Float or None
    "model": "MLP", # LSTM, MLP
    "threshold": 0.5
}

baseline_hyperparameters_copy = baseline_hyperparameters.copy()

system_params = {
  "batch_size": 64,
  "num_epochs": 20,
  "verbose": False,
  "test_memory": True,
  "save_tensor": True
}

In [None]:
def main():
  acc, time_taken = get_results(baseline_hyperparameters, system_params)
  print(f"Accuracy: {acc:.2f}%, Time taken: {time_taken:.2f} seconds")

main()

## Tests
- Static vs Dynamic graphs
- MLP vs LSTM
- tau values
- Number of time steps for static graph
- a and b values
- Prune param

In [None]:
# Static and dynamic graphs for MLP and LSTM
graph_types = ["static", "dynamic"]
models = ["MLP", "LSTM"]
acc_list = []
time_list = []
for graph_type in graph_types:
  for model in models:
    print(f"Currently testing: Graph type: {graph_type}, Model: {model}")
    baseline_hyperparameters["graph_type"] = graph_type
    baseline_hyperparameters["model"] = model
    acc, time_taken = get_results(baseline_hyperparameters, system_params)
    acc_list.append(acc)
    time_list.append(time_taken)

idx = 0
for graph_type in graph_types:
  for model in models:
    print(f"Graph type: {graph_type}, Model: {model}, Accuracy: {acc_list[idx]:.2f}%, Time: {time_list[idx]:.2f} seconds")
    idx += 1

baseline_hyperparameters = baseline_hyperparameters_copy.copy()

In [None]:
# Number of time steps and tau
time_steps_list = [10, 20, 30, 40, 50, 60]
tau_list = [1.0, 2.0, 5.0, 10.0]
acc_list = []
time_list = []

for time_steps in time_steps_list:
  for tau in tau_list:
    print(f"Currently testing: Time steps: {time_steps}, Tau: {tau}")
    baseline_hyperparameters["time_steps"] = time_steps
    baseline_hyperparameters["tau"] = tau
    acc, time_taken = get_results(baseline_hyperparameters, system_params)
    acc_list.append(acc)
    time_list.append(time_taken)

idx = 0
for time_steps in time_steps_list:
  for tau in tau_list:
    print(f"Time steps: {time_steps}, Tau: {tau}, Accuracy: {acc_list[idx]:.2f}%, Time: {time_list[idx]:.2f} seconds")
    idx += 1

baseline_hyperparameters = baseline_hyperparameters_copy.copy()

In [None]:
# a and b values and prune param
a_list = [0.1, 0.3, 0.5, 0.7, 0.9]
prune_list = [None, 0.6, 0.8, 1.0]
acc_list = []
time_list = []

for a in a_list:
  for prune in prune_list:
    print(f"Currently testing: a: {a}, b: {1-a}, Prune: {prune}")
    baseline_hyperparameters["a"] = a
    baseline_hyperparameters["b"] = 1-a
    baseline_hyperparameters["prune_param"] = prune
    acc, time_taken = get_results(baseline_hyperparameters, system_params)
    acc_list.append(acc)
    time_list.append(time_taken)

idx = 0
for a in a_list:
  for prune in prune_list:
    print(f"a: {a}, b: {1-a}, Prune: {prune}, Accuracy: {acc_list[idx]:.2f}%, Time: {time_list[idx]:.2f} seconds")
    idx += 1

baseline_hyperparameters = baseline_hyperparameters_copy.copy()

# Investigating the accuracy with time stamp relationship

- Generate spike train (of length 10) for each time stamp
- Train a separate MLP on each time step and test

In [None]:
hp = {
    "dataset": "DBLP", # DBLP
    "graph_type": "dynamic", # static, dynamic
    "time_steps": 10, # Required for static graph
    "tau": 1.0,
    "alpha": 1.0,
    "surrogate": "triangle", 
    "act": "LIF", # IF, LIF, PLIF
    "a": 0.5, # a+b=1
    "b": 0.5, # a+b=1
    "prune_param": None, # Float or None
    "model": "LSTM" # LSTM, MLP
}

system_params = {
  "batch_size": 64,
  "num_epochs": 20,
  "verbose": False,
  "test_memory": True,
  "save_tensor": True
}

In [None]:
data = dataset.DBLP()
snn = neuron.LIF(tau=hp["tau"], alpha=hp["alpha"], surrogate=hp["surrogate"])
spike_train_all = []

num_time_stamps = len(data.adj) # number of snapshots

# Generate static spike train from the last snapshot
spike_train = generate_spike_train(data, hp)
print(spike_train.shape)

In [None]:
spike_train = spike_train.view(num_time_stamps, hp["time_steps"], spike_train.size(-2), spike_train.size(-1))
print(spike_train.shape)

In [None]:
# Train a MLP on the different time stamps to see the accuracy over time
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
y = data.y
acc_list = []

for t in range(num_time_stamps):
  cur_spike_train = spike_train[t]
  cur_spike_train = cur_spike_train.permute(1, 0, 2)
  cur_spike_train = cur_spike_train.reshape(cur_spike_train.shape[0], -1)  # [samples, time * features]

  y = data.y
  X_train, X_test, y_train, y_test = train_test_split(cur_spike_train, y,
                                                      test_size=0.2,
                                                      random_state=42,
                                                      stratify=y)

  train_dataset = SpikeTrainDataset(X_train, y_train)
  test_dataset  = SpikeTrainDataset(X_test,  y_test)

  batch_size = system_params["batch_size"]
  train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
  test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False)

  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  print(f'Using device: {device}')

  model = MLPClassifier(input_size=cur_spike_train.shape[-1],
                        hidden_size=256,
                        num_classes=data.num_classes).to(device)

  criterion = nn.CrossEntropyLoss()
  optimizer = optim.Adam(model.parameters(), lr=0.001)
  num_epochs = system_params["num_epochs"]

  start_time = time.time()
  final_accuracy = 0
  for epoch in range(num_epochs):
      model.train()
      running_loss = 0.0
      for inputs, labels in train_loader:
          inputs = inputs.float().to(device)  # Convert back to float
          labels = labels.to(device)
          optimizer.zero_grad()

          outputs = model(inputs)
          loss = criterion(outputs, labels)
          loss.backward()
          optimizer.step()

          running_loss += loss.item()

      avg_loss = running_loss / len(train_loader)
      print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')

      # Evaluation
      model.eval()
      correct = 0
      total = 0
      with torch.no_grad():
          for inputs, labels in test_loader:
              inputs = inputs.float().to(device)
              labels = labels.to(device)
              outputs = model(inputs)
              _, predicted = torch.max(outputs.data, 1)
              total += labels.size(0)
              correct += (predicted == labels).sum().item()

      accuracy = 100.0 * correct / total
      final_accuracy = accuracy
      if system_params["verbose"]:
          print(f'Accuracy on test set: {accuracy:.2f}%\n')

  end_time = time.time()
  time_taken = end_time - start_time
  acc_list.append(final_accuracy)

In [None]:
# Plot T against time_stamp where the y axis ranges from 0 to 100
# Make sure the y axis starts from 0 and ends at 100
plt.plot(range(num_time_stamps), acc_list)
plt.xlabel("Time Stamp")
plt.ylabel("Accuracy")
plt.title("Accuracy over Time")
plt.ylim(0, 100)

plt.show()