In [1]:
import os, yaml
from datetime import datetime
from easydict import EasyDict
from glob import glob
import pickle
from pathlib import Path

import numpy as np
import pandas as pd

import torch
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader

from dataloader.bci_compet import get_dataset
from dataloader.bci_compet import BCICompet2aIV

from model.litmodel import LitModel
from model.litmodel import get_litmodel
from model.cat_conditioned import CatConditioned
from model.attn_conditioned import ATTNConditioned
from model.attn_conditioned_subj_avg import ATTNConditionedSubjAvg
from model.attn_conditioned_subj_ftr import ATTNConditionedSubjFtr


from pytorch_lightning import Trainer, seed_everything

from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score, cohen_kappa_score

%load_ext autoreload
%autoreload 2

In [2]:
### Set confings
config_name = 'bcicompet2a_config'
with open(f'configs/{config_name}.yaml') as file:
    config = yaml.load(file, Loader=yaml.FullLoader)
    args = EasyDict(config)
    

In [3]:
#### Set Device ####
# if torch.cuda.is_available():
#     os.environ['CUDA_VISIBLE_DEVICES'] = str(args.GPU_NUM)
cudnn.benchmark = True
cudnn.fastest = True
cudnn.deterministic = True

#### Set SEED ####
seed_everything(args.SEED)

#### Update configs ####
if args.downsampling != 0: args['sampling_rate'] = args.downsampling


Seed set to 42


In [4]:
CACHE_ROOT = 'cache'
TEST_ROOT = 'test_results'

args.is_test = True

def load_dataset(args, return_subject_id=False):
    datasets = {}
    for subject_id in range(0,9):
        args['target_subject'] = subject_id
        datasets[subject_id] = BCICompet2aIV(args)
    return datasets

path = os.path.join(CACHE_ROOT, f'{config_name}_base_test.pkl')

if not os.path.isfile(path):
    print('Cache miss, generating cache')
    datasets = load_dataset(args)
    with open(path, 'wb') as file:
        pickle.dump(datasets, file)
else:
    print('Loading cache')
    with open(path, 'rb') as file:
        datasets = pickle.load(file)

Loading cache


In [7]:
method =  'avg'
subj_info = 'avg'


trains = glob(f'{args.CKPT_PATH}/{args.task}/*{method}_L1SO_*')
versions = [train.split(os.sep)[-1] for train in trains]
print(versions)

versions = ['2023-12-14_15-36-43-avg_L1SO_0']

['2023-12-14_11-33-59-avg_L1SO_8', '2023-12-14_10-59-34-avg_L1SO_0', '2023-12-14_11-16-46-avg_L1SO_4', '2023-12-14_11-21-04-avg_L1SO_5', '2023-12-14_11-12-28-avg_L1SO_3', '2023-12-14_15-36-43-avg_L1SO_0', '2023-12-14_11-03-52-avg_L1SO_1', '2023-12-14_11-25-21-avg_L1SO_6', '2023-12-14_11-08-10-avg_L1SO_2', '2023-12-14_11-29-39-avg_L1SO_7']


In [9]:

result_dict = {}
for version in versions:
    LOS = int(version.split('_')[-1])
    ckpt_path = sorted(glob(f'{args.CKPT_PATH}/{args.task}/{version}/*.ckpt'))[-1]
    print()
    print(ckpt_path)
    print(LOS)
    print()

    #model = get_litmodel(args)
    in_model =  ATTNConditionedSubjAvg(args, embedding_dimension=30, combined_features_dimension=100 )
    #in_model = ATTNConditionedSubjFtr(args, embedding_dimension=23, combined_features_dimension=43, subj_dim=26 )
    model = LitModel(args, in_model)
    model.load_state_dict(torch.load(ckpt_path)['state_dict'], strict=False)
    trainer = Trainer()

    model.eval()

    datasets[LOS].return_subject_info = subj_info
    test_dataloader = DataLoader(datasets[LOS],
                                    batch_size=args.batch_size,
                                    pin_memory=False,
                                    num_workers=args.num_workers)
    gt = datasets[LOS].label                                   
    logits = trainer.predict(model, dataloaders=test_dataloader)
    pred = torch.cat(logits, dim=0).argmax(axis=1).detach().cpu().numpy()
    acc = accuracy_score(pred, gt)
    kappa = cohen_kappa_score(pred, gt)
    result_dict[LOS] = {'acc': acc, 'kappa': kappa}
result_df = pd.DataFrame(result_dict).T

Path(f'{TEST_ROOT}/{args.task}/{method}').mkdir(parents=True, exist_ok=True)

result_df.to_csv(f'{TEST_ROOT}/{args.task}/{method}/{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}_L1SO.csv')

Trainer will use only 1 of 2 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=2)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]



./checkpoints/BCICompet2a/2023-12-14_15-36-43-avg_L1SO_0/epoch=91-val_loss=0.937.ckpt
0



/home/devuser/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


Predicting: |          | 0/? [00:00<?, ?it/s]