In [1]:
# Install required libraries
!pip install tonic --quiet
!pip install spikingjelly --quiet
!pip install torch torchvision --quiet
!pip install scikit-learn --quiet
!pip install tqdm --quiet

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m106.2/106.2 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.4/50.4 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m109.0/109.0 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m437.6/437.6 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m30.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from spikingjelly.clock_driven import neuron, functional, surrogate
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
import tonic
from tonic import transforms, datasets
import time
import os
from tqdm.notebook import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

2025-07-07 18:14:18.584519: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1751912058.794619      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1751912058.852803      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
transform = transforms.ToVoxelGrid(sensor_size=(34, 34, 2), n_time_bins=30)

train_ds = datasets.NMNIST(
    save_to="/kaggle/working",
    train=True,
    transform=transform
)
test_ds = datasets.NMNIST(
    save_to="/kaggle/working",
    train=False,
    transform=transform
)

Downloading https://prod-dcd-datasets-public-files-eu-west-1.s3.eu-west-1.amazonaws.com/1afc103f-8799-464a-a214-81bb9b1f9337 to /kaggle/working/NMNIST/train.zip


  0%|          | 0/1011893601 [00:00<?, ?it/s]

Extracting /kaggle/working/NMNIST/train.zip to /kaggle/working/NMNIST
Downloading https://prod-dcd-datasets-public-files-eu-west-1.s3.eu-west-1.amazonaws.com/a99d0fee-a95b-4231-ad22-988fdb0a2411 to /kaggle/working/NMNIST/test.zip


  0%|          | 0/169674850 [00:00<?, ?it/s]

Extracting /kaggle/working/NMNIST/test.zip to /kaggle/working/NMNIST


In [4]:
class SpikingRNNROI(nn.Module):
    def __init__(self):
        super().__init__()
        # deeper conv stack with batch norm
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.bn1   = nn.BatchNorm2d(16)
        self.lif1  = neuron.LIFNode()
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.bn2   = nn.BatchNorm2d(32)
        self.lif2  = neuron.LIFNode()
        # mask head
        self.mask_conv = nn.Conv2d(32, 1, 1)
        self.sig       = nn.Sigmoid()

    def forward(self, x_seq):
        masks = []
        for x in x_seq:
            f = self.lif1(self.bn1(self.conv1(x)))
            f = self.lif2(self.bn2(self.conv2(f)))
            m = self.sig(self.mask_conv(f))
            masks.append(m)
        return masks

In [5]:
class LeNetSNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 12, kernel_size=5),
            neuron.LIFNode(),
            nn.MaxPool2d(2),
            nn.Conv2d(12, 32, kernel_size=5),
            neuron.LIFNode(),
            nn.MaxPool2d(2)
        )
        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32 * 5 * 5, 100),
            neuron.LIFNode(),
            nn.Linear(100, 10)
        )

    def forward(self, x):
        x_seq = []
        for t in range(x.size(0)):
            out = self.conv_layers(x[t])
            out = self.fc_layers(out)
            x_seq.append(out)
        return sum(x_seq) / len(x_seq)

In [6]:
class CNNBaseline(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv3d(1, 16, kernel_size=(3, 5, 5), padding=(1, 2, 2)),
            nn.ReLU(),
            nn.MaxPool3d((1, 2, 2)),
            nn.Conv3d(16, 32, kernel_size=(3, 5, 5), padding=(1, 2, 2)),
            nn.ReLU(),
            nn.MaxPool3d((1, 2, 2))
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32 * 30 * 8 * 8, 100),
            nn.ReLU(),
            nn.Linear(100, 10)
        )

    def forward(self, x):
        x = x.unsqueeze(1)
        x = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, 1, 30, 34, 34)
        return self.fc(self.conv_layers(x))

In [7]:
from torch.utils.data import DataLoader

train_loader = DataLoader(
    train_ds,
    batch_size=32,
    shuffle=True,
    drop_last=True,
    num_workers=2,  
    pin_memory=True
)

test_loader = DataLoader(
    test_ds,
    batch_size=32,
    shuffle=False,
    drop_last=False,
    num_workers=2,
    pin_memory=True
)

In [8]:
import torch
import torch.nn as nn
from tqdm.auto import tqdm
from sklearn.metrics import f1_score
from spikingjelly.clock_driven import functional

def train_one_epoch(model, predictor, train_loader, optimizer, criterion, device):
    model.train()
    predictor.train()
    total_loss = 0
    total_correct = 0

    for voxel_frames, labels in tqdm(
        train_loader,
        desc='  Train',
        unit='batch',
        leave=False
    ):
        # reset spiking state
        functional.reset_net(model)
        functional.reset_net(predictor)

        # reorder to time-first list of frames
        x = voxel_frames.permute(1, 0, 2, 3, 4).float().to(device)
        labels = labels.to(device)

        # predict one mask per frame
        masks = predictor([frame for frame in x])
        filtered = torch.stack([f * m for f, m in zip(x, masks)], dim=0)

        optimizer.zero_grad()
        outputs = model(filtered)  # [B, num_classes]
        loss    = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss   += loss.item() * labels.size(0)
        total_correct+= (outputs.argmax(1) == labels).sum().item()

    return total_loss / len(train_loader.dataset), total_correct / len(train_loader.dataset)


def validate_and_get_f1(model, predictor, val_loader, device):
    model.eval()
    predictor.eval()
    all_true, all_pred = [], []

    with torch.no_grad():
        for voxel_frames, labels in tqdm(
            val_loader,
            desc='  Val  ',
            unit='batch',
            leave=False
        ):
            functional.reset_net(model)
            functional.reset_net(predictor)

            x = voxel_frames.permute(1, 0, 2, 3, 4).float().to(device)
            labels = labels.to(device)

            masks = predictor([frame for frame in x])
            filtered = torch.stack([f * m for f, m in zip(x, masks)], dim=0)

            outputs = model(filtered)
            preds   = outputs.argmax(1).cpu().tolist()

            all_pred.extend(preds)
            all_true.extend(labels.cpu().tolist())

    macro_f1     = f1_score(all_true, all_pred, average='macro', zero_division=0)
    per_class_f1 = f1_score(all_true, all_pred, average=None, zero_division=0)
    return macro_f1, per_class_f1

In [9]:

import numpy as np
import torch
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

predictor = SpikingRNNROI().to(device)
model     = LeNetSNN().to(device)

# optimizer and LR scheduler
initial_lr = 1e-3
optimizer  = Adam(
    list(model.parameters()) + list(predictor.parameters()),
    lr=initial_lr
)
scheduler = CosineAnnealingLR(optimizer, T_max=25)

# compute class weights
all_labels = [int(lbl) for _, lbl in train_loader.dataset]
counts     = np.bincount(all_labels, minlength=10)
weights    = 1.0 / (counts + 1e-6)
class_weights = torch.tensor(weights, dtype=torch.float32).to(device)

criterion = nn.CrossEntropyLoss(weight=class_weights)

# early stopping state and hyperparameters
best_val_f1 = 0.0
stalled     = 0
patience    = 3
max_epochs  = 20

In [10]:
from tqdm.auto import trange

for epoch in trange(1, max_epochs+1, desc='Epochs', unit='epoch'):
    train_loss, train_acc = train_one_epoch(
        model, predictor,
        train_loader, optimizer,
        criterion, device
    )

    scheduler.step()

    val_macro_f1, val_perclass_f1 = validate_and_get_f1(
        model, predictor,
        test_loader, device
    )

    perclass_str = ", ".join(
        f"{i}:{f1:.2f}" for i, f1 in enumerate(val_perclass_f1)
    )

    tqdm.write(
        f"Epoch {epoch}/{max_epochs} | "
        f"Train loss: {train_loss:.4f}, acc: {train_acc:.4f} | "
        f"Val macro-F₁: {val_macro_f1:.4f} | "
        f"Per-class F₁: [{perclass_str}]"
    )

    if val_macro_f1 > best_val_f1 + 1e-4:
        best_val_f1 = val_macro_f1
        stalled     = 0
    else:
        stalled += 1
        if stalled >= patience:
            tqdm.write(f"Early stopping at epoch {epoch}")
            break

Epochs:   0%|          | 0/20 [00:00<?, ?epoch/s]

  Train:   0%|          | 0/1875 [00:00<?, ?batch/s]

  Val  :   0%|          | 0/313 [00:00<?, ?batch/s]

Epoch 1/20 | Train loss: 0.3718, acc: 0.9017 | Val macro-F₁: 0.9728 | Per-class F₁: [0:0.99, 1:0.99, 2:0.97, 3:0.98, 4:0.97, 5:0.98, 6:0.98, 7:0.96, 8:0.96, 9:0.95]


  Train:   0%|          | 0/1875 [00:00<?, ?batch/s]

  Val  :   0%|          | 0/313 [00:00<?, ?batch/s]

Epoch 2/20 | Train loss: 0.0898, acc: 0.9750 | Val macro-F₁: 0.9813 | Per-class F₁: [0:0.98, 1:0.99, 2:0.98, 3:0.98, 4:0.99, 5:0.99, 6:0.98, 7:0.98, 8:0.98, 9:0.97]


  Train:   0%|          | 0/1875 [00:00<?, ?batch/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b58945527a0>
<function _MultiProcessingDataLoaderIter.__del__ at 0x7b58945527a0>Traceback (most recent call last):

Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
        self._shutdown_workers()self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers

      File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():if w.is_alive():
 
             ^^^^^^^^^^^^^^^^^^^^^^^
^  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive

      File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
assert self._par

  Val  :   0%|          | 0/313 [00:00<?, ?batch/s]

Epoch 3/20 | Train loss: 0.0655, acc: 0.9809 | Val macro-F₁: 0.9842 | Per-class F₁: [0:0.99, 1:0.99, 2:0.98, 3:0.99, 4:0.98, 5:0.99, 6:0.99, 7:0.98, 8:0.98, 9:0.97]


  Train:   0%|          | 0/1875 [00:00<?, ?batch/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b58945527a0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
Exception ignored in:   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
<function _MultiProcessingDataLoaderIter.__del__ at 0x7b58945527a0>    
if w.is_alive():Traceback (most recent call last):
 Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b58945527a0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1564, in _shutdown_workers
    self._pin_memory_thread.join()
  File "/usr/lib/python3.11/threading.py", line 1116, in join
    raise Runtime

  Val  :   0%|          | 0/313 [00:00<?, ?batch/s]

Epoch 4/20 | Train loss: 0.0542, acc: 0.9838 | Val macro-F₁: 0.9846 | Per-class F₁: [0:0.99, 1:0.99, 2:0.98, 3:0.98, 4:0.99, 5:0.99, 6:0.98, 7:0.98, 8:0.98, 9:0.98]


  Train:   0%|          | 0/1875 [00:00<?, ?batch/s]

  Val  :   0%|          | 0/313 [00:00<?, ?batch/s]

Epoch 5/20 | Train loss: 0.0459, acc: 0.9862 | Val macro-F₁: 0.9853 | Per-class F₁: [0:0.98, 1:0.99, 2:0.98, 3:0.99, 4:0.99, 5:0.99, 6:0.98, 7:0.99, 8:0.98, 9:0.98]


  Train:   0%|          | 0/1875 [00:00<?, ?batch/s]

Exception ignored in: Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x7b58945527a0><function _MultiProcessingDataLoaderIter.__del__ at 0x7b58945527a0>
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()    
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
self._shutdown_workers()    
if w.is_alive():  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers

     if w.is_alive():  
        ^  ^ ^^^^^^^^^^^^^^^^^^^
^  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
^    ^assert self._parent_pid == os.getpid(), 'can only test a child process'

    File "/usr/lib/pyth

  Val  :   0%|          | 0/313 [00:00<?, ?batch/s]

Epoch 6/20 | Train loss: 0.0398, acc: 0.9880 | Val macro-F₁: 0.9892 | Per-class F₁: [0:0.99, 1:0.99, 2:0.99, 3:0.99, 4:0.99, 5:0.99, 6:0.98, 7:0.99, 8:0.99, 9:0.99]


  Train:   0%|          | 0/1875 [00:00<?, ?batch/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b58945527a0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b58945527a0>    if w.is_alive():

Traceback (most recent call last):
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
      self._shutdown_workers() 
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
     if w.is_alive(): ^
^ ^ ^ ^^  ^^ ^ ^^^^^^
^  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
^    assert self._parent_pid == os.getpid(), 'can only test a child process'^^^
^ ^ ^ ^ 
   File "/usr/lib/p

  Val  :   0%|          | 0/313 [00:00<?, ?batch/s]

Epoch 7/20 | Train loss: 0.0355, acc: 0.9891 | Val macro-F₁: 0.9878 | Per-class F₁: [0:0.99, 1:0.99, 2:0.99, 3:0.99, 4:0.99, 5:0.99, 6:0.99, 7:0.99, 8:0.98, 9:0.98]


  Train:   0%|          | 0/1875 [00:00<?, ?batch/s]

  Val  :   0%|          | 0/313 [00:00<?, ?batch/s]

Epoch 8/20 | Train loss: 0.0311, acc: 0.9904 | Val macro-F₁: 0.9878 | Per-class F₁: [0:0.99, 1:0.99, 2:0.98, 3:0.99, 4:0.99, 5:0.99, 6:0.99, 7:0.99, 8:0.98, 9:0.98]


  Train:   0%|          | 0/1875 [00:00<?, ?batch/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b58945527a0>
<function _MultiProcessingDataLoaderIter.__del__ at 0x7b58945527a0>Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__

Traceback (most recent call last):
      File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()self._shutdown_workers()

  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
        if w.is_alive():
if w.is_alive():
            ^ ^ ^^^^^^^^^^^^^^^^^^
^  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    ^^assert self._parent_pid == os.getpid(), 'can only test a child process'
^  
   File "/usr/lib/pyt

  Val  :   0%|          | 0/313 [00:00<?, ?batch/s]

Epoch 9/20 | Train loss: 0.0291, acc: 0.9912 | Val macro-F₁: 0.9888 | Per-class F₁: [0:0.99, 1:0.99, 2:0.99, 3:0.99, 4:0.99, 5:0.99, 6:0.99, 7:0.99, 8:0.98, 9:0.98]
Early stopping at epoch 9


In [11]:
def train_one_epoch_cnn(model, train_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    total_correct = 0
    for X, y in tqdm(
        train_loader,
        desc='  Train',
        unit='batch',
        leave=False
    ):
        X, y = X.float().to(device), y.to(device)
        optimizer.zero_grad()
        preds = model(X)
        loss = criterion(preds, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * y.size(0)
        total_correct += (preds.argmax(1) == y).sum().item()

    avg_loss = total_loss / len(train_loader.dataset)
    acc = total_correct / len(train_loader.dataset)
    return avg_loss, acc

def validate_cnn(model, test_loader, device):
    model.eval()
    all_true = []
    all_pred = []
    with torch.no_grad():
        for X, y in tqdm(
            test_loader,
            desc='  Val  ',
            unit='batch',
            leave=False
        ):
            X, y = X.float().to(device), y.to(device)
            preds = model(X)
            all_pred.extend(preds.argmax(1).cpu().tolist())
            all_true.extend(y.cpu().tolist())

    macro_f1 = f1_score(all_true, all_pred, average='macro', zero_division=0)
    per_class_f1 = f1_score(all_true, all_pred, average=None, zero_division=0)
    return macro_f1, per_class_f1

In [12]:
import numpy as np
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# assume train_loader and val_loader are defined
cnn_model = CNNBaseline().to(device)
optimizer = Adam(cnn_model.parameters(), lr=1e-3)
scheduler = CosineAnnealingLR(optimizer, T_max=25)

# compute class weights (same as SNN)
all_labels = [int(lbl) for _, lbl in train_loader.dataset]
counts     = np.bincount(all_labels, minlength=10)
weights    = 1.0 / (counts + 1e-6)
class_weights = torch.tensor(weights, dtype=torch.float32).to(device)
criterion      = nn.CrossEntropyLoss(weight=class_weights)

best_val_f1 = 0.0
stalled     = 0
patience    = 3
max_epochs  = 20

In [13]:
from tqdm.auto import trange

for epoch in trange(1, max_epochs+1, desc='Epochs', unit='epoch'):
    tr_loss, tr_acc = train_one_epoch_cnn(
        cnn_model, train_loader,
        optimizer, criterion, device
    )
    scheduler.step()

    val_macro_f1, val_perclass_f1 = validate_cnn(
        cnn_model, test_loader, device
    )

    perclass_str = ", ".join(
        f"{i}:{f1:.2f}" for i, f1 in enumerate(val_perclass_f1)
    )

    tqdm.write(
        f"Epoch {epoch}/{max_epochs} | "
        f"Train loss: {tr_loss:.4f}, acc: {tr_acc:.4f} | "
        f"Val macro-F₁: {val_macro_f1:.4f} | "
        f"Per-class F₁: [{perclass_str}]"
    )

    if val_macro_f1 > best_val_f1 + 1e-4:
        best_val_f1 = val_macro_f1
        stalled     = 0
        # torch.save(cnn_model.state_dict(), 'results/cnn_best.pth')
    else:
        stalled += 1
        if stalled >= patience:
            tqdm.write(f"Early stopping at epoch {epoch}")
            break

Epochs:   0%|          | 0/20 [00:00<?, ?epoch/s]

  Train:   0%|          | 0/1875 [00:00<?, ?batch/s]

  Val  :   0%|          | 0/313 [00:00<?, ?batch/s]

Epoch 1/20 | Train loss: 2.3046, acc: 0.0990 | Val macro-F₁: 0.0183 | Per-class F₁: [0:0.00, 1:0.00, 2:0.00, 3:0.18, 4:0.00, 5:0.00, 6:0.00, 7:0.00, 8:0.00, 9:0.00]


  Train:   0%|          | 0/1875 [00:00<?, ?batch/s]

  Val  :   0%|          | 0/313 [00:00<?, ?batch/s]

Epoch 2/20 | Train loss: 2.3028, acc: 0.0991 | Val macro-F₁: 0.0179 | Per-class F₁: [0:0.18, 1:0.00, 2:0.00, 3:0.00, 4:0.00, 5:0.00, 6:0.00, 7:0.00, 8:0.00, 9:0.00]


  Train:   0%|          | 0/1875 [00:00<?, ?batch/s]

  Val  :   0%|          | 0/313 [00:00<?, ?batch/s]

Epoch 3/20 | Train loss: 2.3028, acc: 0.1010 | Val macro-F₁: 0.0178 | Per-class F₁: [0:0.00, 1:0.00, 2:0.00, 3:0.00, 4:0.00, 5:0.00, 6:0.00, 7:0.00, 8:0.18, 9:0.00]


  Train:   0%|          | 0/1875 [00:00<?, ?batch/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b58945527a0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b58945527a0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 16

  Val  :   0%|          | 0/313 [00:00<?, ?batch/s]

Epoch 4/20 | Train loss: 2.3028, acc: 0.0994 | Val macro-F₁: 0.0183 | Per-class F₁: [0:0.00, 1:0.00, 2:0.00, 3:0.18, 4:0.00, 5:0.00, 6:0.00, 7:0.00, 8:0.00, 9:0.00]
Early stopping at epoch 4
