<a href="https://colab.research.google.com/github/nmq443/cognitive-science-final-project/blob/quang-branch/torcheeg_atcnet-with_preprocessed_data_cwt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# if run on colab
!pip install torcheeg
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
# if run on kaggle
!pip install torcheeg

Collecting torcheeg
  Downloading torcheeg-1.1.2.tar.gz (214 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m214.5/214.5 kB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l- done
Collecting xlrd>=2.0.1 (from torcheeg)
  Downloading xlrd-2.0.1-py2.py3-none-any.whl.metadata (3.4 kB)
Collecting lmdb>=1.3.0 (from torcheeg)
  Downloading lmdb-1.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.2 kB)
Collecting einops>=0.4.1 (from torcheeg)
  Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)
Collecting xmltodict>=0.13.0 (from torcheeg)
  Downloading xmltodict-0.13.0-py2.py3-none-any.whl.metadata (7.7 kB)
Collecting spectrum>=0.8.1 (from torcheeg)
  Downloading spectrum-0.8.1.tar.gz (230 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m230.8/230.8 kB[0m [31m15.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l- done
Collecting mne

In [2]:
import torcheeg
from torcheeg import transforms
from torcheeg.datasets import BCICIV2aDataset
from torcheeg.model_selection import KFoldGroupbyTrial
from torch.utils.data import DataLoader
from torcheeg.models import ATCNet, EEGNet
import torch

from torcheeg.trainers import ClassifierTrainer

import pytorch_lightning as pl

In [3]:
# if run on kaggle
# root_data_path = '/kaggle/input/bci-competition-iv-dataset-2a-in-mat-format/BCICIV-2a-mat'

# if run on colab
root_data_path = '/content/drive/MyDrive/BCICIV-2a-mat'

# if run on local machine
# root_data_path = './BCICIV-2a-mat/'

In [4]:
dataset = BCICIV2aDataset(
    root_path=root_data_path,
    io_path=f'./examples_pipeline/bciciv-2a',
    online_transform=transforms.Compose([
        transforms.CWTSpectrum(),
        transforms.ToTensor(),
    ]),
    label_transform=transforms.Compose([
        transforms.Select('label'),
        transforms.Lambda(lambda x: x - 1)
    ]),
    chunk_size=7*250,
    num_worker=2
)

[2024-05-21 11:44:35] INFO (torcheeg/MainThread) 🔍 | Processing EEG data. Processed EEG data has been cached to [92m./examples_pipeline/bciciv-2a[0m.
INFO:torcheeg:🔍 | Processing EEG data. Processed EEG data has been cached to [92m./examples_pipeline/bciciv-2a[0m.
[2024-05-21 11:44:35] INFO (torcheeg/MainThread) ⏳ | Monitoring the detailed processing of a record for debugging. The processing of other records will only be reported in percentage to keep it clean.
INFO:torcheeg:⏳ | Monitoring the detailed processing of a record for debugging. The processing of other records will only be reported in percentage to keep it clean.
  pid = os.fork()
[PROCESS]: 100%|██████████| 18/18 [01:08<00:00,  3.81s/it]
[2024-05-21 11:45:50] INFO (torcheeg/MainThread) ✅ | All processed EEG data has been cached to ./examples_pipeline/bciciv-2a.
INFO:torcheeg:✅ | All processed EEG data has been cached to ./examples_pipeline/bciciv-2a.
[2024-05-21 11:45:50] INFO (torcheeg/MainThread) 😊 | Please set [92mi

In [5]:
print("Dataset's info: ")
print(dataset.info)

Dataset's info: 
      start_at  end_at   clip_id subject_id  trial_id session subject  run  \
0          251    2001    A04E_0        A04         0       E     A04    3   
1         2254    4004    A04E_1        A04         1       E     A04    3   
2         4172    5922    A04E_2        A04         2       E     A04    3   
3         6124    7874    A04E_3        A04         3       E     A04    3   
4         8132    9882    A04E_4        A04         4       E     A04    3   
...        ...     ...       ...        ...       ...     ...     ...  ...   
5179     86751   88501  A08T_283        A08        43       T     A08    8   
5180     88657   90407  A08T_284        A08        44       T     A08    8   
5181     90585   92335  A08T_285        A08        45       T     A08    8   
5182     92699   94449  A08T_286        A08        46       T     A08    8   
5183     94758   96508  A08T_287        A08        47       T     A08    8   

      label  _record_id  
0         4   _recor

In [6]:
k_fold = KFoldGroupbyTrial(
    n_splits=10,
    split_path='./examples_pipeline/split',
    shuffle=True,
    random_state=44
)

In [None]:
DEVICE = 'gpu' if torch.cuda.is_available() else 'cpu'

for i, (train_dataset, val_dataset) in enumerate(k_fold.split(dataset)):
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=64,
        shuffle=True,
        num_workers=8
    )
    val_loader = DataLoader(
        dataset=val_dataset,
        batch_size=64,
        shuffle=False,
        num_workers=8
    )

    model = ATCNet(
        num_classes=4,
        num_electrodes=22,
        in_channels=22,
        chunk_size=7*250,
    )

    trainer = ClassifierTrainer(
        model=model,
        num_classes=4,
        lr=1e-4,
        weight_decay=1e-4,
        accelerator=DEVICE
    )

    trainer.fit(
        train_loader,
        val_loader,
        max_epochs=50,
        default_root_dir=f'./examples_pipeline/atcnet_model/{i}',
        callbacks=[pl.callbacks.ModelCheckpoint(save_last=True)],
        enable_progress_bar=True,
        enable_model_summary=True,
        limit_val_batches=0.0
    )

    score = trainer.test(
        val_loader,
        enable_progress_bar=True,
        enable_model_summary=True
    )[0]
    print(f"Fold {i} test accuracy: {score['test_accuracy']: .4f}")

[2024-05-21 11:54:58] INFO (torcheeg/MainThread) 📊 | Detected existing split of train and test set, use existing split from ./examples_pipeline/split.
INFO:torcheeg:📊 | Detected existing split of train and test set, use existing split from ./examples_pipeline/split.
[2024-05-21 11:54:58] INFO (torcheeg/MainThread) 💡 | If the dataset is re-generated, you need to re-generate the split of the dataset instead of using the previous split.
INFO:torcheeg:💡 | If the dataset is re-generated, you need to re-generate the split of the dataset instead of using the previous split.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.callbacks.model_summary:
  | Name          | Type             | Params
-------------

Training: |          | 0/? [00:00<?, ?it/s]

In [None]:
!mkdir weights

In [None]:
atc_weights_path = './weights/atc_weights.pt'
torch.save(model.state_dict(), atc_weights_path)
# to load weight:
# model.load_state_dict(torch.load(atc_weights_path))