In [58]:
import torch
import torch.nn as nn
import snntorch as snn

To transform a Graph Convolutional Network (GCN) into a spiking version, we need to replace the traditional neural network layers with their spiking counterparts. 

This process involves introducing temporal dynamics and modeling the information propagation using spike trains instead of static activations.

# LIF : Leaky Integrate-and-Fire Neuron

**Replace Traditional Neurons with Spiking Neurons**:

The first step is to replace the traditional neurons in the GCN layers with spiking neurons. Spiking neurons model the behavior of biological neurons by generating and transmitting spike trains over time, rather than processing static activations.

There are various spiking neuron models available, such as Leaky Integrate-and-Fire (LIF), Izhikevich, and Hodgkin-Huxley models. For this example, we'll use the LIF model, which is a widely used and computationally efficient spiking neuron model.

In [59]:
ALPHA= torch.tensor(0.9)
BETA = torch.tensor(0.5)

class LIFNeuron(snn.Synaptic):
    def __init__(self):
        super(LIFNeuron, self).__init__(alpha=ALPHA, beta=BETA)

        self.tau_mem = 20.0 # Membrane time constant
        self.tau_syn = 5.0 # Synaptic time constant

    def set_batch_size(self, batch_size):
        self.batch_size = batch_size

    def forward(self, x):
        # Update the membrane potential and synaptic current using the input spike train
        mem, synaptic_current = self.lif(x, alpha=self.alpha, beta=self.beta, tau_mem=self.tau_mem, tau_syn=self.tau_syn) 
        return mem, synaptic_current

### Initialization:

- **tau_mem** (membrane time constant): This parameter determines the rate at which the membrane potential decays over time. A higher value means the membrane potential decays more slowly.

- **tau_syn** (synaptic time constant): This parameter determines the rate at which the synaptic current decays over time. A higher value means the synaptic current decays more slowly.

### Forward Pass:

- **x**: This is the input tensor representing the incoming spike trains or synaptic currents.

- **mem, synaptic_current = self.lif(x, self.mem, self.synaptic_current)**:

    The lif method computes the membrane potential (mem) and synaptic current (synaptic_current) for the current time step, based on the input x and the previous state variables (self.mem and self.synaptic_current).

    The lif method implements the dynamics of the LIF neuron model, which can be summarized as follows:

    - The membrane potential (mem) is computed by integrating the synaptic current (synaptic_current) with a leaky integration, controlled by the membrane time constant (tau_mem).

    - If the membrane potential exceeds a threshold value (typically 1.0), a spike is generated, and the membrane potential is reset to a resting value (typically 0.0).
    
    - The synaptic current (synaptic_current) is computed by integrating the input (x) with an exponential decay controlled by the synaptic time constant (tau_syn).

- **return mem, synaptic_current**: The updated membrane potential (mem) and synaptic current (synaptic_current) are returned as the output of the forward method.

The LIFNeuron class implements the Leaky Integrate-and-Fire neuron model, which is a spiking neuron model that integrates incoming synaptic currents over time and generates output spikes when the membrane potential exceeds a threshold.

The class takes the input spike trains (or synaptic currents) and computes the updated membrane potential and synaptic current based on the LIF dynamics and the specified time constants (tau_mem and tau_syn).

The membrane potential and synaptic current represent the internal state variables of the LIF neuron, and they are updated at each time step based on the input and the previous state variables. 

These state variables are then used to determine whether the neuron should generate an output spike or not.

# Spiking Graph Convolution

**Modify the Graph Convolution Layer:**

The traditional GraphConvolution layer needs to be modified to handle spiking inputs and outputs. 

Instead of performing a single matrix multiplication, the spiking GraphConvolution layer will need to integrate the incoming spike trains and generate output spike trains based on the neuron dynamics.

The SpikingGraphConvolution layer essentially performs a spiking version of the graph convolution operation, where information is propagated and integrated across the graph structure using spiking neuron dynamics instead of static activations.

In [60]:
class SpikingGraphConvolution(nn.Module):
    def __init__(self, in_features, out_features):
        super(SpikingGraphConvolution, self).__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        self.bias = nn.Parameter(torch.FloatTensor(out_features))
        self.reset_parameters()
        self.neuron = LIFNeuron()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x, adj):
        self.neuron.set_batch_size(x.size(0))
        # print(x.size(),self.weight.size(),adj.size())
        support = torch.mm(x, self.weight)  # Linear transformation
        output = torch.mm(adj, support)  # Graph convolution
        mem, synaptic_current = self.neuron(output + self.bias)  # Spiking neuron
        return mem, synaptic_current

The key difference from normal GC is in the last step, where instead of simply adding the bias and returning the output, the SpikingGraphConvolution layer uses a spiking neuron model (LIFNeuron) to integrate the propagated features (output + self.bias) over time.

The LIFNeuron computes the membrane potential (mem) and synaptic current (synaptic_current) at each time step, based on the input and the previous state variables. These state variables represent the internal dynamics of the spiking neurons and determine whether the neurons should generate output spikes or not.

By using spiking neuron dynamics, the SpikingGraphConvolution layer introduces temporal dynamics and models the information propagation using spike trains instead of static activations. This is more biologically plausible and potentially more efficient for certain types of tasks and hardware implementations.

However, **the core idea of propagating and aggregating information across the graph structure remains the same**, with the **main difference being the use of spiking neuron dynamics instead of static activations** in the SpikingGraphConvolution layer.

# SpikingGCN

The SpikingGCN class defines the overall architecture of the Spiking Graph Convolutional Network (SGCN) model, consisting of two SpikingGraphConvolution layers and additional operations to handle spiking activations and dropout.

In [61]:
class SpikingGCN(nn.Module):
    def __init__(
        self,
        in_features,
        hidden_features,
        out_features,
        dropout,
        time_window=100,
    ):
        super(SpikingGCN, self).__init__()
        self.gc1 = SpikingGraphConvolution(in_features, hidden_features)
        self.gc2 = SpikingGraphConvolution(hidden_features, out_features)
        self.dropout = dropout
        self.time_window = time_window  # Time window for spike counting (ms)

    def forward(self, x, adj):
        mem1, synaptic_current1 = self.gc1(x, adj)
        spike_rates1 = snn.synaptic_activation(synaptic_current1, self.time_window)
        spike_rates1 = snn.spike_activation(snn.soft_spike(spike_rates1, self.dropout))

        mem2, synaptic_current2 = self.gc2(spike_rates1, adj)
        spike_rates2 = snn.synaptic_activation(synaptic_current2, self.time_window)
        return spike_rates2

The SpikingGCN model follows a similar structure to the traditional GCN, with two graph convolution layers separated by a non-linearity and dropout regularization. 

However, instead of using static activations, the SGCN model operates on spike trains and incorporates spiking neuron dynamics using the SpikingGraphConvolution layers and various spiking activation functions from the snntorch library.

The key steps in the forward pass are:

1. Propagate and integrate the input spike trains through the first SpikingGraphConvolution layer to obtain the hidden layer's membrane potential and synaptic current.
2. Convert the synaptic current to spike rates, apply a non-linearity and dropout regularization.
3. Propagate and integrate the processed spike rates through the second SpikingGraphConvolution layer to obtain the output layer's membrane potential and synaptic current.
4. Convert the output layer's synaptic current to spike rates, which represent the final output of the SGCN model.

# Model Output

The actual output of the model is a tensor of spike rates, representing the predicted output for each node in the graph.

Specifically, the output of the `forward` method is `spike_rates2` : a tensor that contains the spike rates for each node and each output feature (or class) at the end of the simulation time window (`self.time_window`).

The shape of `spike_rates2` would be `(num_nodes, out_features)`, where:

- `num_nodes` is the number of nodes in the input graph.
- `out_features` is the number of output features or classes, specified by the `out_features` parameter during the initialization of the `SpikingGCN` class.

Each element `spike_rates2[i, j]` represents the spike rate (or firing rate) of node `i` for output feature (or class) `j` at the end of the simulation time window.

In the context of a classification task, **the spike rates can be interpreted as the predicted "confidence" or "strength" of each node belonging to each class**. Higher spike rates for a particular class would indicate a stronger prediction for that class.

It's important to note that the spike rates are not necessarily bounded between 0 and 1, as they represent the firing rates of the spiking neurons over the simulation time window. However, you can apply a normalization or an activation function (e.g., softmax) to the spike rates to obtain class probabilities or a bounded output suitable for your specific task.

For example, if you want to obtain class probabilities for a multi-class classification task, you could apply a softmax function to the spike rates:

```python
class_probabilities = nn.Softmax(dim=1)(spike_rates2)
```

Here `class_probabilities` would be a tensor of shape `(num_nodes, out_features)`, where each row sums to 1 and represents the probability distribution over the output classes for the corresponding node.


# Dummy example

In [62]:
from torch.utils.data import DataLoader

In [63]:
# Create dummy data
num_nodes = 10
num_features = 5
num_classes = 3
time_window = 100

# Node features
node_features = torch.randn(num_nodes, num_features, time_window)

# Adjacency matrix
adj = torch.rand(num_nodes, num_nodes) < 0.2  # Random sparse adjacency matrix
adj = adj.float()

# Node labels (dummy data)
node_labels = torch.randint(0, num_classes, (num_nodes,))

In [64]:
from snntorch import spikegen

# Convert node features to spike trains
node_features_spike = spikegen.rate(node_features, time_window)  # Shape: (num_nodes, num_features, time_window)

In [65]:
# Create a dataset and dataloader
dataset = list(zip(node_features_spike, adj, node_labels))
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

In [66]:
model = SpikingGCN(
    in_features=num_features,
    hidden_features=8,
    out_features=num_classes,
    dropout=0.2,
    time_window=time_window,
)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [67]:
# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    total_loss = 0
    for features, adj_matrix, labels in dataloader:
        optimizer.zero_grad()
        output = model(features, adj_matrix)
        output = output.permute(0, 2, 1)  # Reshape output for CrossEntropyLoss
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss / len(dataloader)}")

###### print(x.size(),self.weight.size(),adj.size())

torch.Size([2, 10, 5, 100]) torch.Size([5, 8]) torch.Size([2, 10])


RuntimeError: self must be a matrix