In [1]:
import torch
import torch.nn as nn
import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
from snntorch import spikegen

In [12]:
data = np.loadtxt("../data.txt")          # shape (400, 2)
labels = np.loadtxt("../labels.txt")      # shape (400,)
labels = labels.astype(int)

# === Convert to PyTorch tensors ===
X = torch.tensor(data, dtype=torch.float32)
y = torch.tensor(labels, dtype=torch.long)


dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

num_steps = 20
# === Create DataLoader ===

In [13]:
class SNN(nn.Module):
    def __init__(self):
        super().__init__()
        beta = 0.9  # Membrane decay rate
        self.fc1 = nn.Linear(2, 16)
        self.lif1 = snn.Leaky(beta=beta)

        self.fc2 = nn.Linear(16, 2)
        self.lif2 = snn.Leaky(beta=beta)

    def forward(self, x):
        spk2_rec = []

        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        spike_input = spikegen.rate(x, num_steps=num_steps)

        for step in range(num_steps):
            cur1 = self.fc1(spike_input[step])
            spk1, mem1 = self.lif1(cur1, mem1)

            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)

            spk2_rec.append(spk2)

        # Sum spikes over time across timesteps
        return torch.stack(spk2_rec, dim=0).sum(dim=0)

In [14]:
net = SNN()
loss_fn = SF.ce_count_loss()
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))

In [15]:
num_epochs = 20
for epoch in range(num_epochs):
    total_loss = 0
    for batch_x, batch_y in dataloader:
        optimizer.zero_grad()

        # Forward pass
        out = net(batch_x)

        # Loss: expects raw spike counts and class indices
        loss = loss_fn(out, batch_y)

        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

RuntimeError: size mismatch (got input: [2], target: [32])

In [20]:
import snntorch as snn
from snntorch import spikegen
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, random_split
import numpy as np

batch_size = 32
num_steps = 20
beta = 0.95

# Load 2D dataset
data = np.loadtxt("../data.txt", dtype=np.float32)
labels = np.loadtxt("../labels.txt", dtype=np.int64)
X = torch.tensor(data)
y = torch.tensor(labels)

# Create train/test split
dataset = TensorDataset(X, y)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Define network adapted to 2D input
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(2, 16)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(16, 2)
        self.lif2 = snn.Leaky(beta=beta)

    def forward(self, x):
        spike_input = spikegen.rate(x, num_steps=num_steps)
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        spk_rec, mem_rec = [], []
        for step in range(num_steps):
            cur1 = self.fc1(spike_input[step])
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)

            spk_rec.append(spk2)
            mem_rec.append(mem2)

        return torch.stack(spk_rec), torch.stack(mem_rec)

net = Net().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4)

num_epochs = 20

for epoch in range(num_epochs):
    total_loss = 0

    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        _, mem_rec = net(inputs)

        loss_val = sum(criterion(mem_rec[step], targets) for step in range(num_steps))

        loss_val.backward()
        optimizer.step()

        total_loss += loss_val.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}")

# Save spike data and labels from test set
spike_save_path = "input_spikes.csv"
label_save_path = "labels.csv"

all_spikes = []
all_labels = []

for data, targets in test_loader:
    data = data.to(device)
    targets = targets.to(device)

    spike_data = spikegen.rate(data, num_steps=num_steps).cpu().numpy()
    spike_data = np.squeeze(spike_data, axis=1)
    all_spikes.append(spike_data)

    all_labels.append(targets.cpu().numpy())

all_spikes = np.concatenate(all_spikes, axis=0)
all_labels = np.concatenate(all_labels, axis=0)

np.savetxt(spike_save_path, all_spikes.astype(np.int8), delimiter=",", fmt="%d")
np.savetxt(label_save_path, all_labels.astype(np.int8), delimiter=",", fmt="%d")

print("Spike data and labels saved as CSV files.")


    * make sure the original data is stored as integers.
    * use the `converters=` keyword argument.  If you only use
      NumPy 1.23 or later, `converters=float` will normally work.
    * Use `np.loadtxt(...).astype(np.int64)` parsing the file as
      floating point and then convert it.  (On all NumPy versions.)
  (Deprecated NumPy 1.23)
  labels = np.loadtxt("../labels.txt", dtype=np.int64)


Epoch [1/20], Average Loss: 14.1056
Epoch [2/20], Average Loss: 14.1027
Epoch [3/20], Average Loss: 14.1128
Epoch [4/20], Average Loss: 14.0549
Epoch [5/20], Average Loss: 14.0444
Epoch [6/20], Average Loss: 14.0294
Epoch [7/20], Average Loss: 14.0265
Epoch [8/20], Average Loss: 13.9826
Epoch [9/20], Average Loss: 13.8664
Epoch [10/20], Average Loss: 13.9056
Epoch [11/20], Average Loss: 13.8852
Epoch [12/20], Average Loss: 13.8752
Epoch [13/20], Average Loss: 13.8700
Epoch [14/20], Average Loss: 13.7787
Epoch [15/20], Average Loss: 13.7929
Epoch [16/20], Average Loss: 13.7318
Epoch [17/20], Average Loss: 13.7233
Epoch [18/20], Average Loss: 13.6839
Epoch [19/20], Average Loss: 13.5925
Epoch [20/20], Average Loss: 13.5273
Spike data and labels saved as CSV files.


In [21]:
np.savetxt("2d_weights_fc1.txt", net.fc1.weight.detach().numpy())
np.savetxt("2d_weights_fc2.txt", net.fc2.weight.detach().numpy())
np.savetxt("2d_bias_fc1.txt", net.fc1.bias.detach().numpy())
np.savetxt("2d_bias_fc2.txt", net.fc2.bias.detach().numpy())