In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
import os, yaml
from easydict import EasyDict
import pickle
from datetime import datetime

import torch
from torch.utils.data import DataLoader

from sklearn.model_selection import KFold

from dataloader.bci_compet import get_dataset
from dataloader.bci_compet import BCICompet2aIV

from model.litmodel import LitModel
from model.cat_conditioned import CatConditioned
from pytorch_lightning.loggers import TensorBoardLogger

from pytorch_lightning import Trainer, seed_everything


from utils.setup_utils import (
    get_device,
    get_log_name,
)
from utils.training_utils import get_callbacks

torch.set_float32_matmul_precision('medium')

%load_ext autoreload
%autoreload 2

In [None]:
CACHE_ROOT = 'cache'

config_name = 'bcicompet2a_config'

with open(f'configs/{config_name}.yaml') as file:
    config = yaml.load(file, Loader=yaml.FullLoader)
    args = EasyDict(config)


In [None]:
def load_dataset(args, return_subject_id=False):
    datasets = {}
    for subject_id in range(0,9):
        args['target_subject'] = subject_id
        datasets[subject_id] = BCICompet2aIV(args)
    return datasets

path = os.path.join(CACHE_ROOT, f'{config_name}_base.pkl')

if not os.path.isfile(path):
    print('Cache miss, generating cache')
    datasets = load_dataset(args)
    with open(path, 'wb') as file:
        pickle.dump(datasets, file)
else:
    print('Loading cache')
    with open(path, 'rb') as file:
        datasets = pickle.load(file)

In [None]:
for subject_id in datasets.keys(): 
    print(f"Subject {subject_id} has {len(datasets[subject_id])} trials")

In [None]:
for subject_id in datasets.keys(): 
    datasets[subject_id].return_subject_info = 'id'

In [None]:
name = 'all_cat_cond'

args.VERSION = f'{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}-{name}'


#### Set Log ####
args['current_time'] = datetime.now().strftime('%Y%m%d')
args['LOG_NAME'] = get_log_name(args)

#### Update configs ####
if args.downsampling != 0: args['sampling_rate'] = args.downsampling
seed_everything(args.SEED)


In [None]:
train_size = 240
val_size = 48


train_datasets = {}
val_datasets = {}
for subject_id in datasets.keys():
    train_datasets[subject_id] = torch.utils.data.Subset(datasets[subject_id], range(train_size))
    val_datasets[subject_id] = torch.utils.data.Subset(datasets[subject_id], range(train_size, train_size+val_size))


train_dataset_all = torch.utils.data.ConcatDataset(list(train_datasets.values()))
val_dataset_all = torch.utils.data.ConcatDataset(list(val_datasets.values()))
len(train_dataset_all), len(val_dataset_all)

In [None]:
train_dataloader_all = DataLoader(train_dataset_all, batch_size=args['batch_size'], shuffle=True, num_workers=0, persistent_workers=False)
val_dataloader_all = DataLoader(val_dataset_all, batch_size=args['batch_size'], shuffle=False, num_workers=0, persistent_workers=False)

In [None]:
model = CatConditioned(args, num_subjects=9,  embedding_dimension=16, combined_features_dimension=4, num_classes=args['num_classes'] )
lit_model = LitModel(args, model)

In [None]:
logger = TensorBoardLogger(args.LOG_PATH, 
                                    name=args.VERSION)

In [None]:
callbacks = get_callbacks(monitor='val_loss', args=args)

In [None]:
trainer = Trainer(
            max_epochs=args['EPOCHS'],
            callbacks=callbacks,
            default_root_dir=args.CKPT_PATH,
            logger=logger,
            enable_progress_bar=False
        )

In [None]:
trainer.fit(lit_model,
            train_dataloaders=train_dataloader_all,
            val_dataloaders=val_dataloader_all)
        
torch.cuda.empty_cache()