# Finetuning Notebook for Thesis (Simple Classifiers)

### Imports

In [2]:
print("Hello world")

# Add custom path
import sys
sys.path.append("/home/maxihuber/eeg-foundation/")

# Standard library imports
import os
import gc
import glob
import json

# Third-party library imports
import numpy as np
import pandas as pd
import torch
import mne
import lightning.pytorch as L
import matplotlib.pyplot as plt
from natsort import natsorted
from tqdm import tqdm
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis, QuadraticDiscriminantAnalysis
from sklearn.linear_model import LinearRegression
from sklearn.metrics import balanced_accuracy_score, mean_squared_error
from sklearn.ensemble import AdaBoostClassifier, RandomForestClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.pipeline import make_pipeline
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier

# MNE imports
from mne.preprocessing import Xdawn

# Custom imports
from src.utils.preloading.utils import load_edf_to_dataframe

# Set MNE log level
mne.set_log_level('warning')

# Seed everything
L.seed_everything(42)

print("Bye world")

Hello world


AttributeError: partially initialized module 'pandas' has no attribute '_pandas_parser_CAPI' (most likely due to a circular import)

## Data Loading

### Load Train/Val/Test Information

In [2]:
########################################################################################################################
# TUAB and Epilepsy

yc_class = {
    "class_name": "YC",
    "time_col": "Time in Seconds",
    "prefix_filepath": "/itet-stor/maxihuber/deepeye_storage/foundation/tueg/edf",
    "load_mode": 2,
}

tuab = {
    "task_name": "TUAB",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/net_scratch/finetune_files/tuab_light.json",
    "out_dim": 2,
}

epilepsy = {
    "task_name": "Epilepsy",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/net_scratch/finetune_files/epilepsy_light.json",
    "out_dim": 2,
}

yc_tasks = [tuab, epilepsy]

########################################################################################################################
# Clinical JSONs

cli_class = {
    "class_name": "Clinical",
    "time_col": "Time in Seconds",
    "prefix_filepath": "/itet-stor/maxihuber/deepeye_storage/foundation_clinical_prepared/",
    "load_mode": 0,
}

age = {
    "task_name": "Age",
    "task_type": "Regression",
    "json_path": "/itet-stor/maxihuber/net_scratch/finetune_files/age_light.json",
    "out_dim": 1,
}

depression = {
    "task_name": "Depression",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/net_scratch/finetune_files/cli_depression_light.json",
    "out_dim": 2,
}

parkinsons = {
    "task_name": "Parkinsons",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/net_scratch/finetune_files/cli_parkinsons_light.json",
    "out_dim": 2,
}

schizophrenia = {
    "task_name": "Schizophrenia",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/net_scratch/finetune_files/cli_schizophrenia_light.json",
    "out_dim": 2,
}

sex = {
    "task_name": "Sex",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/net_scratch/finetune_files/sex_light.json",
    "out_dim": 2,
}

cli_tasks = [age, depression, parkinsons, schizophrenia, sex]


########################################################################################################################
# Motor-Imagery JSONs

mi_class = {
    "class_name": "Motor Imagery",
    "time_col": "time in seconds",
    "prefix_filepath": "/itet-stor/maxihuber/deepeye_storage/foundation_prepared/",
    "load_mode": 0,
}

eye_open_closed = {
    "task_name": "EyeOpenClosed",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/net_scratch/finetune_files/eye_open_closed_light.json",
    "out_dim": 2,
    "outputs": set(["eye open", "eye closed"]),
    "short_mode": False,
}

eye_vh = {
    "task_name": "EyeVH",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/net_scratch/finetune_files/eye_vh_light.json",
    "out_dim": 2,
    "outputs": set(["vertical", "horizontal"]),
    "short_mode": False,
}

flexion_extension_imaginary = {
    "task_name": "FlexionExtensionImaginary",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/net_scratch/finetune_files/flexion_extension_imaginary_light.json",
    "out_dim": 2,
    "outputs": set(
        [
            "hand movement imagined elbow flexion",
            "hand movement imagined elbow extension",
        ]
    ),
    "short_mode": False,
}

flexion_extension_real = {
    "task_name": "FlexionExtensionReal",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/net_scratch/finetune_files/flexion_extension_real_light.json",
    "out_dim": 2,
    "outputs": set(["hand movement elbow extension", "hand movement elbow flexion"]),
    "short_mode": False,
}

grasp_imaginary = {
    "task_name": "GraspImaginary",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/net_scratch/finetune_files/grasp_imaginary_light.json",
    "out_dim": 2,
    "outputs": set(["imagined palmar grasp", "imagined lateral grasp"]),
    "short_mode": False,
}

grasp_real = {
    "task_name": "GraspReal",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/net_scratch/finetune_files/grasp_real_light.json",
    "out_dim": 2,
    "outputs": set(["movement palmar grasp", "movement lateral grasp"]),
    "short_mode": False,
}

lr_imaginary = {
    "task_name": "LRImaginary",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/net_scratch/finetune_files/lr_imaginary_light.json",
    "out_dim": 2,
    "outputs": set(["left hand imagined movement", "right hand imagined movement"]),
    "short_mode": True,
}

lr_real = {
    "task_name": "LRReal",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/net_scratch/finetune_files/lr_real_light.json",
    "out_dim": 2,
    "outputs": set(["right hand movement", "left hand movement"]),
    "short_mode": True,
}

mi_task_body_parts_imagined = {
    "task_name": "BodyPartsImagined",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/net_scratch/finetune_files/mi_task_imagined_body_parts_light.json",
    "out_dim": 5,
    "outputs": set(
        [
            "rest",
            "right hand imagined movement",
            "foot imagined movement",
            "left hand imagined movement",
            "tongue imagined movement",
        ]
    ),
    "short_mode": True,
}

mi_task_body_parts_real = {
    "task_name": "BodyPartsReal",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/net_scratch/finetune_files/mi_task_body_parts_light.json",
    "out_dim": 4,
    "outputs": set(
        ["rest", "right hand movement", "foot movement", "left hand movement"]
    ),
    "short_mode": True,
}

pronation_supination_imaginary = {
    "task_name": "PronationSupinationImaginary",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/net_scratch/finetune_files/pronation_supination_imaginary_light.json",
    "out_dim": 2,
    "outputs": set(["imagined supination", "imagined pronation"]),
    "short_mode": False,
}

pronation_supination_real = {
    "task_name": "PronationSupinationReal",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/net_scratch/finetune_files/pronation_supination_real_light.json",
    "out_dim": 2,
    "outputs": set(["movement supination", "movement pronation"]),
    "short_mode": False,
}

mi_tasks = [eye_open_closed, eye_vh, flexion_extension_imaginary, flexion_extension_real, 
            grasp_imaginary, grasp_real, lr_imaginary, lr_real,
            mi_task_body_parts_imagined, mi_task_body_parts_real,
            pronation_supination_imaginary, pronation_supination_real]

########################################################################################################################
# ERP JSONs

erp_class = {
    "class_name": "Error-Related Potential",
    "time_col": "time in seconds",
    "prefix_filepath": "/itet-stor/maxihuber/deepeye_storage/foundation_prepared/",
    "load_mode": 0,
}

erp = {
    "task_name": "ERP",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/net_scratch/finetune_files/new_erp_light.json",
    "out_dim": 2,
    "outputs": set(
        [
            "with event-related potential",
            "without event-related potential",
        ]
    ),
}

errp = {
    "task_name": "ERRP",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/net_scratch/finetune_files/errp_all_light.json",
    "out_dim": 2,
    "outputs": set(
        [
            "without error-related potential",
            "with error-related potential",
        ]
    ),
}

erp_tasks = [erp, errp]

########################################################################################################################
# EyeNet JSONs

eye_class = {
    "class_name": "EyeNet",
    "time_col": "time",
    "prefix_filepath": "/itet-stor/maxihuber/deepeye_storage/foundation_prepared/",
    "load_mode": 1,
}

eye_dir_amp = {
    "task_name": "EyeNetDirectionAmp",
    "task_type": "Regression",
    "json_path": [
        "/itet-stor/maxihuber/deepeye_storage/eegeyenet_tasks/EEGEyeNet_Direction_Amp_train.json",
        "/itet-stor/maxihuber/deepeye_storage/eegeyenet_tasks/EEGEyeNet_Direction_Amp_val.json",
        "/itet-stor/maxihuber/deepeye_storage/eegeyenet_tasks/EEGEyeNet_Direction_Amp_test.json",
    ],
    "out_dim": 1,
}

eye_dir_ang = {
    "task_name": "EyeNetDirectionAng",
    "task_type": "Regression",
    "json_path": [
        "/itet-stor/maxihuber/deepeye_storage/eegeyenet_tasks/EEGEyeNet_Direction_Ang_train.json",
        "/itet-stor/maxihuber/deepeye_storage/eegeyenet_tasks/EEGEyeNet_Direction_Ang_val.json",
        "/itet-stor/maxihuber/deepeye_storage/eegeyenet_tasks/EEGEyeNet_Direction_Ang_test.json",
    ],
    "out_dim": 1,
}

eye_lr = {
    "task_name": "EyeNetLR",
    "task_type": "Classification",
    "json_path": [
        "/itet-stor/maxihuber/deepeye_storage/eegeyenet_tasks/EEGEyeNet_LR_train.json",
        "/itet-stor/maxihuber/deepeye_storage/eegeyenet_tasks/EEGEyeNet_LR_val.json",
        "/itet-stor/maxihuber/deepeye_storage/eegeyenet_tasks/EEGEyeNet_LR_test.json",
    ],
    "out_dim": 2,
}

eye_position = {
    "task_name": "EyeNetPosition",
    "task_type": "Regression",
    "json_path": [
        "/itet-stor/maxihuber/deepeye_storage/eegeyenet_tasks/EEGEyeNet_Position_train.json",
        "/itet-stor/maxihuber/deepeye_storage/eegeyenet_tasks/EEGEyeNet_Position_val.json",
        "/itet-stor/maxihuber/deepeye_storage/eegeyenet_tasks/EEGEyeNet_Position_test.json",
    ],
    "out_dim": 2,
}

eye_tasks = [eye_dir_amp, eye_dir_ang, eye_lr, eye_position]

classes = {
    "YC": [yc_class, yc_tasks], 
    "Clinical": [cli_class, cli_tasks], 
    "MI": [mi_class, mi_tasks],
    "ERP": [erp_class, erp_tasks], 
    "EyeNet": [eye_class, eye_tasks],
}

### Load data into memory

In [None]:
########################################################################################################################
# Select the class and task

# used_class = yc_class
# used_class = cli_class
used_class = mi_class
# used_class = erp_class
# used_class = eye_class
#
# used_task = tuab
# used_task = epilepsy
# used_task = age
# used_task = depression
# used_task = parkinsons
# used_task = schizophrenia
# used_task = sex
#
used_task = eye_open_closed
# used_task = eye_vh
# used_task = flexion_extension_imaginary
# used_task = flexion_extension_real
# used_task = grasp_real
# used_task = lr_imaginary
# used_task = lr_real
# used_task = mi_task_body_parts_real
# used_task = mi_task_body_parts_imagined
# used_task = pronation_supination_real
# used_task = pronation_supination_imaginary
#
# used_task = erp
# used_task = errp
#
# used_task = eye_dir_amp
# used_task = eye_dir_ang
# used_task = eye_lr
# used_task = eye_position

class_name = used_class["class_name"]
time_col = used_class["time_col"]
prefix_filepath = used_class["prefix_filepath"]
load_mode = used_class["load_mode"]
task_name = used_task["task_name"]
task_type = used_task["task_type"]
json_path = used_task["json_path"]
out_dim = used_task["out_dim"]
short_mode = used_task["short_mode"] if "short_mode" in used_task else False

with open('/home/maxihuber/eeg-foundation/src/data/components/channels_to_id3.json', 'r') as f:
    chn_to_id = json.load(f)
    task_channels = set([chn: id for chn, id in chn_to_id.items() if id > 0]) # excludes "None"
print(f"Task channels: {task_channels}")

truncate = True
num_keep = 100

def load_index0(data_index_path):
    with open(data_index_path, "r") as f:
        train_test_dict = json.load(f)
    train_samples = train_test_dict["train"]
    test_samples = train_test_dict["test"]
    return train_samples, test_samples

def load_index1(data_index_paths):
    all_samples = []
    for data_index_path in data_index_paths:
        with open(data_index_path, "r") as f:
            subset_dict = json.load(f)
        all_samples.append(list(subset_dict.values())[0])
    return all_samples[0], all_samples[1], all_samples[2]

def truncate0(train_index, test_index, num_keep, truncate=False):
    train_index = train_index[:num_keep] + train_index[-num_keep:] if truncate else train_index
    test_index = test_index[:num_keep] + test_index[-num_keep:] if truncate else test_index
    return train_index, test_index

def truncate1(train_index, val_index, test_index, num_keep, truncate=False):
    train_index = train_index[:num_keep] + train_index[-num_keep:] if truncate else train_index
    val_index = val_index[:num_keep] + val_index[-num_keep:] if truncate else val_index
    test_index = test_index[:num_keep] + test_index[-num_keep:] if truncate else test_index
    return train_index, val_index, test_index

def get_node_index(index_patterns):
    index_paths = []
    for pattern in index_patterns:  # regex the index_patterns
        index_paths.extend(glob.glob(pattern))
    num_trials = 0
    trial_info_index = {}
    for index_path in index_paths:
        with open(index_path, "r") as f:
            new_trial_info_index = json.load(f)
            for trial_info in new_trial_info_index.values():
                trial_info_index[num_trials] = trial_info
                num_trials += 1
    print(f"[get_node_index] # Trials = {num_trials}", file=sys.stderr)
    return trial_info_index

def get_full_paths(input_files, prefix_filepath, filename_to_nodepath):
    adjusted_files = []
    for file in input_files:
        if os.path.basepath(file) in filename_to_nodepath:
            print("hit")
            adjusted_files.append(filename_to_nodepath[file])
        else:
            file = prefix_filepath + file if "/itet-stor" not in file else file.replace("/itet-stor/kard", "/itet-stor/maxihuber")
            adjusted_files.append(file)
    return adjusted_files

def get_generic_channel_name(channel_name):
    channel_name = channel_name.lower()
    # Remove "eeg " prefix if present
    if channel_name.startswith("eeg "):
        channel_name = channel_name[4:]
    # Simplify names with a dash and check if it ends with "-"
    if "-" in channel_name:
        if channel_name.endswith("-"):
            return "None"
        return channel_name.split("-")[0]
    return channel_name

def load_file_data(data_index, task_channels, filename_to_nodepath, load_mode, task_name):
    num_samples = 0
    data = {}
    outputs = {}
    srs = {}
    durs = {}
    channels = {}
    failed_samples = []

    for sample in tqdm(data_index, desc="Loading data", position=0, leave=True):
        try:
            # Load the data of this sample
            input_files = get_full_paths(sample["input"], prefix_filepath, filename_to_nodepath)

            if load_mode == 2:
                file = input_files[0]
                df = load_edf_to_dataframe(file)
            else:
                dataframes = [pd.read_pickle(file) for file in input_files]
                df = pd.concat(dataframes, axis=0)
            
            # Crop the data to the desired length
            start = int(sample["start"])
            length = int(sample["length"]) if "length" in sample else int(sample["end"])
            df.loc[start : start + length, :] if load_mode==1 else df.iloc[start:length, :]
            assert len(df) > 0, f"Empty dataframe for sample: {sample}"
   
            # Add labels
            if load_mode != 1:
                outputs[num_samples] = sample.get("output", sample.get("label"))
            else:
                outputs[num_samples] = list(sample["output"].values()) if task_name == "EyeNetPosition" else list(sample["output"].values())[0]
            
            # Add metadata
            sr = int(1 / (df[time_col].iloc[1] - df[time_col].iloc[0]))
            srs[num_samples] = sr
            durs[num_samples] = len(df) / sr
            
            channels[num_samples] = sorted(list(set(df.columns) & set(task_channels)), key=lambda x: list(task_channels).index(x))
            df = df[channels[num_samples]].astype(float)
            data[num_samples] = torch.tensor(df.to_numpy(), dtype=torch.float32).T
            
            num_samples += 1
            
            del df
            if num_samples % 100 == 0:
                gc.collect()

        except Exception as e:
            print(f"Failed to process sample: {sample}. Error: {e}", file=sys.stderr)
            failed_samples.append(sample)

    return data, outputs, srs, durs, channels


print(f"Preparing local paths...")
index_patterns = ["/dev/shm/mae/index_*.json", "/scratch/mae/index_*.json"]
node_index = get_node_index(index_patterns=index_patterns)
filename_to_nodepath = {os.path.basename(ie["origin_path"]): ie["new_path"] for trial_idx, ie in node_index.items()}
print(f"Prepared local paths. {len(filename_to_nodepath)} files found on node.")

if load_mode != 1:
    train_index, test_index = load_index0(json_path)
    train_index, test_index = truncate0(train_index, test_index, num_keep, truncate)
    
    print("=" * 10 + "Load train data" + "=" * 100)
    train_data, train_outputs, train_sr, train_dur, train_channels = (
        load_file_data(train_index, task_channels, filename_to_nodepath, load_mode, task_name)
    )
    print("=" * 10 + "Load test data" + "=" * 100)
    test_data, test_outputs, test_sr, test_dur, test_channels = (
        load_file_data(test_index, task_channels, filename_to_nodepath, load_mode, task_name)
    )
else:
    train_index, val_index, test_index = load_index1(json_path)
    train_index, val_index, test_index = truncate1(train_index, val_index, test_index, num_keep, truncate)
    
    print("=" * 10 + "Load train data" + "=" * 100)
    train_data, train_outputs, train_sr, train_dur, train_channels = (
        load_file_data(train_index, task_channels, filename_to_nodepath, load_mode, task_name)
    )
    print("=" * 10 + "Load val data" + "=" * 100)
    val_data, val_outputs, val_sr, val_dur, val_channels = load_file_data(
        val_index, task_channels, filename_to_nodepath, load_mode, task_name
    )
    print("=" * 10 + "Load test data" + "=" * 100)
    test_data, test_outputs, test_sr, test_dur, test_channels = (
        load_file_data(test_index, task_channels, filename_to_nodepath, load_mode, task_name)
    )

# Baseline Models

## Data Preprocessing

### Helpers

In [None]:
def pad_tensor(tensor, target_height, target_width):
    current_height, current_width = tensor.shape

    if current_height < target_height:
        padding_height = target_height - current_height
        padding = torch.zeros((padding_height, current_width), dtype=tensor.dtype)
        tensor = torch.cat((tensor, padding), dim=0)
    else:
        tensor = tensor[:target_height, :]

    if current_width < target_width:
        padding_width = target_width - current_width
        padding = torch.zeros((tensor.shape[0], padding_width), dtype=tensor.dtype)
        tensor = torch.cat((tensor, padding), dim=1)
    else:
        tensor = tensor[:, :target_width]

    return tensor

def resample_signals(data, srs, target_sfreq):
    resampled_data = {}
    for idx, signal in tqdm(data.items(), desc="Resampling signals"):
        signal_numpy = signal.numpy().astype(np.float64)
        signal_resampled = mne.filter.resample(signal_numpy, up=target_sfreq / srs[idx])
        resampled_data[idx] = torch.tensor(signal_resampled, dtype=torch.float32)
        del signal_numpy, signal_resampled  # Free memory
        gc.collect()  # Explicitly invoke garbage collection
    return resampled_data

def pad_or_truncate_signals(data, common_length):
    for idx, signal in tqdm(data.items(), desc="Pad/Truncate signals"):
        signal_length = signal.shape[1]
        if signal_length < common_length:
            pad_width = common_length - signal_length
            signal_padded = np.pad(signal, ((0, 0), (0, pad_width)), mode="constant")
        else:
            signal_padded = signal[:, :common_length]
        data[idx] = torch.tensor(signal_padded, dtype=torch.float32)
        del signal, signal_padded  # Free memory
        gc.collect()  # Explicitly invoke garbage collection
    return data

def create_epochs(data, outputs, channels, sfreq=1000, is_classification=True):
    events = []
    event_id = {}
    epochs_data = []
    for idx, signal in tqdm(data.items(), desc="Creating epochs"):
        epochs_data.append(signal.numpy())
        if is_classification:
            if outputs[idx] not in event_id:
                event_id[outputs[idx]] = len(event_id) + 1
            events.append([idx, 0, event_id[outputs[idx]]])
        else:
            events.append([idx, 0, 1])
    events = np.array(events, dtype=int)
    info = mne.create_info(
        ch_names=[f"EEG_{i}" for i in range(channels[0].shape[0])], sfreq=sfreq, ch_types="eeg"
    )
    epochs = mne.EpochsArray(
        np.array(epochs_data),
        info,
        events=events,
        event_id=event_id if is_classification else None,
    )
    del events, info, epochs_data  # Free memory
    gc.collect()  # Explicitly invoke garbage collection
    return epochs

### Applying xDAWN Preprocessing

In [None]:
L.seed_everything(42)
sys.path.append("/home/maxihuber/eeg-foundation/src/models/components/Baselines")

os.makedirs(f"/itet-stor/maxihuber/net_scratch/finetune_ckpts/{task_name}", exist_ok=True)

durs = [df.shape[1] for idx, df in train_data.items()] + [df.shape[1] for idx, df in test_data.items()]
n_chns = [df.shape[0] for idx, df in train_data.items()] + [df.shape[0] for idx, df in test_data.items()]
dur_90 = int(np.percentile(durs, 90))
chn_90 = int(np.percentile(durs, 90))

train_data_pad = {k: pad_tensor(signals, chn_90, dur_90) for k, signals in train_data.items()}
test_data_pad = {k: pad_tensor(signals, chn_90, dur_90) for k, signals in test_data.items()}
del train_data, test_data  # Free memory
gc.collect()  # Explicitly invoke garbage collection

target_sfreq = int(max(train_sr.values()))

common_length = min(
    min(signal.shape[1] for signal in train_data_pad.values()),
    min(signal.shape[1] for signal in test_data_pad.values()),
)

train_data_padded = pad_or_truncate_signals(train_data_pad, common_length)
test_data_padded = pad_or_truncate_signals(test_data_pad, common_length)
del train_data_pad, test_data_pad  # Free memory
gc.collect()  # Explicitly invoke garbage collection

is_classification = True if task_type == "Classification" else False
epochs_train = create_epochs(
    train_data_padded, train_outputs, list(train_channels.values()), target_sfreq, is_classification
)
epochs_test = create_epochs(
    test_data_padded, test_outputs, list(test_channels.values()), target_sfreq, is_classification
)
del train_data_padded, test_data_padded  # Free memory
gc.collect()  # Explicitly invoke garbage collection

print("Start fitting Xdawn...", file=sys.stderr)

xdawn = Xdawn(n_components=2, correct_overlap=False, reg=0.1)
xdawn.fit(epochs_train)

## Finetuning

In [None]:
# Code source: Gaël Varoquaux
#              Andreas Müller
# Modified for documentation by Jaques Grobler
# License: BSD 3 clause
# https://scikit-learn.org/stable/auto_examples/classification/plot_classifier_comparison.html

def get_labels(train_outputs, test_outputs, is_classification):
    if is_classification:
        label_encoder = LabelEncoder()
        y_train = label_encoder.fit_transform(list(train_outputs.values()))
        y_test = label_encoder.transform(list(test_outputs.values()))
    else:
        y_train = np.array(list(train_outputs.values()))
        y_test = np.array(list(test_outputs.values()))
    return y_train, y_test

# TODO: check which classification & regression models we actually want
simple_models = {
    "Classification": {
        "LDA": LinearDiscriminantAnalysis(),
        "QDA": QuadraticDiscriminantAnalysis(),
        "Nearest Neighbors": KNeighborsClassifier(3),
        "Linear SVM": SVC(kernel="linear", C=0.025, random_state=42),
        "RBF SVM": SVC(gamma=2, C=1, random_state=42),
        "Decision Tree": DecisionTreeClassifier(max_depth=5, random_state=42),
        "Random Forest": RandomForestClassifier(
            max_depth=5, n_estimators=10, max_features=1, random_state=42
        ),
        "AdaBoost": AdaBoostClassifier(algorithm="SAMME", random_state=42),
    },
    "Regression": {
        "Linear Regression": LinearRegression(),
        "Ridge Regression": Ridge(alpha=1.0),
        "K-Nearest Neighbors Regressor": KNeighborsRegressor(n_neighbors=3),
        "Support Vector Regression (Linear)": SVR(kernel='linear', C=1.0),
        "Support Vector Regression (RBF)": SVR(kernel='rbf', C=1.0, gamma=0.1),
        "Decision Tree Regressor": DecisionTreeRegressor(max_depth=5, random_state=42),
        "Random Forest Regressor": RandomForestRegressor(
            max_depth=5, n_estimators=10, max_features=1, random_state=42
        ),
        "AdaBoost Regressor": AdaBoostRegressor(n_estimators=50, random_state=42),
    },
}

# TODO: more accuracy parameters, for both regression and classification
from sklearn.metrics import (
    balanced_accuracy_score, accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, confusion_matrix, mean_squared_error, mean_absolute_error,
    r2_score, mean_absolute_percentage_error
)
scores = {
    "Classification": {
        "Accuracy": accuracy_score,
        "Balanced Accuracy": balanced_accuracy_score,
        "Precision": precision_score,
        "Recall": recall_score,
        "F1 Score": f1_score,
        "ROC AUC": roc_auc_score,
        "Confusion Matrix": confusion_matrix,
    },
    "Regression": {
        "MAE": mean_absolute_error,
        "RMSE": lambda y_true, y_pred: mean_squared_error(y_true, y_pred, squared=False),  # Using lambda for RMSE
        "R-squared": r2_score,
        "MAPE": mean_absolute_percentage_error,
    }
}

# Transform the data using xDAWN
X_train_xdawn = xdawn.transform(epochs_train)
X_test_xdawn = xdawn.transform(epochs_test)

# Flatten the transformed data for LDA input
n_epochs_train, n_components, n_times = X_train_xdawn.shape
X_train_xdawn = X_train_xdawn.reshape(n_epochs_train, n_components * n_times)
n_epochs_test, n_components, n_times = X_test_xdawn.shape
X_test_xdawn = X_test_xdawn.reshape(n_epochs_test, n_components * n_times)

y_train, y_test = get_labels(train_outputs, test_outputs, is_classification)

for name, clf in simple_models[task_type].items():
    print(f"Evaluating {name} model.", file=sys.stderr)
    
    clf_pipeline = make_pipeline(StandardScaler(), clf)
    clf_pipeline.fit(X_train_xdawn, y_train)
    y_pred = clf_pipeline.predict(X_test_xdawn)

    for score_name, score_func in scores[task_type].items():
        score = score_func(y_test, y_pred)
        print(f"{score_name}: {score} ({name})", file=sys.stderr)

    # Clean up to save memory
    del clf_pipeline, y_pred
    gc.collect()