In [3]:
import time
import argparse
import sys
import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torch.cuda import amp
import torchvision
import numpy as np

from spikingjelly.activation_based import neuron, encoding, functional, surrogate, layer

In [4]:
class MNIST_SNN(nn.Module):
    def __init__(self, tau):
        super().__init__()

        self.layer = nn.Sequential(
            layer.Flatten(),
            layer.Linear(28 * 28, 16 * 16, bias=False),
            neuron.LIFNode(tau=tau, surrogate_function=surrogate.ATan()),
            layer.Linear(16 * 16, 10, bias=False),
            neuron.LIFNode(tau=tau, surrogate_function=surrogate.ATan())
        )
    def forward(self, x:torch.Tensor):
        return self.layer(x)

In [5]:
net = MNIST_SNN(tau=2.0)
net = net.to('cuda')

In [6]:
X_train = np.load("/media/ubuntu/sda/SNN_train/X_train.npy")
X_test = np.load("/media/ubuntu/sda/SNN_train/X_test.npy")
y_train = np.load('/media/ubuntu/sda/SNN_train/y_train.npy')
y_test = np.load("/media/ubuntu/sda/SNN_train/y_test.npy")

In [7]:
X_train_tensor = torch.from_numpy(X_train).float()
X_test_tensor = torch.from_numpy(X_test).float()
y_train_tensor = torch.from_numpy(y_train).long()
y_test_tensor = torch.from_numpy(y_test).long()

train_dataset = data.TensorDataset(X_train_tensor, y_train_tensor)
test_dataset = data.TensorDataset(X_test_tensor, y_test_tensor)

train_loader = data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = data.DataLoader(test_dataset, batch_size=64, shuffle=False)

In [8]:
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

In [9]:
encoder = encoding.PoissonEncoder()
device = 'cuda'
for epoch in range(2):
    start_time = time.time()
    net.train()
    train_loss = 0
    train_acc = 0
    train_samples = 0
    for img, label in train_loader:
        optimizer.zero_grad()
        img = img.to(device)
        label = label.to(device)
        label_onehot = F.one_hot(label, 10).float()

        
        out_fr = 0.
        for t in range(100):
            encoded_img = encoder(img)
            out_fr += net(encoded_img)
        out_fr = out_fr / 100
        loss = F.mse_loss(out_fr, label_onehot)
        loss.backward()
        optimizer.step()

        train_samples += label.numel()
        train_loss += loss.item() * label.numel()
        train_acc += (out_fr.argmax(1) == label).float().sum().item()

        functional.reset_net(net)

    train_time = time.time()
    train_speed = train_samples / (train_time - start_time)
    train_loss /= train_samples
    train_acc /= train_samples

    net.eval()
    test_loss = 0
    test_acc = 0
    test_samples = 0
    with torch.no_grad():
        for img, label in test_loader:
            img = img.to(device)
            label = label.to(device)
            label_onehot = F.one_hot(label, 10).float()
            out_fr = 0.
            for t in range(100):
                encoded_img = encoder(img)
                out_fr += net(encoded_img)
            out_fr = out_fr / 100
            loss = F.mse_loss(out_fr, label_onehot)

            test_samples += label.numel()
            test_loss += loss.item() * label.numel()
            test_acc += (out_fr.argmax(1) == label).float().sum().item()
            functional.reset_net(net)
    test_time = time.time()
    test_speed = test_samples / (test_time - train_time)
    test_loss /= test_samples
    test_acc /= test_samples

    print("-" * 60)
    print(f"Epoch {epoch}")
    print(f'Train acc: {train_acc}; Test acc: {test_acc}')

------------------------------------------------------------
Epoch 0
Train acc: 0.9044; Test acc: 0.949
------------------------------------------------------------
Epoch 1
Train acc: 0.9590666666666666; Test acc: 0.9654
