In [44]:
%pip install torch torchvision matplotlib math

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


ERROR: Could not find a version that satisfies the requirement math (from versions: none)
ERROR: No matching distribution found for math


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import math
import matplotlib.pyplot as plt

In [None]:
# Hyperparameters
total_time = 100  # total time in milliseconds
timestep = 1  # time step in milliseconds
num_steps = total_time // timestep  # number of time steps
max_firing_rate = 200  # maximum firing rate in Hz
hidden_size = 100
tau = 45.0  # membrane time constant in ms
dt = 1.0  # time step in ms
v_rest = 0.0  # resting potential
v_reset = -0.5  # reset potential after spike
v_th = 1.1  # threshold potential for spiking
epochs = 20  # number of training epochs

train_ds = datasets.MNIST(root='.', train=True, download=True, transform=transforms.ToTensor())
print (train_ds[0][0].shape)
train_ds = train_ds.data.float() / 255.0

def poisson_rate_encoding(train_ds, num_steps, max_firing_rate, num_samples=None):
    # If num_samples is None, use the entire dataset
    if num_samples is None:
        num_samples = train_ds.shape[0]
    
    # Sample random indices if using a subset, otherwise use all indices
    if num_samples < train_ds.shape[0]:
        indices = torch.randperm(train_ds.shape[0])[:num_samples]
    else:
        indices = torch.arange(train_ds.shape[0])
    
    img_flat = train_ds[0].view(-1)
    spike_train = torch.zeros((num_samples, img_flat.shape[0], num_steps))
    
    # Process images in smaller batches to avoid memory issues
    batch_size = 1000  # Process 1000 images at a time
    for batch_start in range(0, num_samples, batch_size):
        batch_end = min(batch_start + batch_size, num_samples)
        batch_indices = indices[batch_start:batch_end]
        
        for i, idx in enumerate(batch_indices):
            img = train_ds[idx]
            img_flat = img.view(-1)
            firing_rates = img_flat * max_firing_rate
            probability = firing_rates / 1000 * timestep
            
            for t in range(num_steps):
                random_values = torch.rand(len(firing_rates))
                spikes = (random_values < probability).float()
                spike_train[batch_start + i, :, t] = spikes
                
        print(f"Processed {batch_end}/{num_samples} images")
            
    return spike_train, indices

spike_train, indices = poisson_rate_encoding(train_ds, num_steps, max_firing_rate, num_samples=None)


# # Visualize spike patterns for several samples (e.g., first 10) and show their labels and images
# num_samples_to_plot = 10
# fig, axes = plt.subplots(num_samples_to_plot, 2, figsize=(12, 2 * num_samples_to_plot), gridspec_kw={'width_ratios': [3, 1]}, sharex='col')

# mnist_targets = datasets.MNIST(root='.', train=True, download=True).targets

# for i in range(num_samples_to_plot):
#     # Raster plot
#     spike_sample = spike_train[i]
#     neuron_idx, time_idx = torch.nonzero(spike_sample, as_tuple=True)
#     axes[i, 0].scatter(time_idx.numpy(), neuron_idx.numpy(), s=1, color='black')
#     label = mnist_targets[indices[i]].item()
#     axes[i, 0].set_ylabel('Neuron idx')
#     axes[i, 0].set_title(f'Spike raster plot for sample {i} (Label: {label})')
#     # Image plot
#     img = train_ds[indices[i]].view(28, 28)
#     axes[i, 1].imshow(img.numpy(), cmap='gray')
#     axes[i, 1].axis('off')
#     axes[i, 1].set_title('Input image')

# axes[-1, 0].set_xlabel('Time step')
# plt.tight_layout()
# plt.show()


class SurrogateSpike(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return (input >= 1.0).float()

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        surrogate_grad = 1 / (1 + 15 * torch.abs(input - 1.0 )) ** 2
        return grad_input * surrogate_grad

class LIFLayerSG(nn.Module):
    def __init__(self, in_features, out_features, tau=20.0, dt=1.0, v_rest=0.0, v_reset=0.0, v_th=1.0):
        super().__init__()
        self.W = nn.Parameter(torch.randn(out_features, in_features) * 0.1)
        self.decay = math.exp(-dt / tau)
        self.v_rest, self.v_reset, self.v_th = v_rest, v_reset, v_th

    def forward(self, spikes):
        B, F_in, T = spikes.shape
        v = torch.zeros(B, self.W.shape[0], device=spikes.device)
        spike_history = []

        for t in range(T):
            I_t = torch.matmul(spikes[:, :, t], self.W.t())
            v = (v - self.v_rest) * self.decay + I_t + self.v_rest
            z = SurrogateSpike.apply(v - self.v_th)
            v = torch.where(z.bool(), torch.full_like(v, self.v_reset), v)
            spike_history.append(z)

        return torch.stack(spike_history, dim=2)

class SNNModelSG(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, **lif_kwargs):
        super().__init__()
        self.lif1 = LIFLayerSG(input_size, hidden_size, **lif_kwargs)
        self.lif2 = LIFLayerSG(hidden_size, output_size, **lif_kwargs)

    def forward(self, x):
        B, H, W, T = x.shape
        x = x.view(B, H * W, T)
        s1 = self.lif1(x)
        s2 = self.lif2(s1)
        return s2.mean(dim=2)  # Average over time

# Prepare labels (target)
mnist_targets = datasets.MNIST(root='.', train=True, download=True).targets
labels = mnist_targets[indices[:]]

# Training preparation with mini-batches
def train_model(spike_data, labels, model, criterion, optimizer, epochs=10, batch_size=64):
    model.train()
    dataset_size = spike_data.shape[0]
    
    # Store metrics for each epoch
    epoch_losses = []
    epoch_accuracies = []
    
    for epoch in range(epochs):
        # Shuffle indices for this epoch
        indices = torch.randperm(dataset_size)
        running_loss = 0.0
        running_acc = 0.0
        
        # Process mini-batches
        for i in range(0, dataset_size, batch_size):
            # Get mini-batch indices
            batch_indices = indices[i:min(i+batch_size, dataset_size)]
            
            # Get data for this mini-batch
            batch_data = spike_data[batch_indices]
            batch_labels = labels[batch_indices]
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(batch_data)
            loss = criterion(outputs, batch_labels)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            # Calculate metrics
            running_loss += loss.item() * len(batch_indices)
            _, predicted = torch.max(outputs.data, 1)
            running_acc += (predicted == batch_labels).sum().item()
        
        # Calculate epoch metrics
        epoch_loss = running_loss / dataset_size
        epoch_acc = running_acc / dataset_size
        
        epoch_losses.append(epoch_loss)
        epoch_accuracies.append(epoch_acc)
        
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss:.6f}, Accuracy: {epoch_acc*100:.4f}%")
    
    return epoch_losses, epoch_accuracies

# Define mini-batch size
mini_batch_size = 64

# Create the model, criterion and optimizer
model = SNNModelSG(input_size=28*28, hidden_size=hidden_size, output_size=10)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-2)

# spike_train: [B, F, T] → reshape to [B, H, W, T]
total_samples = spike_train.shape[0]
spike_data = spike_train.view(total_samples, 28, 28, -1)

# Train with mini-batches
losses, accuracies = train_model(spike_data, labels, model, criterion, optimizer, 
                                epochs=epochs, batch_size=mini_batch_size)

# Plot training progress
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(losses)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')

plt.subplot(1, 2, 2)
plt.plot(accuracies)
plt.title('Training Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.tight_layout()
plt.show()

torch.Size([1, 28, 28])
Processed 1000/60000 images
Processed 2000/60000 images
Processed 3000/60000 images
Processed 4000/60000 images
Processed 5000/60000 images
Processed 6000/60000 images
Processed 7000/60000 images
Processed 8000/60000 images
Processed 9000/60000 images
Processed 10000/60000 images
Processed 11000/60000 images
Processed 12000/60000 images
Processed 13000/60000 images
Processed 14000/60000 images
Processed 15000/60000 images
Processed 16000/60000 images
Processed 17000/60000 images
Processed 18000/60000 images
Processed 19000/60000 images
Processed 20000/60000 images
Processed 21000/60000 images
Processed 22000/60000 images
Processed 23000/60000 images
Processed 24000/60000 images
Processed 25000/60000 images
Processed 26000/60000 images
Processed 27000/60000 images
Processed 28000/60000 images
Processed 29000/60000 images
Processed 30000/60000 images
Processed 31000/60000 images
Processed 32000/60000 images
Processed 33000/60000 images
Processed 34000/60000 images