In [None]:
import os
import torch
import torch.nn as nn
import pytorch_lightning as pl
import yaml

from pytorch.data_setup.DataModule import DataModule
from pytorch.models.EEGNet import EEGNetv4
from pytorch.models.TSception import TSception
from pytorch.models.EEGChannelNet import ChannelNet
from pytorch.models.Conformer import Conformer

def read_config(config_path: str):
    with open(config_path) as file:
        config = yaml.load(file, Loader=yaml.FullLoader)
    return config

for sub in ["P001", "P002", "P004", "P005", "P006", "P007", "P008", "P009"]: 
    for model_name in ["EEGNet", "TSception", "ChannelNet", "Conformer"]:
        config = read_config(config_path = f"./pytorch/configs/test/{sub}/{sub}_{model_name}_test.yaml")
        model_config = config["parameters"]["model"]["parameters"]
        dm_config = config["parameters"]["datamodule"]["parameters"]
        model_config = {key: model_config[key]["value"] for key in model_config}
        dm_config = {key: dm_config[key]["value"] for key in dm_config}
        if model_name == "EEGNet":
            model = EEGNetv4(**model_config) 
        elif model_name == "TSception":
            model = TSception(**model_config) 
        elif model_name == "ChannelNet":
            model = ChannelNet(**model_config) 
        elif model_name == "Conformer":
            model = Conformer(**model_config) 
        ckpt_path = f"./pytorch/final_classification_ckpts/{sub}/{model_name}_{sub}.ckpt"
        checkpoint = torch.load(ckpt_path)
        model.load_state_dict(checkpoint['state_dict'])
        model.eval()
        model.freeze()
        trainer = pl.Trainer()
        dm = DataModule(**dm_config)
        print(f"Testing {model_name} on {sub}")
        trainer.test(model = model, datamodule = dm)