# Attempting to optimise the spike generation process

## Status quo

In [None]:
import time

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from spikenet import dataset, neuron
import scipy.sparse as sp
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
from typing import Dict, List

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import warnings

warnings.filterwarnings("ignore", category=UserWarning)

In [101]:
data = dataset.DBLP()

hp = {
    "dataset": "DBLP", # DBLP
    "graph_type": "dynamic", # static, dynamic
    "time_steps": 10, # Required for static graph
    "tau": 1.0,
    "alpha": 1.0,
    "surrogate": "triangle", 
    "act": "LIF", # IF, LIF, PLIF
    "a": 0.5, # a+b=1
    "b": 0.5, # a+b=1
    "prune_param": None, # Float or None
    "model": "LSTM", # LSTM, MLP
    "threshold": 0.5,
}

In [102]:
def get_DADx(adj, x, a=0.5, b=0.5):
  degree = np.array(adj.sum(1)).flatten()
  D_inv_a = np.power(degree, -a, where=degree!=0)
  D_inv_b = np.power(degree, -b, where=degree!=0)
  D_inv_a = sp.diags(D_inv_a)
  D_inv_b = sp.diags(D_inv_b)
  transformed_x = D_inv_a @ adj @ D_inv_b @ x
  return torch.FloatTensor(transformed_x)

def _generate_static_spike_train(data: dataset.Dataset, hp: Dict, snn) -> torch.Tensor:
  spike_train = []
  DADx = get_DADx(data.adj[-1], data.x[-1], a=hp["a"], b=hp["b"])
  for _ in range(hp["time_steps"]):
    spike_train.append(snn(DADx))
  return torch.stack(spike_train).to(torch.bool)

def _generate_dynamic_spike_train(data: dataset.Dataset, hp: Dict, snn) -> torch.Tensor:
    spike_train_all = []
    T = len(data.adj)
    DADx_prev = None
    spikes_prev = None
    for t in range(T):
        DADx_t = get_DADx(data.adj[t], data.x[t], a=hp["a"], b=hp["b"])
        if t==1:
           print(DADx_t)
        snn.reset()
        if DADx_prev is not None:
          delta = torch.abs(DADx_t - DADx_prev).max(dim=1)[0]
        spike_trains_this_snapshot = []
        for _ in range(hp["time_steps"]):
            spikes = snn(DADx_t)
            spike_trains_this_snapshot.append(spikes)
        spikes_t = torch.stack(spike_trains_this_snapshot)
        spike_train_all.append(spikes_t)
    spike_train_all = torch.stack(spike_train_all, dim=0)
    spike_train_all = spike_train_all.view(-1, spike_train_all.size(-2), spike_train_all.size(-1))
    spike_train_all = spike_train_all.to(torch.bool)

    return spike_train_all

def generate_spike_train(data: dataset.Dataset, hp: Dict) -> torch.Tensor:
  snn = neuron.LIF(tau=hp["tau"], alpha=hp["alpha"], surrogate=hp["surrogate"])
  if hp["graph_type"]=="static":
    return _generate_static_spike_train(data, hp, snn)
  else:
    return _generate_dynamic_spike_train(data, hp, snn)
   

### Proving that once DADx is calculated, setting some nodes to zero does not affect the other nodes

This is important to investigate as it means that we can optimise our calculation of the spike train by only calculate `snn(DADx)` for nodes that are affected (changed beyond a certain percentage threshold etc.)

The below block of code shows that settings several values of DADx to 0 does not affect the spike train generated from the other values.


In [None]:
hp["graph_type"] = "static"
original_spikes = generate_spike_train(data, hp)

snn = neuron.LIF(tau=hp["tau"], alpha=hp["alpha"], surrogate=hp["surrogate"])
spike_train = []
DADx = get_DADx(data.adj[-1], data.x[-1], a=hp["a"], b=hp["b"])
print(f'DADx shape: {DADx.shape}')
for _ in range(hp["time_steps"]):
  spike_train.append(snn(DADx))
original_st = torch.stack(spike_train).to(torch.bool)
print(f'Spike train shape: {original_st.shape}')

# Investgate if setting DADx to zero for some nodes will affect the output for the other nodes
# Set DADx to zero for some nodes

print(f'Original DADx: {DADx}')

modified_nodes = [0, 1, 2, 3, 4, 5]
for node in modified_nodes:
  DADx[node, :] = 0

print(f'Modified DADx: {DADx}')

snn = neuron.LIF(tau=hp["tau"], alpha=hp["alpha"], surrogate=hp["surrogate"])
spike_train = []
for _ in range(hp["time_steps"]):
  spike_train.append(snn(DADx))
modified_st = torch.stack(spike_train).to(torch.bool)

# Check if the spike trains are equal for the original and modified DADx besides the affected nodes

# Iterate over all the spikes and check if they are equal
for i in range(original_st.shape[0]):
  for j in range(original_st.shape[1]):
    if j in modified_nodes:
      continue
    if not torch.equal(original_st[i][j], modified_st[i][j]):
      print(f'Spike trains are not equal at time step {i} and node {j}')

# ==================================================================================================

DADx = get_DADx(data.adj[-1], data.x[-1], a=hp["a"], b=hp["b"]) # Reset DADx
modified_nodes = [(0,0), (0,2), (0,4), (1,2), (1,4), (2,1), (2,3), (2,5)]
for x, y in modified_nodes:
  DADx[x][y] = 0

print(f'Modified DADx: {DADx}')

snn = neuron.LIF(tau=hp["tau"], alpha=hp["alpha"], surrogate=hp["surrogate"])
spike_train = []
for _ in range(hp["time_steps"]):
  spike_train.append(snn(DADx))
modified_st = torch.stack(spike_train).to(torch.bool)

# Iterate over all the spikes and check if they are equal
for i in range(original_st.shape[0]):
  print(f'Time step {i}')
  for j in range(original_st.shape[1]):
    for k in range(original_st.shape[2]):
      if (j, k) in modified_nodes:
        continue
      if not torch.equal(original_st[i][j][k], modified_st[i][j][k]):
        print(f'Spike trains are not equal at time step {i} and node {j} and feautres {k}')

Using this code as a baseline, I want to optimize the spike generation process by avoiding regenerating spikes for time=t+1 if we already have the spikes at t. That means that most of the spikes are identical but the affected nodes. For example, when DADX changes a lot for some nodes, we can only update the spikes for certain dimensions.

In [103]:
hp["graph_type"] = "dynamic"
original_spikes = generate_spike_train(data, hp)

snn = neuron.LIF(tau=hp["tau"], alpha=hp["alpha"], surrogate=hp["surrogate"])
spike_train_all = []
T = len(data.adj)
DADx_prev = None
spikes_prev = None
threshold = hp["threshold"]
for t in range(T):
    print(f't: {t}')
    snn.reset()
    DADx_t = get_DADx(data.adj[t], data.x[t], a=hp["a"], b=hp["b"])
    if DADx_prev is not None:
        # delta shape: (num_nodes,), mask shape: (num_nodes,)
        delta = torch.abs(DADx_t - DADx_prev).max(dim=1)[0] # Get the max feature difference for each node
        mask = delta < threshold # Mask nodes whose change is insignificant
        print(f't: {t} Mask: {mask}')
        DADx_t[mask] = 0 # Essentially we just ignore the nodes by taking the previous spike and changing the DADx to zero
    spike_trains_this_snapshot = []
    for _ in range(hp["time_steps"]):
        spikes = snn(DADx_t)
        spike_trains_this_snapshot.append(spikes)
    spikes_t = torch.stack(spike_trains_this_snapshot)
    # Where Mask is true, copy the spikes from the previous snapshot
    # (time_steps, num_nodes, num_features)
    if spikes_prev is not None:
        for i in range(hp["time_steps"]):
            x = spikes_prev[i][mask]
            print(f"x shape: {x.shape} x sum: {x.sum()}")
            spikes_t[i][mask] = spikes_prev[i][mask]
    spike_train_all.append(spikes_t)
    DADx_prev = DADx_t
    spikes_prev = spikes_t
spike_train_all = torch.stack(spike_train_all, dim=0)
spike_train_all = spike_train_all.view(-1, spike_train_all.size(-2), spike_train_all.size(-1))
spike_train_all = spike_train_all.to(torch.bool)
original_st = spike_train_all

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
t: 0
t: 1
t: 1 Mask: tensor([True, True, True,  ..., True, True, True])
x shape: torch.Size([26797, 128]) x sum: 0.0
x shape: torch.Size([26797, 128]) x sum: 0.0
x shape: torch.Size([26797, 128]) x sum: 0.0
x shape: torch.Size([26797, 128]) x sum: 0.0
x shape: torch.Size([26797, 128]) x sum: 0.0
x shape: torch.Size([26797, 128]) x sum: 0.0
x shape: torch.Size([26797, 128]) x sum: 0.0
x shape: torch.Size([26797, 128]) x sum: 0.0
x shape: torch.Size([26797, 128]) x sum: 0.0
x shape: torch.Size([26797, 128]) x sum: 0.0
t: 2
t: 2 Mask: tensor([True, True, True,  ..., True, True, True])
x shape: torch.Size([26177, 128]) x sum: 0.0
x shape: torch.Size([26177, 128]) x sum: 0.0
x shape: torch.Size([26177, 128]) x sum: 0.0
x shape: torch.Size([2