In [1]:
# if run on colab
#!pip install torcheeg
#from google.colab import mount
#drive.mount('/content/drive', force_remount=True) 

In [2]:
# if run on kaggle
!pip install torcheeg

Collecting torcheeg
  Downloading torcheeg-1.1.2.tar.gz (214 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m214.5/214.5 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l- done
Collecting xlrd>=2.0.1 (from torcheeg)
  Downloading xlrd-2.0.1-py2.py3-none-any.whl.metadata (3.4 kB)
Collecting lmdb>=1.3.0 (from torcheeg)
  Downloading lmdb-1.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.2 kB)
Collecting einops>=0.4.1 (from torcheeg)
  Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)
Collecting xmltodict>=0.13.0 (from torcheeg)
  Downloading xmltodict-0.13.0-py2.py3-none-any.whl.metadata (7.7 kB)
Collecting spectrum>=0.8.1 (from torcheeg)
  Downloading spectrum-0.8.1.tar.gz (230 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m230.8/230.8 kB[0m [31m13.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l- done
Collecting mne

In [3]:
import torcheeg
from torcheeg import transforms
from torcheeg.datasets import BCICIV2aDataset
from torcheeg.model_selection import KFoldGroupbyTrial
from torch.utils.data import DataLoader
from torcheeg.models import ATCNet, EEGNet
import torch

from torcheeg.trainers import ClassifierTrainer

import pytorch_lightning as pl

In [4]:
# if run on kaggle
root_data_path = '/kaggle/input/bci-competition-iv-dataset-2a-in-mat-format/BCICIV-2a-mat'

# if run on colab
# root_data_path = './BCICIV-2a-mat/'

In [5]:
dataset = BCICIV2aDataset(
    root_path=root_data_path,
    io_path=f'./examples_pipeline/bciciv-2a',
    # skip_trial_with_artifacts=True,
    # offline_transform=transforms.Compose([
    #     transforms.BandDifferentialEntropy(apply_to_baseline=True),
    #     transforms.To2d(apply_to_baseline=True),
    #     transforms.ToTensor(apply_to_baseline=True)
    # ]),
    online_transform=transforms.Compose([
        # transforms.To2d(apply_to_baseline=True),
        # transforms.ToTensor(apply_to_baseline=True),
        transforms.To2d(),
        transforms.ToTensor(),
        # transforms.CWTSpectrum(apply_to_baseline=True),
        # transforms.BandDifferentialEntropy(apply_to_baseline=True),
        # transforms.BaselineRemoval(),
    ]),
    label_transform=transforms.Compose([
        transforms.Select('label'),
        transforms.Lambda(lambda x: x - 1)
    ]),
    chunk_size=7*250,
    num_worker=2
)

[2024-05-23 02:30:24] INFO (torcheeg/MainThread) 🔍 | Processing EEG data. Processed EEG data has been cached to [92m./examples_pipeline/bciciv-2a[0m.
[2024-05-23 02:30:24] INFO (torcheeg/MainThread) ⏳ | Monitoring the detailed processing of a record for debugging. The processing of other records will only be reported in percentage to keep it clean.
[PROCESS]:   0%|          | 0/18 [00:00<?, ?it/s]
[RECORD /kaggle/input/bci-competition-iv-dataset-2a-in-mat-format/BCICIV-2a-mat/A04T.mat]: 0it [00:00, ?it/s][A
[RECORD /kaggle/input/bci-competition-iv-dataset-2a-in-mat-format/BCICIV-2a-mat/A04T.mat]: 1it [00:01,  1.01s/it][A
[RECORD /kaggle/input/bci-competition-iv-dataset-2a-in-mat-format/BCICIV-2a-mat/A04T.mat]: 16it [00:01, 19.26it/s][A
[RECORD /kaggle/input/bci-competition-iv-dataset-2a-in-mat-format/BCICIV-2a-mat/A04T.mat]: 30it [00:01, 36.70it/s][A
[RECORD /kaggle/input/bci-competition-iv-dataset-2a-in-mat-format/BCICIV-2a-mat/A04T.mat]: 42it [00:01, 50.38it/s][A
[RECORD /kagg

In [6]:
print("Dataset's info: ")
print(dataset.info)

Dataset's info: 
      start_at  end_at   clip_id subject_id  trial_id session subject  run  \
0          251    2001    A04T_0        A04         0       T     A04    1   
1         2254    4004    A04T_1        A04         1       T     A04    1   
2         4172    5922    A04T_2        A04         2       T     A04    1   
3         6124    7874    A04T_3        A04         3       T     A04    1   
4         8132    9882    A04T_4        A04         4       T     A04    1   
...        ...     ...       ...        ...       ...     ...     ...  ...   
5179     86751   88501  A04E_283        A04        43       E     A04    8   
5180     88657   90407  A04E_284        A04        44       E     A04    8   
5181     90585   92335  A04E_285        A04        45       E     A04    8   
5182     92699   94449  A04E_286        A04        46       E     A04    8   
5183     94758   96508  A04E_287        A04        47       E     A04    8   

      label  _record_id  
0         4   _recor

In [7]:
k_fold = KFoldGroupbyTrial(
    n_splits=10,
    split_path='./examples_pipeline/split',
    shuffle=True,
    random_state=44
)

In [8]:
DEVICE = 'gpu' if torch.cuda.is_available() else 'cpu'

for i, (train_dataset, val_dataset) in enumerate(k_fold.split(dataset)):
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=64,
        shuffle=True,
        num_workers=4
    )
    val_loader = DataLoader(
        dataset=val_dataset,
        batch_size=64,
        shuffle=False,
        num_workers=4
    )
    
    model = ATCNet(
        num_classes=4,
        num_electrodes=22,
        chunk_size=7*250,
    )
    
    trainer = ClassifierTrainer(
        model=model,
        num_classes=4,
        lr=1e-4,
        weight_decay=1e-4,
        accelerator=DEVICE
    )

    trainer.fit(
        train_loader,
        val_loader, 
        max_epochs=250,
        default_root_dir=f'./examples_pipeline/atcnet_model/{i}',
        callbacks=[pl.callbacks.ModelCheckpoint(save_last=True)],
        enable_progress_bar=True,
        enable_model_summary=True,
        limit_val_batches=0.0
    )

    score = trainer.test(
        val_loader,
        enable_progress_bar=True,
        enable_model_summary=True
    )[0]
    print(f"Fold {i} test accuracy: {score['test_accuracy']: .4f}")

[2024-05-23 02:31:07] INFO (torcheeg/MainThread) 📊 | Create the split of train and test set.
[2024-05-23 02:31:07] INFO (torcheeg/MainThread) 😊 | Please set [92msplit_path[0m to [92m./examples_pipeline/split[0m for the next run, if you want to use the same setting for the experiment.
  return F.conv2d(input, weight, bias, self.stride,
2024-05-23 02:31:18.531782: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-23 02:31:18.531917: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-23 02:31:18.688601: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Training: |          | 0/? [00:00<?, ?it/s]

[2024-05-23 02:31:36] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.384 train_accuracy: 0.277 

[2024-05-23 02:31:44] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.366 train_accuracy: 0.333 

[2024-05-23 02:31:53] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.352 train_accuracy: 0.351 

[2024-05-23 02:32:01] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.337 train_accuracy: 0.370 

[2024-05-23 02:32:10] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.323 train_accuracy: 0.400 

[2024-05-23 02:32:19] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.304 train_accuracy: 0.433 

[2024-05-23 02:32:27] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.286 train_accuracy: 0.458 

[2024-05-23 02:32:36] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.268 train_accuracy: 0.480 

[2024-05-23 02:32:44] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.249 train_accuracy: 0.510 

[2024-05-23 02:32:53] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.238 train_accuracy: 0.521 



Testing: |          | 0/? [00:00<?, ?it/s]

[2024-05-23 03:07:17] INFO (torcheeg/MainThread) 
[Test] test_loss: 1.088 test_accuracy: 0.628 



Fold 0 test accuracy:  0.6285


Training: |          | 0/? [00:00<?, ?it/s]

[2024-05-23 03:07:29] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.380 train_accuracy: 0.287 

[2024-05-23 03:07:38] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.362 train_accuracy: 0.336 

[2024-05-23 03:07:46] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.339 train_accuracy: 0.390 

[2024-05-23 03:07:55] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.315 train_accuracy: 0.432 

[2024-05-23 03:08:04] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.292 train_accuracy: 0.455 

[2024-05-23 03:08:12] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.275 train_accuracy: 0.478 

[2024-05-23 03:08:21] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.255 train_accuracy: 0.501 

[2024-05-23 03:08:30] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.240 train_accuracy: 0.511 

[2024-05-23 03:08:38] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.228 train_accuracy: 0.523 

[2024-05-23 03:08:47] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.218 train_accuracy: 0.532 



Testing: |          | 0/? [00:00<?, ?it/s]

[2024-05-23 03:43:29] INFO (torcheeg/MainThread) 
[Test] test_loss: 1.087 test_accuracy: 0.652 



Fold 1 test accuracy:  0.6516


Training: |          | 0/? [00:00<?, ?it/s]

[2024-05-23 03:43:42] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.384 train_accuracy: 0.269 

[2024-05-23 03:43:52] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.369 train_accuracy: 0.320 

[2024-05-23 03:44:01] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.356 train_accuracy: 0.350 

[2024-05-23 03:44:11] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.343 train_accuracy: 0.378 

[2024-05-23 03:44:20] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.322 train_accuracy: 0.416 

[2024-05-23 03:44:30] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.306 train_accuracy: 0.438 

[2024-05-23 03:44:39] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.286 train_accuracy: 0.476 

[2024-05-23 03:44:49] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.264 train_accuracy: 0.497 

[2024-05-23 03:44:58] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.249 train_accuracy: 0.515 

[2024-05-23 03:45:08] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.232 train_accuracy: 0.535 



Testing: |          | 0/? [00:00<?, ?it/s]

[2024-05-23 04:23:16] INFO (torcheeg/MainThread) 
[Test] test_loss: 1.051 test_accuracy: 0.678 



Fold 2 test accuracy:  0.6782


Training: |          | 0/? [00:00<?, ?it/s]

[2024-05-23 04:23:30] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.384 train_accuracy: 0.273 

[2024-05-23 04:23:39] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.364 train_accuracy: 0.335 

[2024-05-23 04:23:49] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.338 train_accuracy: 0.384 

[2024-05-23 04:23:58] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.317 train_accuracy: 0.421 

[2024-05-23 04:24:08] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.297 train_accuracy: 0.439 

[2024-05-23 04:24:17] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.282 train_accuracy: 0.462 

[2024-05-23 04:24:27] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.266 train_accuracy: 0.478 

[2024-05-23 04:24:36] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.253 train_accuracy: 0.498 

[2024-05-23 04:24:46] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.244 train_accuracy: 0.507 

[2024-05-23 04:24:55] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.232 train_accuracy: 0.518 



Testing: |          | 0/? [00:00<?, ?it/s]

[2024-05-23 05:03:07] INFO (torcheeg/MainThread) 
[Test] test_loss: 1.115 test_accuracy: 0.616 



Fold 3 test accuracy:  0.6157


Training: |          | 0/? [00:00<?, ?it/s]

[2024-05-23 05:03:20] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.383 train_accuracy: 0.283 

[2024-05-23 05:03:30] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.369 train_accuracy: 0.311 

[2024-05-23 05:03:39] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.349 train_accuracy: 0.364 

[2024-05-23 05:03:49] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.328 train_accuracy: 0.406 

[2024-05-23 05:03:58] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.301 train_accuracy: 0.451 

[2024-05-23 05:04:08] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.278 train_accuracy: 0.473 

[2024-05-23 05:04:17] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.257 train_accuracy: 0.496 

[2024-05-23 05:04:27] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.243 train_accuracy: 0.511 

[2024-05-23 05:04:36] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.233 train_accuracy: 0.516 

[2024-05-23 05:04:46] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.218 train_accuracy: 0.536 



Testing: |          | 0/? [00:00<?, ?it/s]

[2024-05-23 05:42:59] INFO (torcheeg/MainThread) 
[Test] test_loss: 1.127 test_accuracy: 0.616 



Fold 4 test accuracy:  0.6157


Training: |          | 0/? [00:00<?, ?it/s]

[2024-05-23 05:43:12] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.385 train_accuracy: 0.267 

[2024-05-23 05:43:21] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.366 train_accuracy: 0.323 

[2024-05-23 05:43:32] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.347 train_accuracy: 0.364 

[2024-05-23 05:43:41] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.330 train_accuracy: 0.388 

[2024-05-23 05:43:51] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.310 train_accuracy: 0.414 

[2024-05-23 05:44:00] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.297 train_accuracy: 0.430 

[2024-05-23 05:44:10] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.282 train_accuracy: 0.450 

[2024-05-23 05:44:19] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.272 train_accuracy: 0.466 

[2024-05-23 05:44:29] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.260 train_accuracy: 0.481 

[2024-05-23 05:44:39] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.245 train_accuracy: 0.503 



Testing: |          | 0/? [00:00<?, ?it/s]

[2024-05-23 06:22:52] INFO (torcheeg/MainThread) 
[Test] test_loss: 1.054 test_accuracy: 0.685 



Fold 5 test accuracy:  0.6852


Training: |          | 0/? [00:00<?, ?it/s]

[2024-05-23 06:23:05] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.385 train_accuracy: 0.269 

[2024-05-23 06:23:15] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.366 train_accuracy: 0.331 

[2024-05-23 06:23:24] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.349 train_accuracy: 0.359 

[2024-05-23 06:23:34] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.331 train_accuracy: 0.383 

[2024-05-23 06:23:44] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.311 train_accuracy: 0.426 

[2024-05-23 06:23:53] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.289 train_accuracy: 0.450 

[2024-05-23 06:24:03] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.272 train_accuracy: 0.484 

[2024-05-23 06:24:12] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.262 train_accuracy: 0.486 

[2024-05-23 06:24:22] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.245 train_accuracy: 0.507 

[2024-05-23 06:24:31] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.234 train_accuracy: 0.517 



Testing: |          | 0/? [00:00<?, ?it/s]

[2024-05-23 07:02:45] INFO (torcheeg/MainThread) 
[Test] test_loss: 1.224 test_accuracy: 0.498 



Fold 6 test accuracy:  0.4977


Training: |          | 0/? [00:00<?, ?it/s]

[2024-05-23 07:02:58] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.383 train_accuracy: 0.275 

[2024-05-23 07:03:08] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.366 train_accuracy: 0.334 

[2024-05-23 07:03:17] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.348 train_accuracy: 0.372 

[2024-05-23 07:03:27] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.329 train_accuracy: 0.396 

[2024-05-23 07:03:37] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.316 train_accuracy: 0.416 

[2024-05-23 07:03:46] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.298 train_accuracy: 0.459 

[2024-05-23 07:03:56] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.282 train_accuracy: 0.469 

[2024-05-23 07:04:05] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.268 train_accuracy: 0.481 

[2024-05-23 07:04:15] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.249 train_accuracy: 0.519 

[2024-05-23 07:04:24] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.238 train_accuracy: 0.523 



Testing: |          | 0/? [00:00<?, ?it/s]

[2024-05-23 07:42:40] INFO (torcheeg/MainThread) 
[Test] test_loss: 1.044 test_accuracy: 0.683 



Fold 7 test accuracy:  0.6829


Training: |          | 0/? [00:00<?, ?it/s]

[2024-05-23 07:42:53] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.384 train_accuracy: 0.277 

[2024-05-23 07:43:03] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.367 train_accuracy: 0.330 

[2024-05-23 07:43:12] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.350 train_accuracy: 0.367 

[2024-05-23 07:43:22] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.330 train_accuracy: 0.399 

[2024-05-23 07:43:31] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.309 train_accuracy: 0.438 

[2024-05-23 07:43:41] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.289 train_accuracy: 0.465 

[2024-05-23 07:43:51] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.271 train_accuracy: 0.486 

[2024-05-23 07:44:00] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.251 train_accuracy: 0.506 

[2024-05-23 07:44:10] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.243 train_accuracy: 0.511 

[2024-05-23 07:44:19] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.223 train_accuracy: 0.530 



Testing: |          | 0/? [00:00<?, ?it/s]

[2024-05-23 08:22:36] INFO (torcheeg/MainThread) 
[Test] test_loss: 1.382 test_accuracy: 0.354 



Fold 8 test accuracy:  0.3542


Training: |          | 0/? [00:00<?, ?it/s]

[2024-05-23 08:22:50] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.384 train_accuracy: 0.279 

[2024-05-23 08:22:59] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.361 train_accuracy: 0.338 

[2024-05-23 08:23:09] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.345 train_accuracy: 0.373 

[2024-05-23 08:23:18] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.326 train_accuracy: 0.404 

[2024-05-23 08:23:28] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.306 train_accuracy: 0.435 

[2024-05-23 08:23:38] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.289 train_accuracy: 0.462 

[2024-05-23 08:23:47] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.272 train_accuracy: 0.483 

[2024-05-23 08:23:57] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.256 train_accuracy: 0.502 

[2024-05-23 08:24:06] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.250 train_accuracy: 0.504 

[2024-05-23 08:24:16] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.233 train_accuracy: 0.525 



Testing: |          | 0/? [00:00<?, ?it/s]

[2024-05-23 09:02:36] INFO (torcheeg/MainThread) 
[Test] test_loss: 1.049 test_accuracy: 0.681 



Fold 9 test accuracy:  0.6806


In [9]:
!mkdir weights

In [10]:
atc_weights_path = './weights/atc_weights.pt'
torch.save(model.state_dict(), atc_weights_path)
# to load weight:
# model.load_state_dict(torch.load(atc_weights_path))