<a href="https://colab.research.google.com/github/nmq443/cognitive-science-final-project/blob/quang-branch/torcheeg_atcnet-with_preprocessed_data_cwt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
# if run on colab
"""
!pip install torcheeg
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
"""

"\n!pip install torcheeg\nfrom google.colab import drive\ndrive.mount('/content/drive', force_remount=True)\n"

In [3]:
# if run on kaggle
!pip install torcheeg

Collecting torcheeg
  Downloading torcheeg-1.1.2.tar.gz (214 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m214.5/214.5 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
Collecting xlrd>=2.0.1 (from torcheeg)
  Downloading xlrd-2.0.1-py2.py3-none-any.whl.metadata (3.4 kB)
Collecting lmdb>=1.3.0 (from torcheeg)
  Downloading lmdb-1.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.2 kB)
Collecting einops>=0.4.1 (from torcheeg)
  Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)
Collecting xmltodict>=0.13.0 (from torcheeg)
  Downloading xmltodict-0.13.0-py2.py3-none-any.whl.metadata (7.7 kB)
Collecting spectrum>=0.8.1 (from torcheeg)
  Downloading spectrum-0.8.1.tar.gz (230 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m230.8/230.8 kB[0m [31m14.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
Collecting mne_conne

In [1]:
import torcheeg
from torcheeg import transforms
from torcheeg.datasets import BCICIV2aDataset
from torcheeg.model_selection import KFoldGroupbyTrial
from torch.utils.data import DataLoader
from torcheeg.models import ATCNet, EEGNet
import torch

from torcheeg.trainers import ClassifierTrainer

import pytorch_lightning as pl

In [2]:
# if run on kaggle
# root_data_path = '/kaggle/input/bci-competition-iv-dataset-2a-in-mat-format/BCICIV-2a-mat'
# if run on colab
# root_data_path = '/content/drive/MyDrive/BCICIV-2a-mat'

# if run on local machine
root_data_path = './BCICIV-2a-mat/'

In [4]:
dataset = BCICIV2aDataset(
    root_path=root_data_path,
    io_path=f'./examples_pipeline/bciciv-2a',
    online_transform=transforms.Compose([
        transforms.To2d(),
        transforms.MinMaxNormalize(),
        transforms.ToTensor(),
        transforms.RandomNoise(p=0.3),
    ]),
    label_transform=transforms.Compose([
        transforms.Select('label'),
        transforms.Lambda(lambda x: x - 1)
    ]),
    chunk_size=7*250,
    num_worker=2
)

[2024-05-22 12:09:42] INFO (torcheeg/MainThread) 🔍 | Detected cached processing results, reading cache from ./examples_pipeline/bciciv-2a.


In [5]:
print("Dataset's info: ")
print(dataset.info)

Dataset's info: 
      start_at  end_at   clip_id subject_id  trial_id session subject  run  \
0          251    2001    A02E_0        A02         0       E     A02    3   
1         2254    4004    A02E_1        A02         1       E     A02    3   
2         4172    5922    A02E_2        A02         2       E     A02    3   
3         6124    7874    A02E_3        A02         3       E     A02    3   
4         8132    9882    A02E_4        A02         4       E     A02    3   
...        ...     ...       ...        ...       ...     ...     ...  ...   
5179     86751   88501  A07T_283        A07        43       T     A07    8   
5180     88657   90407  A07T_284        A07        44       T     A07    8   
5181     90585   92335  A07T_285        A07        45       T     A07    8   
5182     92699   94449  A07T_286        A07        46       T     A07    8   
5183     94758   96508  A07T_287        A07        47       T     A07    8   

      label  _record_id  
0         1   _recor

In [6]:
k_fold = KFoldGroupbyTrial(
    n_splits=10,
    split_path='./examples_pipeline/split',
    shuffle=True,
    random_state=44
)

In [7]:
DEVICE = 'gpu' if torch.cuda.is_available() else 'cpu'

for i, (train_dataset, val_dataset) in enumerate(k_fold.split(dataset)):
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=64,
        shuffle=True,
        num_workers=8
    )
    val_loader = DataLoader(
        dataset=val_dataset,
        batch_size=64,
        shuffle=False,
        num_workers=8
    )

    model = ATCNet(
        num_classes=4,
        num_electrodes=22,
        in_channels=1,
        chunk_size=7*250,
    )

    trainer = ClassifierTrainer(
        model=model,
        num_classes=4,
        lr=1e-4,
        weight_decay=1e-4,
        accelerator=DEVICE
    )

    trainer.fit(
        train_loader,
        val_loader,
        max_epochs=50,
        default_root_dir=f'./examples_pipeline/atcnet_model/{i}',
        callbacks=[pl.callbacks.ModelCheckpoint(save_last=True)],
        enable_progress_bar=True,
        enable_model_summary=True,
        limit_val_batches=0.0
    )

    score = trainer.test(
        val_loader,
        enable_progress_bar=True,
        enable_model_summary=True
    )[0]
    print(f"Fold {i} test accuracy: {score['test_accuracy']: .4f}")

[2024-05-22 12:09:42] INFO (torcheeg/MainThread) 📊 | Create the split of train and test set.
[2024-05-22 12:09:42] INFO (torcheeg/MainThread) 😊 | Please set [92msplit_path[0m to [92m./examples_pipeline/split[0m for the next run, if you want to use the same setting for the experiment.
  return F.conv2d(input, weight, bias, self.stride,
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: examples_pipeline/atcnet_model/0/lightning_logs
2024-05-22 12:09:44.705399: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type         

Training: |                                            | 0/? [00:00<?, ?it/s]

  return F.conv2d(input, weight, bias, self.stride,
  return F.conv1d(input, weight, bias, self.stride,
[2024-05-22 12:10:05] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.391 train_accuracy: 0.237 

[2024-05-22 12:10:24] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.390 train_accuracy: 0.256 

[2024-05-22 12:10:41] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.387 train_accuracy: 0.257 

[2024-05-22 12:10:58] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.386 train_accuracy: 0.279 

[2024-05-22 12:11:16] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.383 train_accuracy: 0.268 

[2024-05-22 12:11:34] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.380 train_accuracy: 0.286 

[2024-05-22 12:11:53] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.376 train_accuracy: 0.306 

[2024-05-22 12:12:11] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.371 train_accuracy: 0.317 

[2024-05-22 12:12:29] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.363 train_accuracy: 0.33

Testing: |                                             | 0/? [00:00<?, ?it/s]

[2024-05-22 12:24:31] INFO (torcheeg/MainThread) 
[Test] test_loss: 1.254 test_accuracy: 0.468 



────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test_accuracy         0.46759259700775146
        test_loss           1.2540966272354126
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Fold 0 test accuracy:  0.4676


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: examples_pipeline/atcnet_model/1/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type             | Params
---------------------------------------------------
0 | model         | ATCNet           | 88.7 K
1 | ce_fn         | CrossEntropyLoss | 0     
2 | train_loss    | MeanMetric       | 0     
3 | val_loss      | MeanMetric       | 0     
4 | test_loss     | MeanMetric       | 0     
5 | train_metrics | MetricCollection | 0     
6 | val_metrics   | MetricCollection | 0     
7 | test_metrics  | MetricCollection | 0     
---------------------------------------------------
88.7 K    Trainable params
0         Non-trainable params
88.7 K    Total params
0.355     Total estimated model params size (MB)


Training: |                                            | 0/? [00:00<?, ?it/s]

[2024-05-22 12:24:50] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.391 train_accuracy: 0.259 

[2024-05-22 12:25:08] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.388 train_accuracy: 0.253 

[2024-05-22 12:25:27] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.386 train_accuracy: 0.260 

[2024-05-22 12:25:47] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.381 train_accuracy: 0.281 

[2024-05-22 12:26:05] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.378 train_accuracy: 0.292 

[2024-05-22 12:26:22] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.373 train_accuracy: 0.307 

[2024-05-22 12:26:40] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.369 train_accuracy: 0.319 

[2024-05-22 12:26:58] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.363 train_accuracy: 0.328 

[2024-05-22 12:27:17] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.357 train_accuracy: 0.338 

[2024-05-22 12:27:36] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.355 train_accuracy: 0.334 



Testing: |                                             | 0/? [00:00<?, ?it/s]

[2024-05-22 12:40:11] INFO (torcheeg/MainThread) 
[Test] test_loss: 1.230 test_accuracy: 0.497 



────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test_accuracy         0.4965277910232544
        test_loss           1.2304927110671997
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Fold 1 test accuracy:  0.4965


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: examples_pipeline/atcnet_model/2/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type             | Params
---------------------------------------------------
0 | model         | ATCNet           | 88.7 K
1 | ce_fn         | CrossEntropyLoss | 0     
2 | train_loss    | MeanMetric       | 0     
3 | val_loss      | MeanMetric       | 0     
4 | test_loss     | MeanMetric       | 0     
5 | train_metrics | MetricCollection | 0     
6 | val_metrics   | MetricCollection | 0     
7 | test_metrics  | MetricCollection | 0     
---------------------------------------------------
88.7 K    Trainable params
0         Non-trainable params
88.7 K    Total params
0.355     Total estimated model params size (MB)


Training: |                                            | 0/? [00:00<?, ?it/s]

  return F.conv1d(input, weight, bias, self.stride,
[2024-05-22 12:40:30] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.388 train_accuracy: 0.256 

[2024-05-22 12:40:50] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.388 train_accuracy: 0.248 

[2024-05-22 12:41:11] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.382 train_accuracy: 0.274 

[2024-05-22 12:41:32] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.381 train_accuracy: 0.292 

[2024-05-22 12:41:52] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.377 train_accuracy: 0.296 

[2024-05-22 12:42:13] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.369 train_accuracy: 0.312 

[2024-05-22 12:42:34] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.364 train_accuracy: 0.332 

[2024-05-22 12:42:55] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.358 train_accuracy: 0.335 

[2024-05-22 12:43:15] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.352 train_accuracy: 0.344 

[2024-05-22 12:43:36] INFO (torcheeg/MainThread)

Testing: |                                             | 0/? [00:00<?, ?it/s]

  return F.conv2d(input, weight, bias, self.stride,
[2024-05-22 12:57:21] INFO (torcheeg/MainThread) 
[Test] test_loss: 1.230 test_accuracy: 0.500 



────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test_accuracy                 0.5
        test_loss           1.2298835515975952
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Fold 2 test accuracy:  0.5000


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: examples_pipeline/atcnet_model/3/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type             | Params
---------------------------------------------------
0 | model         | ATCNet           | 88.7 K
1 | ce_fn         | CrossEntropyLoss | 0     
2 | train_loss    | MeanMetric       | 0     
3 | val_loss      | MeanMetric       | 0     
4 | test_loss     | MeanMetric       | 0     
5 | train_metrics | MetricCollection | 0     
6 | val_metrics   | MetricCollection | 0     
7 | test_metrics  | MetricCollection | 0     
---------------------------------------------------
88.7 K    Trainable params
0         Non-trainable params
88.7 K    Total params
0.355     Total estimated model params size (MB)


Training: |                                            | 0/? [00:00<?, ?it/s]

[2024-05-22 12:57:42] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.389 train_accuracy: 0.253 

[2024-05-22 12:58:02] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.388 train_accuracy: 0.260 

[2024-05-22 12:58:23] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.384 train_accuracy: 0.271 

[2024-05-22 12:58:44] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.384 train_accuracy: 0.277 

[2024-05-22 12:59:04] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.379 train_accuracy: 0.297 

[2024-05-22 12:59:24] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.373 train_accuracy: 0.306 

[2024-05-22 12:59:45] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.365 train_accuracy: 0.326 

[2024-05-22 13:00:05] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.359 train_accuracy: 0.332 

[2024-05-22 13:00:24] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.354 train_accuracy: 0.346 

[2024-05-22 13:00:43] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.347 train_accuracy: 0.351 



Testing: |                                             | 0/? [00:00<?, ?it/s]

[2024-05-22 13:14:18] INFO (torcheeg/MainThread) 
[Test] test_loss: 1.259 test_accuracy: 0.440 



────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test_accuracy         0.43981480598449707
        test_loss           1.2591882944107056
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Fold 3 test accuracy:  0.4398


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: examples_pipeline/atcnet_model/4/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type             | Params
---------------------------------------------------
0 | model         | ATCNet           | 88.7 K
1 | ce_fn         | CrossEntropyLoss | 0     
2 | train_loss    | MeanMetric       | 0     
3 | val_loss      | MeanMetric       | 0     
4 | test_loss     | MeanMetric       | 0     
5 | train_metrics | MetricCollection | 0     
6 | val_metrics   | MetricCollection | 0     
7 | test_metrics  | MetricCollection | 0     
---------------------------------------------------
88.7 K    Trainable params
0         Non-trainable params
88.7 K    Total params
0.355     Total estimated model params size (MB)


Training: |                                            | 0/? [00:00<?, ?it/s]

[2024-05-22 13:14:39] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.392 train_accuracy: 0.242 

[2024-05-22 13:14:59] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.385 train_accuracy: 0.269 

[2024-05-22 13:15:19] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.383 train_accuracy: 0.269 

[2024-05-22 13:15:40] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.381 train_accuracy: 0.290 

[2024-05-22 13:16:00] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.377 train_accuracy: 0.294 

[2024-05-22 13:16:21] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.373 train_accuracy: 0.304 

[2024-05-22 13:16:41] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.365 train_accuracy: 0.326 

[2024-05-22 13:17:02] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.359 train_accuracy: 0.328 

[2024-05-22 13:17:22] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.353 train_accuracy: 0.347 

[2024-05-22 13:17:43] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.349 train_accuracy: 0.349 



Testing: |                                             | 0/? [00:00<?, ?it/s]

[2024-05-22 13:31:21] INFO (torcheeg/MainThread) 
[Test] test_loss: 1.272 test_accuracy: 0.454 



────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test_accuracy         0.45370370149612427
        test_loss           1.2720471620559692
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Fold 4 test accuracy:  0.4537


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: examples_pipeline/atcnet_model/5/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type             | Params
---------------------------------------------------
0 | model         | ATCNet           | 88.7 K
1 | ce_fn         | CrossEntropyLoss | 0     
2 | train_loss    | MeanMetric       | 0     
3 | val_loss      | MeanMetric       | 0     
4 | test_loss     | MeanMetric       | 0     
5 | train_metrics | MetricCollection | 0     
6 | val_metrics   | MetricCollection | 0     
7 | test_metrics  | MetricCollection | 0     
---------------------------------------------------
88.7 K    Trainable params
0         Non-trainable params
88.7 K    Total params
0.355     Total estimated model params size (MB)


Training: |                                            | 0/? [00:00<?, ?it/s]

[2024-05-22 13:31:41] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.391 train_accuracy: 0.245 

[2024-05-22 13:32:02] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.386 train_accuracy: 0.262 

[2024-05-22 13:32:22] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.387 train_accuracy: 0.266 

[2024-05-22 13:32:42] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.384 train_accuracy: 0.265 

[2024-05-22 13:33:03] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.380 train_accuracy: 0.284 

[2024-05-22 13:33:23] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.375 train_accuracy: 0.301 

[2024-05-22 13:33:44] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.370 train_accuracy: 0.308 

[2024-05-22 13:34:04] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.365 train_accuracy: 0.327 

[2024-05-22 13:34:25] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.361 train_accuracy: 0.334 

[2024-05-22 13:34:45] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.352 train_accuracy: 0.354 



Testing: |                                             | 0/? [00:00<?, ?it/s]

[2024-05-22 13:51:43] INFO (torcheeg/MainThread) 
[Test] test_loss: 1.231 test_accuracy: 0.475 



────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test_accuracy         0.47453704476356506
        test_loss            1.231122612953186
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Fold 5 test accuracy:  0.4745


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: examples_pipeline/atcnet_model/6/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type             | Params
---------------------------------------------------
0 | model         | ATCNet           | 88.7 K
1 | ce_fn         | CrossEntropyLoss | 0     
2 | train_loss    | MeanMetric       | 0     
3 | val_loss      | MeanMetric       | 0     
4 | test_loss     | MeanMetric       | 0     
5 | train_metrics | MetricCollection | 0     
6 | val_metrics   | MetricCollection | 0     
7 | test_metrics  | MetricCollection | 0     
---------------------------------------------------
88.7 K    Trainable params
0         Non-trainable params
88.7 K    Total params
0.355     Total estimated model params size (MB)


Training: |                                            | 0/? [00:00<?, ?it/s]

[2024-05-22 13:52:20] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.390 train_accuracy: 0.252 

[2024-05-22 13:52:55] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.385 train_accuracy: 0.275 

[2024-05-22 13:53:31] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.385 train_accuracy: 0.280 

[2024-05-22 13:54:06] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.378 train_accuracy: 0.293 

[2024-05-22 13:54:42] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.377 train_accuracy: 0.296 

[2024-05-22 13:55:18] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.371 train_accuracy: 0.318 

[2024-05-22 13:55:53] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.368 train_accuracy: 0.316 

[2024-05-22 13:56:28] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.363 train_accuracy: 0.322 

[2024-05-22 13:57:04] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.360 train_accuracy: 0.332 

[2024-05-22 13:57:39] INFO (torcheeg/MainThread) 
[Train] train_loss: 1.357 train_accuracy: 0.341 



Testing: |                                             | 0/? [00:00<?, ?it/s]


KeyboardInterrupt



In [None]:
!mkdir weights

In [None]:
atc_weights_path = './weights/atc_weights.pt'
torch.save(model.state_dict(), atc_weights_path)
# to load weight:
# model.load_state_dict(torch.load(atc_weights_path))