# SNN vs ANN Comparison - Variant C Demo

This notebook demonstrates the implementation and training of a Spiking Neural Network (SNN) using PyTorch and snnTorch, comparing it with conventional ANNs.

## Setup Instructions
1. First, we'll install the required packages
2. Then we'll implement and train an SNN model
3. Finally, we'll evaluate and visualize the results

In [21]:
# Install required packages
!pip install torch torchvision snntorch matplotlib numpy


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [22]:
# Import required libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import snntorch as snn
from snntorch import surrogate
from snntorch import functional as sf
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cpu


## SNN Model Definition

We'll implement a spiking CNN using Leaky Integrate-and-Fire (LIF) neurons from snnTorch. The model consists of:
1. Two convolutional layers with LIF neurons
2. MaxPooling layers for dimensionality reduction
3. A fully connected output layer with LIF neurons

In [23]:
class SNNCNN(nn.Module):
    def __init__(self, num_inputs=784, num_hidden=256, num_outputs=10, beta=0.95):
        super().__init__()
        
        # Initialize with kaiming normal for better gradient flow
        self.conv1 = nn.Conv2d(1, 12, 5)
        nn.init.kaiming_normal_(self.conv1.weight)
        self.bn1 = nn.BatchNorm2d(12)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=surrogate.fast_sigmoid())
        
        self.conv2 = nn.Conv2d(12, 64, 5)
        nn.init.kaiming_normal_(self.conv2.weight)
        self.bn2 = nn.BatchNorm2d(64)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=surrogate.fast_sigmoid())
        
        # Calculate size after convolutions and pooling
        self.num_flat_features = 64 * 4 * 4
        
        self.fc1 = nn.Linear(self.num_flat_features, num_outputs)
        nn.init.kaiming_normal_(self.fc1.weight)
        self.lif3 = snn.Leaky(beta=beta, spike_grad=surrogate.fast_sigmoid())
        
        # Total number of neurons for AFR calculation
        self.num_neurons = (
            12 * 24 * 24 +  # After conv1
            64 * 8 * 8 +    # After conv2
            num_outputs     # Output layer
        )
        
        # Initialize membrane potentials
        self.reset_states()
    
    def reset_states(self):
        """Reset all membrane potentials."""
        self.mem1 = None
        self.mem2 = None
        self.mem3 = None
    
    def forward(self, x):
        batch_size = x.size(1)
        
        # Initialize hidden states if None
        if self.mem1 is None:
            self.mem1 = self.lif1.init_leaky()
        if self.mem2 is None:
            self.mem2 = self.lif2.init_leaky()
        if self.mem3 is None:
            self.mem3 = self.lif3.init_leaky()
        
        # Record spikes for AFR calculation
        spk1_rec = []
        spk2_rec = []
        out_rec = []
        
        for step in range(x.size(0)):
            cur1 = self.bn1(F.max_pool2d(self.conv1(x[step]), 2))
            spk1, self.mem1 = self.lif1(cur1, self.mem1)
            
            cur2 = self.bn2(F.max_pool2d(self.conv2(spk1), 2))
            spk2, self.mem2 = self.lif2(cur2, self.mem2)
            
            cur3 = self.fc1(spk2.flatten(1))
            spk3, self.mem3 = self.lif3(cur3, self.mem3)
            
            spk1_rec.append(spk1)
            spk2_rec.append(spk2)
            out_rec.append(spk3)
        
        return torch.stack(out_rec, dim=0)
    
    def compute_afr(self, spk_rec):
        """Compute Average Firing Rate as percentage of total possible spikes."""
        total_spikes = torch.sum(spk_rec)
        total_neurons = self.num_neurons
        total_timesteps = spk_rec.size(0)
        total_samples = spk_rec.size(1)
        
        max_possible_spikes = total_neurons * total_timesteps * total_samples
        afr = 100.0 * total_spikes.float() / max_possible_spikes
        return afr.item()

# Create model instance
model = SNNCNN().to(device)
print(model)

# Print parameter count
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal trainable parameters: {total_params:,}")

SNNCNN(
  (conv1): Conv2d(1, 12, kernel_size=(5, 5), stride=(1, 1))
  (bn1): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (lif1): Leaky()
  (conv2): Conv2d(12, 64, kernel_size=(5, 5), stride=(1, 1))
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (lif2): Leaky()
  (fc1): Linear(in_features=1024, out_features=10, bias=True)
  (lif3): Leaky()
)

Total trainable parameters: 29,978


## Data Generation

For this demo, we'll generate synthetic image-like data similar to MNIST format:
- Input shape: (batch_size, 1, 28, 28)
- 10 output classes
- Simple patterns for class determination

In [24]:
def generate_synthetic_data(num_samples=1000, input_size=28, num_classes=10):
    """Generate synthetic image-like data with distinct patterns per class."""
    # Create base patterns for each class
    patterns = torch.randn(num_classes, 1, input_size, input_size)
    patterns = F.avg_pool2d(patterns, kernel_size=4, stride=2, padding=1)
    patterns = F.interpolate(patterns, size=(input_size, input_size), mode='bilinear')
    
    # Initialize storage
    images = torch.zeros(num_samples, 1, input_size, input_size)
    labels = torch.zeros(num_samples, dtype=torch.long)
    
    # Generate samples
    for i in range(num_samples):
        label = i % num_classes
        # Get base pattern and add noise
        base = patterns[label].clone()
        noise = torch.randn_like(base) * 0.1
        images[i] = base + noise
        labels[i] = label
    
    # Normalize to [-1, 1]
    images = (images - images.mean()) / (images.std() + 1e-5)
    
    return images, labels

# Generate training and test data
num_samples = 1000
num_steps = 25
batch_size = 32

# Create train/test sets
X_train, y_train = generate_synthetic_data(num_samples)
X_test, y_test = generate_synthetic_data(num_samples // 5)

# Create data loaders with pin_memory for faster GPU transfer
train_loader = DataLoader(
    TensorDataset(X_train, y_train),
    batch_size=batch_size,
    shuffle=True,
    pin_memory=True
)
test_loader = DataLoader(
    TensorDataset(X_test, y_test),
    batch_size=batch_size,
    pin_memory=True
)

# Print dataset information
print(f"Training samples: {len(X_train)}")
print(f"Test samples: {len(X_test)}")
print(f"Input shape: {X_train[0].shape}")
print(f"Number of classes: {len(torch.unique(y_train))}")
print(f"Data range: [{X_train.min():.2f}, {X_train.max():.2f}]")

Training samples: 1000
Test samples: 200
Input shape: torch.Size([1, 28, 28])
Number of classes: 10
Data range: [-4.57, 4.22]


## Training Functions

We'll implement functions for:
1. Training one epoch
2. Validation with AFR computation
3. Training loop with loss and accuracy tracking

In [None]:
def train(model, train_loader, optimizer, device, num_steps):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    total_acc = 0
    total_samples = 0
    
    for data, targets in train_loader:
        data = data.to(device)
        targets = targets.to(device)
        
        # Reset model states before each batch
        model.reset_states()
        
        # Expand input for time steps
        data = data.unsqueeze(0).repeat(num_steps, 1, 1, 1, 1)
        spk_rec = model(data)
        
        # Loss and accuracy (average over time steps)
        output = spk_rec.mean(0)  # Average over time steps
        loss = F.cross_entropy(output, targets)
        acc = (output.argmax(1) == targets).float().mean() * 100
        
        # Gradient computation and weight updates
        optimizer.zero_grad()
        loss.backward(retain_graph=True)  # Added retain_graph=True
        optimizer.step()
        
        batch_size = len(targets)
        total_loss += loss.item() * batch_size
        total_acc += acc.item() * batch_size
        total_samples += batch_size
    
    return {
        'loss': total_loss / total_samples,
        'accuracy': total_acc / total_samples
    }

Training setup complete


## Training Loop

Now we'll train the model for multiple epochs, tracking:
1. Training loss and accuracy
2. Validation accuracy
3. Average Firing Rate (AFR)
4. Training time

In [26]:
import time
from IPython.display import clear_output

# Training parameters
num_epochs = 10
history = {
    'loss': [], 'accuracy': [], 
    'val_accuracy': [], 'val_afr': [],
    'time_per_epoch': []
}

def plot_training_progress():
    """Plot real-time training progress."""
    clear_output(wait=True)
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 4))
    
    # Plot training loss
    ax1.plot(history['loss'], 'b-')
    ax1.set_title('Training Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    
    # Plot accuracies
    epochs = range(1, len(history['accuracy']) + 1)
    ax2.plot(epochs, history['accuracy'], 'b-', label='Train')
    ax2.plot(epochs, history['val_accuracy'], 'r-', label='Validation')
    ax2.set_title('Model Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.legend()
    
    # Plot AFR
    ax3.plot(epochs, history['val_afr'], 'g-')
    ax3.set_title('Average Firing Rate')
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('AFR (%)')
    
    plt.tight_layout()
    plt.show()
    
    # Print latest metrics
    print(f"\nLatest metrics (Epoch {len(epochs)}):")
    print(f"  Train Loss: {history['loss'][-1]:.4f}")
    print(f"  Train Accuracy: {history['accuracy'][-1]:.2f}%")
    print(f"  Val Accuracy: {history['val_accuracy'][-1]:.2f}%")
    print(f"  Average Firing Rate: {history['val_afr'][-1]:.2f}%")
    print(f"  Time per epoch: {history['time_per_epoch'][-1]:.2f}s")

print("Starting training...")
total_start_time = time.time()

# Make sure model is in training mode
model.train()

for epoch in range(num_epochs):
    epoch_start_time = time.time()
    
    # Training
    train_stats = train(model, train_loader, optimizer, device, num_steps)
    
    # Validation
    val_stats = validate(model, test_loader, device, num_steps)
    
    # Record history
    epoch_time = time.time() - epoch_start_time
    history['loss'].append(train_stats['loss'])
    history['accuracy'].append(train_stats['accuracy'])
    history['val_accuracy'].append(val_stats['accuracy'])
    history['val_afr'].append(val_stats['afr'])
    history['time_per_epoch'].append(epoch_time)
    
    # Plot progress
    if (epoch + 1) % 1 == 0:  # Update every epoch
        plot_training_progress()

total_time = time.time() - total_start_time
print(f"\nTraining completed in {total_time:.2f} seconds")
print(f"Average time per epoch: {np.mean(history['time_per_epoch']):.2f} seconds")

Starting training...


RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

## Results Visualization

Let's plot the training results:
1. Training loss over epochs
2. Training and validation accuracy
3. Average Firing Rate progression

In [None]:
# Create figure for multiple plots
fig = plt.figure(figsize=(20, 10))

# 1. Training curves
ax1 = plt.subplot(231)
ax1.plot(history['loss'], 'b-', label='Loss')
ax1.set_title('Training Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.grid(True)

# 2. Accuracy curves
ax2 = plt.subplot(232)
epochs = range(1, len(history['accuracy']) + 1)
ax2.plot(epochs, history['accuracy'], 'b-', label='Train')
ax2.plot(epochs, history['val_accuracy'], 'r-', label='Validation')
ax2.set_title('Model Accuracy')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.legend()
ax2.grid(True)

# 3. AFR progression
ax3 = plt.subplot(233)
ax3.plot(epochs, history['val_afr'], 'g-')
ax3.set_title('Average Firing Rate')
ax3.set_xlabel('Epoch')
ax3.set_ylabel('AFR (%)')
ax3.grid(True)

# 4. Time per epoch
ax4 = plt.subplot(234)
ax4.plot(epochs, history['time_per_epoch'], 'm-')
ax4.set_title('Time per Epoch')
ax4.set_xlabel('Epoch')
ax4.set_ylabel('Seconds')
ax4.grid(True)

# 5. Example predictions
model.eval()
with torch.no_grad():
    # Get sample batch
    images, labels = next(iter(test_loader))
    images, labels = images.to(device), labels.to(device)
    
    # Reset model states
    model.reset_states()
    
    # Forward pass
    spk_rec = model(images.unsqueeze(0).repeat(num_steps, 1, 1, 1, 1))
    predictions = spk_rec.mean(0).argmax(1)
    
    # Plot first 9 examples
    for idx in range(min(9, len(images))):
        plt.subplot(2, 3, 5 + (idx > 4))
        plt.imshow(images[idx].cpu().squeeze(), cmap='gray')
        plt.title(f'True: {labels[idx].item()}\nPred: {predictions[idx].item()}')
        plt.axis('off')

plt.tight_layout()
plt.show()

# Print final metrics
if len(history['val_accuracy']) > 0:  # Check if we have any history
    print("\nFinal Results:")
    print(f"Validation Accuracy: {history['val_accuracy'][-1]:.2f}%")
    print(f"Average Firing Rate: {history['val_afr'][-1]:.2f}%")
    print(f"Final Loss: {history['loss'][-1]:.4f}")
    print(f"Training Time: {total_time:.2f} seconds")
    print(f"Parameters: {total_params:,}")

# Model summary
print("\nModel Architecture:")
print(model)

OSError: 'seaborn' is not a valid package style, path of style file, URL of style file, or library style name (library styles are listed in `style.available`)

## Save Model and Results

Finally, let's save the trained model and results for future use. This includes:
1. Model state dictionary
2. Training history
3. Final metrics

In [None]:
# Save model checkpoint
if len(history['val_accuracy']) > 0:  # Only save if we have training history
    checkpoint = {
        'epoch': num_epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'history': history,
    }

    import os
    os.makedirs('models', exist_ok=True)
    torch.save(checkpoint, 'models/snn_model.pth')

    # Save final metrics to text file
    with open('models/results.txt', 'w') as f:
        f.write(f"Final Validation Accuracy: {history['val_accuracy'][-1]:.2f}%\n")
        f.write(f"Final Average Firing Rate: {history['val_afr'][-1]:.2f}%\n")
        f.write(f"Training Time: {total_time:.2f} seconds\n")

    print("Model and results saved successfully!")

IndexError: list index out of range