In [None]:
import os, yaml
from datetime import datetime
from easydict import EasyDict
from glob import glob
import pickle

import numpy as np
import pandas as pd

import torch
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader

from dataloader.bci_compet import get_dataset
from dataloader.bci_compet import BCICompet2bIV

from model.litmodel import LitModel
from model.litmodel import get_litmodel
from model.cat_conditioned import CatConditioned
from model.attn_conditioned import ATTNConditioned
from model.attn_conditioned_subj_avg import ATTNConditionedSubjAvg
from model.attn_conditioned_subj_ftr import ATTNConditionedSubjFtr


from pytorch_lightning import Trainer, seed_everything

from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score, cohen_kappa_score

%load_ext autoreload
%autoreload 2

In [None]:
### Set confings
config_name = 'bcicompet2b_config'
with open(f'configs/{config_name}.yaml') as file:
    config = yaml.load(file, Loader=yaml.FullLoader)
    args = EasyDict(config)
    

In [None]:
#### Set Device ####
# if torch.cuda.is_available():
#     os.environ['CUDA_VISIBLE_DEVICES'] = str(args.GPU_NUM)
cudnn.benchmark = True
cudnn.fastest = True
cudnn.deterministic = True

#### Set SEED ####
seed_everything(args.SEED)

#### Update configs ####
if args.downsampling != 0: args['sampling_rate'] = args.downsampling


In [None]:
CACHE_ROOT = 'cache'
TEST_ROOT = 'test_results'

args.is_test = True

def load_dataset(args, return_subject_id=False):
    datasets = {}
    for subject_id in range(0,9):
        args['target_subject'] = subject_id
        datasets[subject_id] = BCICompet2bIV(args)
    return datasets

path = os.path.join(CACHE_ROOT, f'{config_name}.pkl')

if not os.path.isfile(path):
    print('Cache miss, generating cache')
    datasets = load_dataset(args)
    with open(path, 'wb') as file:
        pickle.dump(datasets, file)
else:
    print('Loading cache')
    with open(path, 'rb') as file:
        datasets = pickle.load(file)

In [None]:
version = '2023-12-14_12-06-11-baseline-3'
subj_info = ''

# version = '2023-12-14_12-08-50-cat-2'
# subj_info = 'id'

# version = '2023-12-14_12-11-05-attn-0'
# subj_info = 'id'

# version = '2023-12-14_12-12-57-avg-1'
# subj_info = 'avg'

# version = '2023-12-14_12-15-50-ftr-1'
# subj_info = 'ftr'

ckpt_path = sorted(glob(f'{args.CKPT_PATH}/{args.task}/{version}/*.ckpt'))[-1]
print(ckpt_path)

# TODO use args from optuna histtory
model = get_litmodel(args)
#in_model = CatConditioned(args, embedding_dimension=48, combined_features_dimension=90 )
#in_model = ATTNConditioned(args,  eeg_normalization = 'LayerNorm', subject_normalization='LayerNorm',embedding_dimension=8, combined_features_dimension=108)
#in_model =  ATTNConditionedSubjAvg(args, eeg_normalization = 'LayerNorm', subject_normalization='LayerNorm', embedding_dimension=14, combined_features_dimension=69 )
#in_model = ATTNConditionedSubjFtr(args,  eeg_normalization = 'LayerNorm', subject_normalization='LayerNorm',embedding_dimension=51, combined_features_dimension=100, subj_dim=17 )
#model = LitModel(args, in_model)
model.load_state_dict(torch.load(ckpt_path)['state_dict'], strict=False)
trainer = Trainer()

model.eval()

In [None]:

result_dict = {}

for subject_id in range(0,9):
    datasets[subject_id].return_subject_info = subj_info
    test_dataloader = DataLoader(datasets[subject_id],
                                    batch_size=args.batch_size,
                                    pin_memory=False,
                                    num_workers=args.num_workers)
    gt = datasets[subject_id].label                                   
    logits = trainer.predict(model, dataloaders=test_dataloader)
    pred = torch.cat(logits, dim=0).argmax(axis=1).detach().cpu().numpy()
    acc = accuracy_score(pred, gt)
    kappa = cohen_kappa_score(pred, gt)
    result_dict[subject_id] = {'acc': acc, 'kappa': kappa}
result_df = pd.DataFrame(result_dict).T

result_df.to_csv(f'{TEST_ROOT}/{args.task}/{version}.csv')