In [109]:
import os, yaml
from datetime import datetime
from easydict import EasyDict
from glob import glob
import pickle

import numpy as np
import pandas as pd

import torch
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader

from dataloader.bci_compet import get_dataset
from dataloader.bci_compet import BCICompet2aIV

from model.litmodel import LitModel
from model.litmodel import get_litmodel
from model.cat_conditioned import CatConditioned
from model.attn_conditioned import ATTNConditioned
from model.attn_conditioned_subj_avg import ATTNConditionedSubjAvg
from model.attn_conditioned_subj_ftr import ATTNConditionedSubjFtr



from pytorch_lightning import Trainer, seed_everything

from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score, cohen_kappa_score

In [3]:
### Set confings
config_name = 'bcicompet2a_config'
with open(f'configs/{config_name}.yaml') as file:
    config = yaml.load(file, Loader=yaml.FullLoader)
    args = EasyDict(config)
    

In [6]:
#### Set Device ####
if torch.cuda.is_available():
    os.environ['CUDA_VISIBLE_DEVICES'] = args.GPU_NUM
cudnn.benchmark = True
cudnn.fastest = True
cudnn.deterministic = True

#### Set SEED ####
seed_everything(args.SEED)

#### Update configs ####
if args.downsampling != 0: args['sampling_rate'] = args.downsampling


Seed set to 42


In [107]:
CACHE_ROOT = 'cache'
TEST_ROOT = 'test_results'

args.is_test = True

def load_dataset(args, return_subject_id=False):
    datasets = {}
    for subject_id in range(0,9):
        args['target_subject'] = subject_id
        datasets[subject_id] = BCICompet2aIV(args)
    return datasets

path = os.path.join(CACHE_ROOT, f'{config_name}_base_test.pkl')

if not os.path.isfile(path):
    print('Cache miss, generating cache')
    datasets = load_dataset(args)
    with open(path, 'wb') as file:
        pickle.dump(datasets, file)
else:
    print('Loading cache')
    with open(path, 'rb') as file:
        datasets = pickle.load(file)

Loading cache


In [111]:
# version = '2023-11-30_14-52-13-all_baseline'
# subj_info = ''

# version = '2023-11-30_14-47-59-all_cat_cond'
# subj_info = 'id'

# version = '2023-11-30_14-55-16-all_attn_cond'
# subj_info = 'id'

# version = '2023-11-30_14-59-13-all_attn_cond_avg'
# subj_info = 'avg'

version = '2023-11-30_15-04-02-all_attn_cond_ftr'
subj_info = 'ftr'

ckpt_path = sorted(glob(f'{args.CKPT_PATH}/{args.task}/{version}/*.ckpt'))[-1]
print(ckpt_path)

#model = get_litmodel(args)
#in_model = CatConditioned(args, n_subjects=9, subject_filters=16, final_features=4, n_classes=args['num_classes'] )
#in_model = ATTNConditioned(args, n_subjects=9, embed_dim=16,  n_classes=args['num_classes'] )
#in_model =  ATTNConditionedSubjAvg(args, n_subjects=9, embed_dim=16,  n_classes=args['num_classes'] )
in_model = ATTNConditionedSubjFtr(args, n_subjects=9, embed_dim=16,  n_classes=args['num_classes'] )
model = LitModel(args, in_model)
model.load_state_dict(torch.load(ckpt_path)['state_dict'], strict=False)
trainer = Trainer()          
        
result_dict = {}

for subject_id in range(0,9):
    datasets[subject_id].return_subject_info = subj_info
    test_dataloader = DataLoader(datasets[subject_id],
                                    batch_size=args.batch_size,
                                    pin_memory=False,
                                    num_workers=args.num_workers)
    gt = datasets[0].label                                   
    logits = trainer.predict(model, dataloaders=test_dataloader)
    pred = torch.cat(logits, dim=0).argmax(axis=1).detach().cpu().numpy()
    acc = accuracy_score(pred, gt)
    kappa = cohen_kappa_score(pred, gt)
    result_dict[subject_id] = {'acc': acc, 'kappa': kappa}
result_df = pd.DataFrame(result_dict).T

result_df.to_csv(f'{TEST_ROOT}/{args.task}/{version}.csv')

Trainer will use only 1 of 2 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=2)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


./checkpoints/BCICompet2a/2023-11-30_15-04-02-all_attn_cond_ftr/epoch=481-val_loss=0.875.ckpt


Predicting: |          | 0/? [00:00<?, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |          | 0/? [00:00<?, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |          | 0/? [00:00<?, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |          | 0/? [00:00<?, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |          | 0/? [00:00<?, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |          | 0/? [00:00<?, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |          | 0/? [00:00<?, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |          | 0/? [00:00<?, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |          | 0/? [00:00<?, ?it/s]

In [96]:
datasets[0][0]

{'data': array([[[-0.88996919, -1.27907056, -1.35011277, ...,  0.28252727,
           0.28288531,  0.16730833],
         [-0.96760044, -1.0159445 , -0.87678273, ...,  0.85993223,
           1.03264032,  0.93162256],
         [-0.6876539 , -0.89942583, -0.90780913, ...,  0.90229599,
           1.0303905 ,  0.93598802],
         ...,
         [-1.59265782, -1.45238073, -1.1846601 , ...,  1.03645961,
           1.32855268,  1.38238639],
         [-1.58725833, -1.51257941, -1.2675144 , ...,  0.86497084,
           1.16638995,  1.22026448],
         [-1.3676114 , -1.14787983, -1.01731697, ...,  0.79437571,
           1.11341729,  1.22393781]]]),
 'subject_info': 0,
 'label': 0}

In [89]:
result_df['acc'].mean(), result_df['kappa'].mean()

(0.6878858024691358, 0.5838477366255144)