In [1]:
import os, yaml
from easydict import EasyDict

from datetime import datetime

import torch
from torch.utils.data import DataLoader

from sklearn.model_selection import KFold

from dataloader.bci_compet import get_dataset
from dataloader.bci_compet import BCICompet2aIV

from model.litmodel import LitModel
from model.cat_conditioned import CatConditioned
from pytorch_lightning.loggers import TensorBoardLogger

from pytorch_lightning import Trainer, seed_everything


from utils.setup_utils import (
    get_device,
    get_log_name,
)
from utils.training_utils import get_callbacks

%load_ext autoreload
%autoreload 2

  warn('datautil.preprocess module is deprecated and is now under '


In [2]:
config_name = 'bcicompet2a_config'

with open(f'configs/{config_name}.yaml') as file:
    config = yaml.load(file, Loader=yaml.FullLoader)
    args = EasyDict(config)


#### Set Log ####
args['current_time'] = datetime.now().strftime('%Y%m%d')
args['LOG_NAME'] = get_log_name(args)

#### Update configs ####
args.lr = float(args.lr)
if args.downsampling != 0: args['sampling_rate'] = args.downsampling

seed_everything(args.SEED)


Seed set to 42


LOG >>> Log name: 
	20231127_task_BCICompet2a_batch_64_lr_2e-3_Baseline


42

In [3]:
datasets = {}
for subject_id in range(0,1):
    args['target_subject'] = subject_id
    datasets[subject_id] = BCICompet2aIV(args, return_subject_id=True)

  0%|          | 0/9 [00:00<?, ?it/s]

LOG >>> Filename: Datasets/BCI_Competition_IV_2a\A01T.gdf
Extracting EDF parameters from c:\Code\m-shallowconvnet\Datasets\BCI_Competition_IV_2a\A01T.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG, EEG-C3, EEG, EEG-Cz, EEG, EEG-C4, EEG, EEG, EEG, EEG, EEG, EEG, EEG, EEG-Pz, EEG, EEG, EOG-left, EOG-central, EOG-right
Creating raw.info structure...
Reading 0 ... 672527  =      0.000 ...  2690.108 secs...
Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '769', '770', '771', '772']
Filtering raw data in 1 contiguous segment
Setting up low-pass filter at 38 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal lowpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Upper passband edge: 38.00 Hz
- Upper transit

[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.4s


Not setting metadata
288 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 288 events and 751 original time points ...
0 bad epochs dropped


100%|██████████| 9/9 [00:01<00:00,  4.89it/s]


In [4]:
for subject_id in datasets.keys(): 
    print(f"Subject {subject_id} has {len(datasets[subject_id])} trials")

Subject 0 has 288 trials


In [5]:
train_size = 240
val_size = 48


train_datasets = {}
val_datasets = {}
for subject_id in datasets.keys():
    train_datasets[subject_id] = torch.utils.data.Subset(datasets[subject_id], range(train_size))
    val_datasets[subject_id] = torch.utils.data.Subset(datasets[subject_id], range(train_size, train_size+val_size))


train_dataset_all = torch.utils.data.ConcatDataset(list(train_datasets.values()))
val_dataset_all = torch.utils.data.ConcatDataset(list(val_datasets.values()))
len(train_dataset_all), len(val_dataset_all)

(240, 48)

In [6]:
train_dataloader_all = DataLoader(train_dataset_all, batch_size=32, shuffle=True)
val_dataloader_all = DataLoader(val_dataset_all, batch_size=32, shuffle=True)

In [7]:
model = CatConditioned(args, n_subjects=9, subject_filters=16, final_features=4, n_classes=args['num_classes'] )
lit_model = LitModel(args, model)


subject_filters eeg_dim
16 744


In [8]:
import numpy as np
eeg_input = torch.from_numpy(np.array([train_dataset_all[0]['data']]))
subject_id =  torch.from_numpy(np.array([train_dataset_all[0]['subject_id']]))
print(type(eeg_input), type(subject_id))
print(eeg_input.shape, subject_id.shape)
print(eeg_input.dtype, subject_id.dtype)

<class 'torch.Tensor'> <class 'torch.Tensor'>
torch.Size([1, 1, 22, 751]) torch.Size([1])
torch.float64 torch.int32


In [9]:
model(eeg_input, subject_id)

eeg_features subject_features
torch.Size([1, 744]) torch.Size([1, 16])
torch.Size([1, 760])
Linear(in_features=760, out_features=4, bias=True)


tensor([[ 0.1679, -1.4003, -0.3376, -0.1858]], grad_fn=<AddmmBackward0>)

In [10]:
logger = TensorBoardLogger(args.LOG_PATH, 
                                    name='my_train_sub0_cat_cond')

In [11]:
callbacks = get_callbacks(monitor='val_loss', args=args)

In [12]:
trainer = Trainer(
            max_epochs=250,
            callbacks=callbacks,
            default_root_dir=args.CKPT_PATH,
            logger=logger,
        )

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [14]:
trainer.fit(lit_model,
            train_dataloaders=train_dataloader_all,
            val_dataloaders=val_dataloader_all)
        
torch.cuda.empty_cache()


  | Name      | Type             | Params
-----------------------------------------------
0 | model     | CatConditioned   | 30.4 K
1 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
30.4 K    Trainable params
0         Non-trainable params
30.4 K    Total params
0.122     Total estimated model params size (MB)


Epoch 0: 100%|██████████| 8/8 [02:23<00:00,  0.06it/s, v_num=5, val_loss=1.400, val_acc=0.292, train_loss=1.440, train_acc=0.304]
Epoch 249: 100%|██████████| 8/8 [00:00<00:00, 12.04it/s, v_num=5, val_loss=0.817, val_acc=0.771, train_loss=0.626, train_acc=0.846]

`Trainer.fit` stopped: `max_epochs=250` reached.


Epoch 249: 100%|██████████| 8/8 [00:00<00:00, 11.92it/s, v_num=5, val_loss=0.817, val_acc=0.771, train_loss=0.626, train_acc=0.846]
