In [2]:
import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
import numpy as np
import itertools

batch_size = 1
data_path='./data/mnist'

dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

In [4]:
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

num_inputs = 28*28
num_hidden = 256
num_outputs = 10

# Temporal Dynamics
num_steps = 20
beta = 0.95

In [None]:

class Net(nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input
        spike_input = spikegen.rate(x, num_steps=num_steps) # Generate spike trains
        # print("spike_input")
        # print(spike_input.shape)
        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        # Record the final layer
        spk2_rec = []
        mem2_rec = []

        for step in range(num_steps):
            cur1 = self.fc1(spike_input[step])
            spk1, mem1 = self.lif1(cur1, mem1)
            # print("spk1")
            # print(spk1.shape)
            # print(spk1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            # print("spk2")
            # print(spk2.shape)
            # print(spk2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)

In [6]:
net = Net().to(device)

loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))
num_epochs = 1
loss_hist = []
test_loss_hist = []
counter = 0

def print_batch_accuracy(data, targets, train=False):
    output, _ = net(data.view(batch_size, -1))
    _, idx = output.sum(dim=0).max(1)
    acc = np.mean((targets == idx).detach().cpu().numpy())

    if train:
        print(f"Train set accuracy for a single minibatch: {acc*100:.2f}%")
    else:
        print(f"Test set accuracy for a single minibatch: {acc*100:.2f}%")

def train_printer():
    print(f"Epoch {epoch}, Iteration {iter_counter}")
    print(f"Train Set Loss: {loss_hist[counter]:.2f}")
    print(f"Test Set Loss: {test_loss_hist[counter]:.2f}")
    print_batch_accuracy(data, targets, train=True)
    print_batch_accuracy(test_data, test_targets, train=False)
    print("\n")


for epoch in range(num_epochs):
    iter_counter = 0
    train_batch = iter(train_loader)

    # Minibatch training loop
    for data, targets in train_batch:
        data = data.to(device)
        targets = targets.to(device)

        # forward pass
        net.train()
        spk_rec, mem_rec = net(data.view(batch_size, -1))
        # print(mem_rec.shape)
        # print(mem_rec)

        # initialize the loss & sum over time
        loss_val = torch.zeros((1), dtype=dtype, device=device)
        for step in range(num_steps):
            loss_val += loss(mem_rec[step], targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # Test set
        with torch.no_grad():
            net.eval()
            test_data, test_targets = next(iter(test_loader))
            test_data = test_data.to(device)
            test_targets = test_targets.to(device)

            # Test set forward pass
            test_spk, test_mem = net(test_data.view(batch_size, -1))

            # Test set loss
            test_loss = torch.zeros((1), dtype=dtype, device=device)
            for step in range(num_steps):
                test_loss += loss(test_mem[step], test_targets)
            test_loss_hist.append(test_loss.item())

            # Print train/test loss/accuracy
            if counter % 50 == 0:
                train_printer()
            counter += 1
            iter_counter +=1


Epoch 0, Iteration 0
Train Set Loss: 47.01
Test Set Loss: 43.58
Train set accuracy for a single minibatch: 24.22%
Test set accuracy for a single minibatch: 21.09%


Epoch 0, Iteration 50
Train Set Loss: 11.60
Test Set Loss: 12.26
Train set accuracy for a single minibatch: 89.84%
Test set accuracy for a single minibatch: 85.94%


Epoch 0, Iteration 100
Train Set Loss: 8.23
Test Set Loss: 10.46
Train set accuracy for a single minibatch: 91.41%
Test set accuracy for a single minibatch: 89.84%


Epoch 0, Iteration 150
Train Set Loss: 7.80
Test Set Loss: 10.40
Train set accuracy for a single minibatch: 92.97%
Test set accuracy for a single minibatch: 87.50%


Epoch 0, Iteration 200
Train Set Loss: 6.31
Test Set Loss: 6.23
Train set accuracy for a single minibatch: 94.53%
Test set accuracy for a single minibatch: 92.19%


Epoch 0, Iteration 250
Train Set Loss: 6.34
Test Set Loss: 8.56
Train set accuracy for a single minibatch: 93.75%
Test set accuracy for a single minibatch: 91.41%


Epoch 0

In [None]:

total = 0
correct = 0

# drop_last switched to False to keep all samples
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=False)

with torch.no_grad():
  net.eval()
  for data, targets in test_loader:
    data = data.to(device)
    targets = targets.to(device)

    # forward pass
    test_spk, _ = net(data.view(data.size(0), -1))

    # calculate total accuracy
    _, predicted = test_spk.sum(dim=0).max(1)
    total += targets.size(0)
    correct += (predicted == targets).sum().item()
print(f"Total correctly classified test set images: {correct}/{total}")
print(f"Test Set Accuracy: {100 * correct / total:.2f}%")

spk1
torch.Size([128, 256])
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
spk2
torch.Size([128, 10])
tensor([[0., 0., 1.,  ..., 0., 1., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
spk1
torch.Size([128, 256])
tensor([[1., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 1., 0., 0.]])
spk2
torch.Size([128, 10])
tensor([[0., 0., 1.,  ..., 0., 1., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        

In [6]:
net.lif1.threshold.detach().cpu().numpy()

array(1., dtype=float32)

In [9]:
# File paths for CSV output.
spike_save_path = "mnist_input_spikes.csv"
label_save_path = "mnist_labels.csv"

all_spikes = []
all_labels = []

# Loop over your test_loader.
for data, targets in test_loader:
    data = data.to(device)
    targets = targets.to(device)

    # Convert images to spike trains.
    # Assume spike_data has shape (num_steps, batch_size, vector_length)
    spike_data = spikegen.rate(data.view(batch_size, -1), num_steps=num_steps).cpu().numpy()
    # Remove the batch dimension (assumed to be 1)
    spike_data = np.squeeze(spike_data, axis=1)  # Now shape is (num_steps, vector_length)
    all_spikes.append(spike_data)
    
    # For labels, assume each batch yields one label.
    all_labels.append(targets.cpu().numpy())

# Concatenate all batches along the time dimension.
all_spikes = np.concatenate(all_spikes, axis=0)
all_labels = np.concatenate(all_labels, axis=0)

# Save spike data and labels as CSV.
# The CSV file for spikes will have (total_time_steps x vector_length) entries.
np.savetxt(spike_save_path, all_spikes.astype(np.int8), delimiter=",", fmt="%d")
np.savetxt(label_save_path, all_labels.astype(np.int8), delimiter=",", fmt="%d")

print("Spike data and labels saved as CSV files.")

Spike data and labels saved as CSV files.


In [5]:
np.savetxt("weights_fc1.txt", net.fc1.weight.detach().numpy())
np.savetxt("weights_fc2.txt", net.fc2.weight.detach().numpy())
np.savetxt("bias_fc1.txt", net.fc1.bias.detach().numpy())
np.savetxt("bias_fc2.txt", net.fc2.bias.detach().numpy())

NameError: name 'net' is not defined