In [1]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from spikingjelly.activation_based import neuron, functional, surrogate, layer
from torch.utils.tensorboard import SummaryWriter
import os
import time
import argparse
from torch.cuda import amp
import sys
import datetime
from spikingjelly import visualizing
import torch.utils.data as data


2025-08-23 20:51:29.444847: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-08-23 20:51:29.472991: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-08-23 20:51:30.112648: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [2]:
def load_mnist(path, kind='train'):
    import os
    import gzip
    import numpy as np

    """Load MNIST data from `path`"""
    labels_path = os.path.join(path,
                               '%s-labels-idx1-ubyte.gz'
                               % kind)
    images_path = os.path.join(path,
                               '%s-images-idx3-ubyte.gz'
                               % kind)

    with gzip.open(labels_path, 'rb') as lbpath:
        labels = np.frombuffer(lbpath.read(), dtype=np.uint8,
                               offset=8)

    with gzip.open(images_path, 'rb') as imgpath:
        images = np.frombuffer(imgpath.read(), dtype=np.uint8,
                               offset=16).reshape(len(labels), 784)

    return images, labels

In [3]:
X_train, y_train = load_mnist('/media/ubuntu/sda/SNN_train/MNIST_fashion', kind='train')
X_train = X_train.reshape(X_train.shape[0], 28, 28)
X_test, y_test = load_mnist('/media/ubuntu/sda/SNN_train/MNIST_fashion', kind='t10k')
X_test = X_test.reshape(X_test.shape[0], 28 , 28)

In [4]:
X_train_tensor = torch.from_numpy(X_train).float()
X_test_tensor = torch.from_numpy(X_test).float()
y_train_tensor = torch.from_numpy(y_train).long()
y_test_tensor = torch.from_numpy(y_test).long()

train_dataset = data.TensorDataset(X_train_tensor, y_train_tensor)
test_dataset = data.TensorDataset(X_test_tensor, y_test_tensor)

train_loader = data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = data.DataLoader(test_dataset, batch_size=64, shuffle=False)

  X_train_tensor = torch.from_numpy(X_train).float()


In [5]:
class fashion_SNN(nn.Module):
    def __init__(self, T, channels):
        super().__init__()
        self.T = T
        self.channels = channels

        self.conv = nn.Sequential(
            layer.Conv2d(1, channels, kernel_size=3, padding=1, bias=False),
            layer.BatchNorm2d(channels),
            neuron.LIFNode(surrogate_function=surrogate.ATan()),
            layer.MaxPool2d(2, 2),

            layer.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
            layer.BatchNorm2d(channels),
            neuron.LIFNode(surrogate_function=surrogate.ATan()),
            layer.MaxPool2d(2, 2),

            layer.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
            layer.BatchNorm2d(channels),
            neuron.LIFNode(surrogate_function=surrogate.ATan()),
            layer.MaxPool2d(2, 2),
        )
        self.linear = nn.Sequential(
            layer.Linear(channels * 3 * 3, 10, bias=False),
            neuron.LIFNode(surrogate_function=surrogate.ATan())
        )

        functional.set_step_mode(self, step_mode='m')

    def forward(self, x):
        x_seq = x.unsqueeze(1).repeat(self.T, 1, 1, 1, 1)  # [N, C, H, W] -> [T, N, C, H, W]
        batch_size = x_seq.shape[1]

        x_seq = self.conv(x_seq)
        x_seq = x_seq.reshape(self.T, batch_size, self.channels * 3 * 3)
        x_seq = self.linear(x_seq)
        fr = x_seq.mean(0)
        return fr
    
net = fashion_SNN(T = 40, channels= 32)
net = net.to('cuda')
optimizer = torch.optim.Adam(net.parameters(), lr=0.0001)

In [6]:
device = 'cuda'
for epoch in range(10):
    start_time = time.time()
    net.train()
    train_loss = 0
    train_acc = 0
    train_samples = 0
    for img, label in train_loader:
        optimizer.zero_grad()
        img = img.to(device)
        label = label.to(device)
        label_onehot = F.one_hot(label, 10).float()

        out_fr = net(img)

        loss = F.mse_loss(out_fr, label_onehot)
        loss.backward()
        optimizer.step()

        train_samples += label.numel()
        train_loss += loss.item() * label.numel()
        train_acc += (out_fr.argmax(1) == label).float().sum().item()

        functional.reset_net(net)

    train_time = time.time()
    train_speed = train_samples / (train_time - start_time)
    train_loss /= train_samples
    train_acc /= train_samples

    net.eval()
    test_loss = 0
    test_acc = 0
    test_samples = 0
    with torch.no_grad():
        for img, label in test_loader:
            img = img.to(device)
            label = label.to(device)
            label_onehot = F.one_hot(label, 10).float()

            out_fr = net(img)
            loss = F.mse_loss(out_fr, label_onehot)

            test_samples += label.numel()
            test_loss += loss.item() * label.numel()
            test_acc += (out_fr.argmax(1) == label).float().sum().item()
            functional.reset_net(net)
    test_time = time.time()
    test_speed = test_samples / (test_time - train_time)
    test_loss /= test_samples
    test_acc /= test_samples

    print("-" * 60)
    print(f"Epoch {epoch}")
    print(f'Train acc: {train_acc}; Test acc: {test_acc}')

------------------------------------------------------------
Epoch 0
Train acc: 0.7004333333333334; Test acc: 0.7933
------------------------------------------------------------
Epoch 1
Train acc: 0.8208; Test acc: 0.8318
------------------------------------------------------------
Epoch 2
Train acc: 0.8522833333333333; Test acc: 0.848
------------------------------------------------------------
Epoch 3
Train acc: 0.8664666666666667; Test acc: 0.8629
------------------------------------------------------------
Epoch 4
Train acc: 0.8765333333333334; Test acc: 0.8704
------------------------------------------------------------
Epoch 5
Train acc: 0.8842333333333333; Test acc: 0.8749
------------------------------------------------------------
Epoch 6
Train acc: 0.8891333333333333; Test acc: 0.8829
------------------------------------------------------------
Epoch 7
Train acc: 0.893; Test acc: 0.8827
------------------------------------------------------------
Epoch 8
Train acc: 0.89628333