In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
# from spikingjelly.clock_driven import functional, surrogate, neuron, layer
from spikingjelly.activation_based import spike_op as sn
from spikingjelly.activation_based import functional, surrogate, neuron, layer, encoding
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import cupy

  def forward(ctx, spike, weight, bias, stride, padding, dilation, groups):
  def backward(ctx, grad_output):
  def forward(ctx, spike, weight, bias=None):
  def backward(ctx, grad_output):


In [4]:
time_steps = 8
batch_size = 16
learning_rate = 1e-3
epochs = 100

transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((48, 48)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
])

# Load the dataset
train_dataset = datasets.ImageFolder(
    root='./dataset/train', transform=transform)
test_dataset = datasets.ImageFolder(
    root='./dataset/test', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Define device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
from ptflops import get_model_complexity_info

In [5]:
import numpy as np


class CSNN(nn.Module):
    def __init__(self, T=8):
        super(CSNN, self).__init__()
        self.T = T
        self.layer1 = nn.Sequential(
            layer.Conv2d(1, 128, kernel_size=3, stride=1,
                         padding=1, step_mode='s'),
            layer.BatchNorm2d(128, step_mode='s'),
        )
        self.layer2 = nn.Sequential(
            neuron.IFNode(v_threshold=1.0,
                          surrogate_function=surrogate.ATan(), step_mode='m'),
            layer.MaxPool2d(kernel_size=2, stride=2),
            layer.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            layer.BatchNorm2d(128),
            neuron.IFNode(v_threshold=1.0,
                          surrogate_function=surrogate.ATan(), step_mode='m'),
            layer.MaxPool2d(kernel_size=2)
        )
        self.layer3 = nn.Sequential(
            layer.Flatten(),
            layer.Dropout(p=0.5),
            layer.Linear(128*12*12, 1152),
            neuron.LIFNode(tau=2.0, v_reset=0.0,
                           surrogate_function=surrogate.ATan(), step_mode='m'),
            layer.Dropout(p=0.5),
            layer.Linear(1152, 128),
            neuron.LIFNode(tau=2.0, v_reset=0.0,
                           surrogate_function=surrogate.ATan(), step_mode='m'),
            layer.Linear(128, 7),
            neuron.LIFNode(tau=2.0, v_reset=0.0,
                           surrogate_function=surrogate.ATan(), step_mode='m')
        )

        functional.set_step_mode(self.layer2, step_mode='m')
        functional.set_step_mode(self.layer3, step_mode='m')
        functional.set_backend(self, backend='cupy')
        self.pe = encoding.PoissonEncoder()

    def forward(self, x):
        # x = self.pe(x)
        x = self.layer1(x)
        x_step = x.unsqueeze(0).repeat(self.T, 1, 1, 1, 1)
        out1 = self.layer2(x_step)
        out2 = self.layer3(out1)
        # for t in range(self.T):
        #     x_time[t] = self.pe(x)
        # x_step = []
        # for t in range(self.T):
        #     out1 = self.layer2(x_time[t])
        #     out2 = self.layer3(out1)
        #     x_step.append(out2.unsqueeze(0))
        # x = torch.cat(x_step)
        return out2.mean(dim=0)


np.int = int


def train(model, loader, optimizer, criterion, mse=False):
    model.train()
    total_loss, correct = 0, 0
    for data, target in loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        if mse:
          target_onehot = torch.nn.functional.one_hot(
              target, num_classes=7).float()
          loss = criterion(output, target_onehot)
        else:
          loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.size(0)
        correct += (output.argmax(dim=1) == target).sum().item()
        functional.reset_net(model)  # Reset neuron states
    return total_loss / len(loader.dataset), correct / len(loader.dataset)


def validate(model, loader, criterion, mse=False):
    model.eval()
    total_loss, correct = 0, 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            if mse:
              target_onehot = torch.nn.functional.one_hot(
                  target, num_classes=7).float()
              loss = criterion(output, target_onehot)
            else:
              loss = criterion(output, target)
            total_loss += loss.item() * data.size(0)
            correct += (output.argmax(dim=1) == target).sum().item()
            functional.reset_net(model)  # Reset neuron states
    return total_loss / len(loader.dataset), correct / len(loader.dataset)


# Initialize model, criterion, and optimizer
model = CSNN(T=4).to(device)
# print(model.)
# criterion = nn.CrossEntropyLoss()
criterion = nn.MSELoss().to(device)

mse = True if criterion._get_name() == 'MSELoss' else False

optimizer = optim.Adam(model.parameters(), lr=learning_rate)

train_acc_hist = []
val_acc_hist = []
train_loss_hist = []
val_loss_hist = []
# Train and Validate
# for epoch in range(1, epochs + 1):
#     train_loss, train_acc = train(
#         model, train_loader, optimizer, criterion, mse=mse)
#     val_loss, val_acc = validate(model, test_loader, criterion, mse=mse)
#     train_acc_hist.append(train_acc)
#     val_acc_hist.append(val_acc)
#     train_loss_hist.append(train_loss)
#     val_loss_hist.append(val_loss)
#     print(f"Epoch {epoch}: Train Loss: {train_loss:.4f}, Train Acc: {100*train_acc:.4f}%, "
#           f"Val Loss: {val_loss:.4f}, Val Acc: {100*val_acc:.4f}%")



In [6]:
# with torch.cuda.device(0):  # use CPU if needed
macs, params = get_model_complexity_info(model, (1, 48, 48), as_strings=False,print_per_layer_stat=True, verbose=True)
print(f"MACs: {macs/1e6:.2f} M")
print(f"FLOPs (≈ 2 × MACs): {macs * 2 / 1e6:.2f} MFLOPs")
print(f"Parameters: {params / 1e6:.2f} M")

CSNN(
  0, 0.000% Params, 0.0 Mac, 0.000% MACs, 
  (layer1): Sequential(
    0, 0.000% Params, 0.0 Mac, 0.000% MACs, 
    (0): Conv2d(0, 0.000% Params, 0.0 Mac, 0.000% MACs, 1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), step_mode=s)
    (1): BatchNorm2d(0, 0.000% Params, 0.0 Mac, 0.000% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
  )
  (layer2): Sequential(
    0, 0.000% Params, 0.0 Mac, 0.000% MACs, 
    (0): IFNode(
      0, 0.000% Params, 0.0 Mac, 0.000% MACs, v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=cupy
      (surrogate_function): ATan(0, 0.000% Params, 0.0 Mac, 0.000% MACs, alpha=2.0, spiking=True)
    )
    (1): MaxPool2d(0, 0.000% Params, 0.0 Mac, 0.000% MACs, kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False, step_mode=m)
    (2): Conv2d(0, 0.000% Params, 0.0 Mac, 0.000% MACs, 128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), step_mode=m)
    (3): BatchNorm2d(0, 0.00

In [None]:
macs, params = get_model_complexity_info(
    model, (1, 48, 48), as_strings=True, print_per_layer_stat=True)
print(f"FLOPs: {macs}")
print(f"Parameters: {params}")

CSNN(
  0, 0.000% Params, 0.0 Mac, 0.000% MACs, 
  (layer1): Sequential(
    0, 0.000% Params, 0.0 Mac, 0.000% MACs, 
    (0): Conv2d(0, 0.000% Params, 0.0 Mac, 0.000% MACs, 1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), step_mode=s)
    (1): BatchNorm2d(0, 0.000% Params, 0.0 Mac, 0.000% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
  )
  (layer2): Sequential(
    0, 0.000% Params, 0.0 Mac, 0.000% MACs, 
    (0): IFNode(
      0, 0.000% Params, 0.0 Mac, 0.000% MACs, v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=cupy
      (surrogate_function): ATan(0, 0.000% Params, 0.0 Mac, 0.000% MACs, alpha=2.0, spiking=True)
    )
    (1): MaxPool2d(0, 0.000% Params, 0.0 Mac, 0.000% MACs, kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False, step_mode=m)
    (2): Conv2d(0, 0.000% Params, 0.0 Mac, 0.000% MACs, 128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), step_mode=m)
    (3): BatchNorm2d(0, 0.00