In [1]:
# if run on colab
#!pip install torcheeg

In [1]:
import torcheeg
from torcheeg import transforms
from torcheeg.datasets import BCICIV2aDataset

In [2]:
dataset = BCICIV2aDataset(
    root_path='./BCICIV-2a-mat/',
    io_path=f'./examples_pipeline/bciciv-2a',
    skip_trial_with_artifacts=True,
    # offline_transform=transforms.Compose([
    #     transforms.BandDifferentialEntropy(apply_to_baseline=True),
    #     transforms.To2d(apply_to_baseline=True),
    #     transforms.ToTensor(apply_to_baseline=True)
    # ]),
    online_transform=transforms.Compose([
        # transforms.To2d(apply_to_baseline=True),
        # transforms.ToTensor(apply_to_baseline=True),
        transforms.To2d(),
        transforms.ToTensor(),
        # transforms.CWTSpectrum(apply_to_baseline=True),
        # transforms.BandDifferentialEntropy(apply_to_baseline=True),
        # transforms.BaselineRemoval(),
    ]),
    label_transform=transforms.Compose([
        transforms.Select('label'),
        transforms.Lambda(lambda x: x - 1)
    ]),
    chunk_size=1750,
    num_worker=2
)

[2024-05-13 15:20:23] INFO (torcheeg/MainThread) 🔍 | Processing EEG data. Processed EEG data has been cached to [92m./examples_pipeline/bciciv-2a[0m.
[2024-05-13 15:20:23] INFO (torcheeg/MainThread) ⏳ | Monitoring the detailed processing of a record for debugging. The processing of other records will only be reported in percentage to keep it clean.
[PROCESS]:   0%|                                       | 0/18 [00:00<?, ?it/s]
[RECORD ./BCICIV-2a-mat/A06E.mat]: 0it [00:00, ?it/s][A
[RECORD ./BCICIV-2a-mat/A06E.mat]: 1it [00:00,  2.15it/s][A
[RECORD ./BCICIV-2a-mat/A06E.mat]: 48it [00:00, 111.86it/s][A
[RECORD ./BCICIV-2a-mat/A06E.mat]: 86it [00:00, 178.09it/s][A
[RECORD ./BCICIV-2a-mat/A06E.mat]: 123it [00:00, 228.51it/s][A
[RECORD ./BCICIV-2a-mat/A06E.mat]: 157it [00:00, 253.52it/s][A
[RECORD ./BCICIV-2a-mat/A06E.mat]: 195it [00:00, 287.36it/s][A
[PROCESS]: 100%|██████████████████████████████| 18/18 [00:10<00:00,  1.66it/s]
[2024-05-13 15:20:36] INFO (torcheeg/MainThread) ✅ | 

In [3]:
print("Dataset's info: ")
print(dataset.info)

Dataset's info: 
      start_at  end_at   clip_id subject_id  trial_id session subject  run  \
0          251    2001    A06E_0        A06         0       E     A06    3   
1         2254    4004    A06E_1        A06         1       E     A06    3   
2         4172    5922    A06E_2        A06         2       E     A06    3   
3         6124    7874    A06E_3        A06         3       E     A06    3   
4        10243   11993    A06E_4        A06         5       E     A06    3   
...        ...     ...       ...        ...       ...     ...     ...  ...   
4691     86751   88501  A05T_257        A05        43       T     A05    8   
4692     88657   90407  A05T_258        A05        44       T     A05    8   
4693     90585   92335  A05T_259        A05        45       T     A05    8   
4694     92699   94449  A05T_260        A05        46       T     A05    8   
4695     94758   96508  A05T_261        A05        47       T     A05    8   

      label  _record_id  
0         1   _recor

In [4]:
from torcheeg.model_selection import KFoldGroupbyTrial

k_fold = KFoldGroupbyTrial(
    n_splits=10,
    split_path='./examples_pipeline/split',
    shuffle=True,
    random_state=44
)

In [5]:
from torch.utils.data import DataLoader
from torcheeg.models import ATCNet, EEGNet

from torcheeg.trainers import ClassifierTrainer

import pytorch_lightning as pl

for i, (train_dataset, val_dataset) in enumerate(k_fold.split(dataset)):
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=64,
        shuffle=True
    )
    val_loader = DataLoader(
        dataset=val_dataset,
        batch_size=64,
        shuffle=False
    )

    # model = ATCNet(
    #     num_classes=4,
    #     num_electrodes=22,
    #     chunk_size=128,
    #     in_channels=22,
    #     num_windows=3
    # )
    model = EEGNet(
        chunk_size=1750,
        num_electrodes=22,
        num_classes=4,
        dropout=0.5
    )

    trainer = ClassifierTrainer(
        model=model,
        num_classes=4,
        lr=1e-4,
        weight_decay=1e-4,
        accelerator='gpu'
    )

    trainer.fit(
        train_loader,
        val_loader, 
        max_epochs=100,
        default_root_dir=f'./examples_pipeline/model/{i}',
        callbacks=[pl.callbacks.ModelCheckpoint(save_last=True)],
        enable_progress_bar=True,
        enable_model_summary=True,
        limit_val_batches=0.0
    )

    score = trainer.test(
        val_loader,
        enable_progress_bar=True,
        enable_model_summary=True
    )[0]
    print(f"Fold {i} test accuracy: {score['test_accuracy']: .4f}")

[2024-05-13 15:21:11] INFO (torcheeg/MainThread) 📊 | Create the split of train and test set.
[2024-05-13 15:21:11] INFO (torcheeg/MainThread) 😊 | Please set [92msplit_path[0m to [92m./examples_pipeline/split[0m for the next run, if you want to use the same setting for the experiment.


ValueError: Cannot have number of splits n_splits=10 greater than the number of samples: n_samples=4.