# Range Classification with a Spiking Neural Network (snntorch)

Classify echo delays (close/medium/far) using a small SNN with surrogate gradients. Input channels: pulse at t=0 and echo at t=delay.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import snntorch as snn
from snntorch import surrogate as surrogate
from snntorch import functional as SF
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.metrics import confusion_matrix

# Reproducibility
torch.manual_seed(0)
np.random.seed(0)

## Part 1: Synthetic Data Generation (Bat Simulator)
- Channel 0: pulse spike at t=0
- Channel 1: echo spike at t=delay
- Classes map to delay ranges with jitter.

In [None]:
time_steps = 70  # enough to cover max delay + margin
class_ranges = {
    0: (5, 15),   # Close
    1: (20, 35),  # Medium
    2: (40, 60),  # Far
}

def generate_bat_data(num_samples: int):
    data = torch.zeros(num_samples, time_steps, 2)
    labels = torch.zeros(num_samples, dtype=torch.long)
    for i in range(num_samples):
        cls = np.random.choice([0, 1, 2])
        dmin, dmax = class_ranges[cls]
        delay = np.random.randint(dmin, dmax + 1)
        jitter = np.random.choice([-1, 0, 1])
        delay = int(np.clip(delay + jitter, 0, time_steps - 1))
        # Channel 0: pulse at t=0
        data[i, 0, 0] = 1.0
        # Channel 1: echo at t=delay
        data[i, delay, 1] = 1.0
        labels[i] = cls
    return data, labels

# Generate dataset and split
num_samples = 1200
data, labels = generate_bat_data(num_samples)
split = int(0.8 * num_samples)
train_data, test_data = data[:split], data[split:]
train_labels, test_labels = labels[:split], labels[split:]

train_loader = DataLoader(TensorDataset(train_data, train_labels), batch_size=64, shuffle=True)
test_loader = DataLoader(TensorDataset(test_data, test_labels), batch_size=128, shuffle=False)

## Part 2: SNN Architecture
- Surrogate gradient (fast sigmoid) enables backprop through spikes.
- LIF hidden and output layers.

In [None]:
beta = 0.9
hidden_size = 32
num_classes = 3

class BatSNN(nn.Module):
    def __init__(self):
        super().__init__()
        # Surrogate gradient lets gradients flow through spike nonlinearity.
        self.surrogate_grad = surrogate.fast_sigmoid()
        self.fc1 = nn.Linear(2, hidden_size)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=self.surrogate_grad)
        self.fc2 = nn.Linear(hidden_size, num_classes)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=self.surrogate_grad)

    def forward(self, x):
        # x: [batch, time, 2]
        batch_size = x.size(0)
        mem1 = torch.zeros(batch_size, hidden_size)
        mem2 = torch.zeros(batch_size, num_classes)
        spk2_rec = []
        for t in range(x.size(1)):
            cur = x[:, t, :]
            h1 = self.fc1(cur)
            spk1, mem1 = self.lif1(h1, mem1)
            h2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(h2, mem2)
            spk2_rec.append(spk2)
        spk2_rec = torch.stack(spk2_rec, dim=1)  # [batch, time, classes]
        return spk2_rec

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BatSNN().to(device)

## Part 3: Training Loop
- Loss: Cross-entropy on spike rate (`SF.ce_rate_loss`).
- Optimizer: Adam.
- Track per-epoch loss and test accuracy.

In [None]:
optimizer = optim.Adam(model.parameters(), lr=1e-3)
num_epochs = 50

def test(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            spk_out = model(x)
        # ce_rate_loss expects [time, batch, classes], so permute
        spk_out_time_major = spk_out.permute(1, 0, 2)
            rates = spk_out.sum(dim=1)  # [batch, classes]
            pred = rates.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.numel()
    return correct / total

train_losses = []
test_accuracies = []

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        spk_out = model(x)
        # ce_rate_loss expects [time, batch, classes], so permute
        spk_out_time_major = spk_out.permute(1, 0, 2)
        loss = SF.ce_rate_loss()(spk_out_time_major, y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * x.size(0)
    epoch_loss /= len(train_loader.dataset)
    train_losses.append(epoch_loss)
    acc = test(model, test_loader)
    test_accuracies.append(acc)
    print(f"Epoch {epoch+1}/{num_epochs} | Loss: {epoch_loss:.4f} | Test Acc: {acc:.3f}")

## Part 4: Evaluation & Visualization

In [None]:
# 1) Training metrics
epochs = np.arange(1, num_epochs + 1)
fig, ax1 = plt.subplots(figsize=(7, 3))
ax1.plot(epochs, train_losses, color='tab:blue', label='Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss', color='tab:blue')
ax1.tick_params(axis='y', labelcolor='tab:blue')

ax2 = ax1.twinx()
ax2.plot(epochs, test_accuracies, color='tab:orange', label='Accuracy')
ax2.set_ylabel('Accuracy', color='tab:orange')
ax2.tick_params(axis='y', labelcolor='tab:orange')

fig.tight_layout()
plt.title("Training Loss and Test Accuracy")
plt.show()

# 2) Spike raster on a random test sample
model.eval()
idx = np.random.randint(0, len(test_data))
sample_x = test_data[idx:idx+1].to(device)
sample_y = test_labels[idx].item()
with torch.no_grad():
    spk_out = model(sample_x)
spk_out_cpu = spk_out.cpu().squeeze(0)  # [time, classes]

fig, axes = plt.subplots(2, 1, figsize=(7, 4), sharex=True)
spikeplot.raster(sample_x.cpu().squeeze(0), ax=axes[0])
axes[0].set_title(f"Input Spikes (True class: {sample_y})")
axes[0].set_ylabel("Channel")

spikeplot.raster(spk_out_cpu, ax=axes[1])
axes[1].set_title("Output Spikes (Classes 0,1,2)")
axes[1].set_xlabel("Time step")
axes[1].set_ylabel("Class")
plt.tight_layout()
plt.show()

# 3) Confusion matrix on test set
all_preds = []
all_trues = []
with torch.no_grad():
    for x, y in test_loader:
        x = x.to(device)
        spk_out = model(x)
        rates = spk_out.sum(dim=1)
        preds = rates.argmax(dim=1).cpu().numpy()
        all_preds.extend(preds)
        all_trues.extend(y.numpy())

cm = confusion_matrix(all_trues, all_preds, labels=[0,1,2])
fig, ax = plt.subplots(figsize=(4, 3))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False,
            xticklabels=['Close','Med','Far'], yticklabels=['Close','Med','Far'], ax=ax)
ax.set_xlabel('Predicted')
ax.set_ylabel('True')
ax.set_title('Confusion Matrix')
plt.tight_layout()
plt.show()