In [1]:
import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
import numpy as np
import itertools

batch_size = 100
data_path='../data/mnist'

dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

In [2]:
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))
            ])
# transform = transforms.Compose([
#             transforms.Resize((28, 28)),
#             transforms.ToTensor(),
#             ])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

num_inputs = 28*28
num_hidden = 256
num_outputs = 10

# Temporal Dynamics
num_steps = 20
beta = 0.95

In [6]:

class Net(nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input
        spike_input = spikegen.rate(x, num_steps=num_steps) # Generate spike trains
        # print("spike_input")
        # print(spike_input.shape)
        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        # Record the final layer
        spk2_rec = []
        mem2_rec = []

        for step in range(num_steps):
            # print(spike_input[step].sum(axis=-1))
            # print(spike_input[step].shape)
            cur1 = self.fc1(spike_input[step])
            spk1, mem1 = self.lif1(cur1, mem1)
            # print("spk1")
            # print(spk1.shape)
            # print(spk1)
            cur2 = self.fc2(spk1)
            # print(cur1)
            spk2, mem2 = self.lif2(cur2, mem2)
            # print("spk2")
            # print(spk2.shape)
            # print(spk2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)

In [7]:
net = Net().to(device)

loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))
num_epochs = 1
loss_hist = []
test_loss_hist = []
counter = 0

def print_batch_accuracy(data, targets, train=False):
    output, _ = net(data.view(batch_size, -1))
    _, idx = output.sum(dim=0).max(1)
    acc = np.mean((targets == idx).detach().cpu().numpy())

    if train:
        print(f"Train set accuracy for a single minibatch: {acc*100:.2f}%")
    else:
        print(f"Test set accuracy for a single minibatch: {acc*100:.2f}%")

def train_printer():
    print(f"Epoch {epoch}, Iteration {iter_counter}")
    print(f"Train Set Loss: {loss_hist[counter]:.2f}")
    print(f"Test Set Loss: {test_loss_hist[counter]:.2f}")
    print_batch_accuracy(data, targets, train=True)
    print_batch_accuracy(test_data, test_targets, train=False)
    print("\n")


for epoch in range(num_epochs):
    iter_counter = 0
    train_batch = iter(train_loader)

    # Minibatch training loop
    for data, targets in train_batch:
        data = data.to(device)
        targets = targets.to(device)

        # forward pass
        net.train()
        spk_rec, mem_rec = net(data.view(batch_size, -1))
        # print(mem_rec.shape)
        # print(mem_rec)

        # initialize the loss & sum over time
        loss_val = torch.zeros((1), dtype=dtype, device=device)
        for step in range(num_steps):
            loss_val += loss(mem_rec[step], targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # Test set
        with torch.no_grad():
            net.eval()
            test_data, test_targets = next(iter(test_loader))
            test_data = test_data.to(device)
            test_targets = test_targets.to(device)

            # Test set forward pass
            test_spk, test_mem = net(test_data.view(batch_size, -1))

            # Test set loss
            test_loss = torch.zeros((1), dtype=dtype, device=device)
            for step in range(num_steps):
                test_loss += loss(test_mem[step], test_targets)
            test_loss_hist.append(test_loss.item())

            # Print train/test loss/accuracy
            if counter % 50 == 0:
                train_printer()
            counter += 1
            iter_counter +=1


Epoch 0, Iteration 0
Train Set Loss: 47.33
Test Set Loss: 46.41
Train set accuracy for a single minibatch: 23.00%
Test set accuracy for a single minibatch: 16.00%


Epoch 0, Iteration 50
Train Set Loss: 14.90
Test Set Loss: 11.93
Train set accuracy for a single minibatch: 84.00%
Test set accuracy for a single minibatch: 90.00%


Epoch 0, Iteration 100
Train Set Loss: 8.71
Test Set Loss: 10.67
Train set accuracy for a single minibatch: 91.00%
Test set accuracy for a single minibatch: 94.00%


Epoch 0, Iteration 150
Train Set Loss: 7.47
Test Set Loss: 5.93
Train set accuracy for a single minibatch: 92.00%
Test set accuracy for a single minibatch: 94.00%


Epoch 0, Iteration 200
Train Set Loss: 10.94
Test Set Loss: 9.19
Train set accuracy for a single minibatch: 91.00%
Test set accuracy for a single minibatch: 87.00%


Epoch 0, Iteration 250
Train Set Loss: 6.14
Test Set Loss: 10.06
Train set accuracy for a single minibatch: 95.00%
Test set accuracy for a single minibatch: 89.00%


Epoch 

In [8]:

total = 0
correct = 0

# drop_last switched to False to keep all samples
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=False)

with torch.no_grad():
  net.eval()
  for data, targets in test_loader:
    data = data.to(device)
    targets = targets.to(device)

    # forward pass
    test_spk, _ = net(data.view(data.size(0), -1))

    # calculate total accuracy
    _, predicted = test_spk.sum(dim=0).max(1)
    total += targets.size(0)
    correct += (predicted == targets).sum().item()
print(f"Total correctly classified test set images: {correct}/{total}")
print(f"Test Set Accuracy: {100 * correct / total:.2f}%")

Total correctly classified test set images: 9369/10000
Test Set Accuracy: 93.69%


In [None]:
# File paths for CSV output.
spike_save_path = "mnist_input_spikes.csv"
label_save_path = "mnist_labels.csv"

all_spikes = []
all_labels = []

# Loop over your test_loader.
for data, targets in test_loader:
    data = data.to(device)
    targets = targets.to(device)

    # Convert images to spike trains.
    # Assume spike_data has shape (num_steps, batch_size, vector_length)
    spike_data = spikegen.rate(data.view(batch_size, -1), num_steps=num_steps).cpu().numpy()
    # Remove the batch dimension (assumed to be 1)
    spike_data = np.squeeze(spike_data, axis=1)  # Now shape is (num_steps, vector_length)
    all_spikes.append(spike_data)
    
    # For labels, assume each batch yields one label.
    all_labels.append(targets.cpu().numpy())

# Concatenate all batches along the time dimension.
all_spikes = np.concatenate(all_spikes, axis=0)
all_labels = np.concatenate(all_labels, axis=0)

# Save spike data and labels as CSV.
# The CSV file for spikes will have (total_time_steps x vector_length) entries.
np.savetxt(spike_save_path, all_spikes.astype(np.int8), delimiter=",", fmt="%d")
np.savetxt(label_save_path, all_labels.astype(np.int8), delimiter=",", fmt="%d")

print("Spike data and labels saved as CSV files.")

In [None]:
np.savetxt("weights_fc1.txt", net.fc1.weight.detach().numpy())
np.savetxt("weights_fc2.txt", net.fc2.weight.detach().numpy())
np.savetxt("bias_fc1.txt", net.fc1.bias.detach().numpy())
np.savetxt("bias_fc2.txt", net.fc2.bias.detach().numpy())

In [None]:
for data, label in test_loader:
    break

print(label)

In [None]:
d =  data[0,0,:,:] * 255
f = d.cpu().numpy()
np.savetxt("input_image.txt", f.flatten())
print(f)

In [None]:
import matplotlib.pyplot as plt
# get the weights and bias of every layer in the network and put it in one array
w = net.fc1.weight.detach().cpu().numpy()
b = net.fc1.bias.detach().cpu().numpy()
w = np.concatenate((w.flatten(), b.flatten()), axis=0)
w2 = net.fc2.weight.detach().cpu().numpy()
b2 = net.fc2.bias.detach().cpu().numpy()
w2 = np.concatenate((w2.flatten(), b2.flatten()), axis=0)
a = np.concatenate((w, w2), axis=0)

# find absolute max and min of the weights and bias
max_val = np.max(a)
min_val = np.min(a)
print("Max weight: ", max_val)
print("Min weight: ", min_val)


plt.figure(figsize=(10, 5))
plt.title("Weight Distribution of fc1 Layer")
plt.xlabel("Weight Value")
plt.ylabel("Frequency")
plt.grid(False)

plt.hist(a.flatten(), bins=256, range=(-1, 1))

In [None]:
Q07_SCALE = 128.0
Q07_MAX_FLOAT = 127 / 128.0    # 0.9921875
Q07_MIN_FLOAT = -1.0
Q07_MAX_INT8 = 127
Q07_MIN_INT8 = -128

def quantize_q07(x: float) -> int:
    if x > Q07_MAX_FLOAT:
        x = Q07_MAX_FLOAT
    elif x < Q07_MIN_FLOAT:
        x = Q07_MIN_FLOAT

    scaled = int(x * Q07_SCALE + (0.5 if x >= 0 else -0.5))

    if scaled > Q07_MAX_INT8:
        scaled = Q07_MAX_INT8
    elif scaled < Q07_MIN_INT8:
        scaled = Q07_MIN_INT8

    return np.int8(scaled)

def quantize_tensor_q07(tensor: torch.Tensor):
    tensor_np = tensor.detach().cpu().numpy()
    quantized = np.vectorize(quantize_q07)(tensor_np).astype(np.int8)
    return quantized

def format_c_array(var_name: str, array: np.ndarray) -> str:
    shape = array.shape
    dims = ''.join([f"[{d}]" for d in shape])
    result = f"const int8_t {var_name}{dims} = {{\n"

    if array.ndim == 1:
        result += "    { " + ', '.join(str(v) for v in array) + " }\n"
    elif array.ndim == 2:
        for row in array:
            result += "    { " + ', '.join(str(v) for v in row) + " },\n"
    elif array.ndim == 3:
        for mat in array:
            result += "    {\n"
            for row in mat:
                result += "        { " + ', '.join(str(v) for v in row) + " },\n"
            result += "    },\n"
    else:
        raise ValueError(f"Unsupported tensor dimension: {array.ndim}")

    result += "};\n\n"
    return result

def write_c_header(model: torch.nn.Module, header_path: str):
    with open(header_path, 'w') as f:
        f.write("// Auto-generated Q0.7 quantized weights\n\n")
        f.write("#ifndef Q07_WEIGHTS_H\n#define Q07_WEIGHTS_H\n\n")
        f.write("#include <stdint.h>\n\n")

        for name, param in model.named_parameters():
            q_data = quantize_tensor_q07(param)
            c_name = name.replace('.', '_')
            f.write(f"// Shape: {q_data.shape}\n")
            f.write(format_c_array(c_name, q_data))

        f.write("#endif // Q07_WEIGHTS_H\n")

In [None]:
q = quantize_tensor_q07(net.fc1.weight)
b = quantize_tensor_q07(net.fc1.bias)
q2 = quantize_tensor_q07(net.fc2.weight)
b2 = quantize_tensor_q07(net.fc2.bias)
np.savetxt("weights_fc1_q07.txt", q)
np.savetxt("bias_fc1_q07.txt", b)
np.savetxt("weights_fc2_q07.txt", q2)
np.savetxt("bias_fc2_q07.txt", b2)


In [None]:
write_c_header(net, "q07_weights.h")

In [3]:
class ANN(nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.fc2 = nn.Linear(num_hidden, num_outputs)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input
        fc1_out = self.fc1(x)
        fc2_out = self.fc2(fc1_out)
        return fc2_out

ann = ANN().to(device)

loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(ann.parameters(), lr=5e-4, betas=(0.9, 0.999))
num_epochs = 1
loss_hist = []
test_loss_hist = []
counter = 0

def print_batch_accuracy(data, targets, train=False):
    output = ann(data.view(batch_size, -1))
    _, idx = output.max(1)
    acc = np.mean((targets == idx).detach().cpu().numpy())

    if train:
        print(f"Train set accuracy for a single minibatch: {acc*100:.2f}%")
    else:
        print(f"Test set accuracy for a single minibatch: {acc*100:.2f}%")

def train_printer():
    print(f"Epoch {epoch}, Iteration {iter_counter}")
    print(f"Train Set Loss: {loss_hist[counter]:.2f}")
    print(f"Test Set Loss: {test_loss_hist[counter]:.2f}")
    print_batch_accuracy(data, targets, train=True)
    print_batch_accuracy(test_data, test_targets, train=False)
    print("\n")


for epoch in range(num_epochs):
    iter_counter = 0
    train_batch = iter(train_loader)

    # Minibatch training loop
    for data, targets in train_batch:
        data = data.to(device)
        targets = targets.to(device)

        # forward pass
        ann.train()
        out = ann(data.view(batch_size, -1))

        loss_val = torch.zeros((1), dtype=dtype, device=device)
        loss_val += loss(out, targets)


        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # Test set
        with torch.no_grad():
            ann.eval()
            test_data, test_targets = next(iter(test_loader))
            test_data = test_data.to(device)
            test_targets = test_targets.to(device)

            # Test set forward pass
            out = ann(test_data.view(batch_size, -1))

            # Test set loss
            test_loss = torch.zeros((1), dtype=dtype, device=device)
            test_loss += loss(out, test_targets)

            test_loss_hist.append(test_loss.item())

            # Print train/test loss/accuracy
            if counter % 50 == 0:
                train_printer()
            counter += 1
            iter_counter +=1


Epoch 0, Iteration 0
Train Set Loss: 2.33
Test Set Loss: 2.23
Train set accuracy for a single minibatch: 28.00%
Test set accuracy for a single minibatch: 29.00%


Epoch 0, Iteration 50
Train Set Loss: 0.56
Test Set Loss: 0.75
Train set accuracy for a single minibatch: 89.00%
Test set accuracy for a single minibatch: 77.00%


Epoch 0, Iteration 100
Train Set Loss: 0.49
Test Set Loss: 0.55
Train set accuracy for a single minibatch: 88.00%
Test set accuracy for a single minibatch: 83.00%


Epoch 0, Iteration 150
Train Set Loss: 0.34
Test Set Loss: 0.39
Train set accuracy for a single minibatch: 93.00%
Test set accuracy for a single minibatch: 86.00%


Epoch 0, Iteration 200
Train Set Loss: 0.40
Test Set Loss: 0.37
Train set accuracy for a single minibatch: 89.00%
Test set accuracy for a single minibatch: 87.00%


Epoch 0, Iteration 250
Train Set Loss: 0.42
Test Set Loss: 0.19
Train set accuracy for a single minibatch: 84.00%
Test set accuracy for a single minibatch: 93.00%


Epoch 0, Iter

In [4]:
total = 0
correct = 0

# drop_last switched to False to keep all samples
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=False)

with torch.no_grad():
  ann.eval()
  for data, targets in test_loader:
    data = data.to(device)
    targets = targets.to(device)

    # forward pass
    out = ann(data.view(data.size(0), -1))

    # calculate total accuracy
    _, predicted = out.max(1)
    total += targets.size(0)
    correct += (predicted == targets).sum().item()
print(f"Total correctly classified test set images: {correct}/{total}")
print(f"Test Set Accuracy: {100 * correct / total:.2f}%")

Total correctly classified test set images: 9195/10000
Test Set Accuracy: 91.95%


In [9]:
net.state_dict()

OrderedDict([('fc1.weight',
              tensor([[ 1.1317e-02, -3.3913e-02, -8.5571e-03,  ...,  6.2682e-03,
                        1.3216e-02,  1.6880e-02],
                      [ 1.2386e-02,  3.1120e-02, -2.2875e-02,  ..., -2.8876e-02,
                        2.7425e-02,  2.4119e-02],
                      [-2.4773e-04,  2.7669e-02, -2.6614e-02,  ..., -4.6994e-03,
                       -3.4317e-02, -3.2045e-02],
                      ...,
                      [-3.6028e-03, -3.1044e-02,  2.8029e-02,  ...,  3.0228e-02,
                        9.9771e-03,  3.1691e-02],
                      [ 3.8209e-03, -6.5553e-03, -2.5519e-02,  ...,  3.2070e-02,
                       -7.8906e-03, -1.5710e-02],
                      [ 2.7915e-02,  1.9001e-03,  4.1083e-03,  ...,  3.0904e-02,
                       -5.5127e-03, -7.2479e-05]])),
             ('fc1.bias',
              tensor([-2.6208e-02, -3.3698e-02,  1.5687e-02,  5.6119e-02,  2.8172e-02,
                       6.0481e-02,  5.2268e