In [29]:
import os
import sys
import json
import logging
from math import ceil
import numpy as np
import pandas as pd
from tqdm import tqdm

import re
import mne
mne.set_log_level("WARNING")

def get_generic_channel_name(self, channel_name):
    channel_name = channel_name.lower()
    # Remove "eeg " prefix if present
    if channel_name.startswith("eeg "):
        channel_name = channel_name[4:]
    # Simplify names with a dash and check if it ends with "-"
    if "-" in channel_name:
        if channel_name.endswith("-"):
            return "None"
        return channel_name.split("-")[0]
    return channel_name

def create_raw(
    data,
    ch_names1,
    sr,
    ch_names2=None,
):
    if ch_names2 == None:
        ch_names2 = ch_names1
    ch_types = ["eeg" for _ in range(len(ch_names1))]
    info = mne.create_info(ch_names2, ch_types=ch_types, sfreq=sr)
    eeg_data = np.array(data[ch_names1].T, dtype="float") / 1_000_000
    raw = mne.io.RawArray(eeg_data, info)
    return raw

def avg_channel(raw):
    avg = raw.copy().add_reference_channels(ref_channels="AVG_REF")
    avg = avg.set_eeg_reference(ref_channels="average")
    return avg

def load_from_path(path, channels, sr):
    if path.endswith("edf"):
        eeg_data = mne.io.read_raw_edf(
            path,
            include=channels,
            preload=True,
        )
    elif path.endswith("pkl"):
        # Load DataFrame from pickle
        with open(path, "rb") as file:
            df = pd.read_pickle(file)
            eeg_data = create_raw(
                data=df,
                ch_names1=channels,
                sr=sr,
            )
    else:
        assert False, "Invalid path"

    # Add average reference
    eeg_data = avg_channel(eeg_data)

    # Datastructure to access data for each channel
    channel_data_dict = {}

    # Note: channel_data_dict also includes the AVG_REF channel
    for channel in eeg_data.ch_names:
        idx = eeg_data.ch_names.index(channel)
        data, times = eeg_data[idx, :]
        # Flatten the data to 1D if required
        channel_data_dict[channel] = data.flatten()

    df = pd.DataFrame(channel_data_dict)
    df["Time in Seconds"] = times.flatten()

    return df

import json
import matplotlib.pyplot as plt

tueg_path_new = "/itet-stor/maxihuber/deepeye_storage/index_files/new_tueg_index.json"

with open("/itet-stor/maxihuber/deepeye_storage/index_files/full_tueg_index.json", "r") as f:
    index = json.load(f)

prefix_path = "/itet-stor/maxihuber/deepeye_storage/foundation/tueg/edf"

for i, ie in enumerate(tqdm(index, "Recomputing durations", position=0, leave=True)):
    file = prefix_path + ie["path"]
    # load
    df = load_from_path(file, ie["channels"], ie["sr"])
    ie["duration"] = len(df) / ie["sr"]
    del df

    if i % 1_000 == 0:
        with open(tueg_path_new, 'w') as f:
            json.dump(index, f, indent=4)

with open(tueg_path_new, 'w') as f:
    json.dump(index, f, indent=4)

Recomputing durations:   0%|             | 178/69652 [02:11<14:14:52,  1.35it/s]

KeyboardInterrupt



In [60]:
from natsort import natsorted

with open("/itet-stor/maxihuber/deepeye_storage/index_files/full_tueg_index.json", "r") as f:
    tueg = json.load(f)

with open("/itet-stor/maxihuber/deepeye_storage/index_files/tuab_train_index.json", "r") as f:
    tuab = json.load(f)

full = tueg + tuab

chns = set()
for ie in full:
    chns.update([get_generic_channel_name(chn) for chn in ie["channels"]])

real_chns = ['f8', 't3', 't6', 'f3', 'fp2', 't1', 'f4', 'o1', 'fp1', 't4', 'p3', 'cz', 't5', 'c3', 'p4', 'oz', 'f7', 'fz', 'c4p', 'c3p', 'o2', 'pz', 't2', 'a1', 'a2']
real_chns = natsorted(real_chns)
real_chns = ['None'] + real_chns
real_chns = {v: k for k, v in enumerate(real_chns)}

all_real_chns = set(['105', 't6', '47', '44', '28', 't3', '103', '100', '82', '81', '45', 'p3', '78', '77', '50', 'p4', 'f4', 'f8', '20', '99', '98', 'f3', 'f7', '57', '56', '36', '35', '66', '65', '63', '62', '31', '30', '69', '68', '107', '70', '120', 'a2', 't1', 'a1', '79', '127', '128', '26', '24', '121', '22', '21', '123', '85', 't2', 'c4p', '89', '88', '87', '86', '94', '92', '91', '90', '109', '108', '111', '110', '113', '112', '115', '114', '117', '116', '119', '118', '42', '59', '58', '41', '40', '43', '42', 'o1', '48', 'c3', 't5', '46', '49', 'c3p', 'cz', 'c4', 't4', '73', 'c3', 't1', 't2', 't3', 't4', 't5', 't6', 'p3', 'p4', 'pz', 'f7', 'f8', 'f3', 'f4', 'fz', 'o1', 'o2', 'oz', 'fp1', 'fp2'])
all_real_chns = set(natsorted(list(all_real_chns)))
print(all_real_chns)

counts = {}
for ie in full:
    good_channels = []
    for chn in ie["channels"]:
        if get_generic_channel_name(chn) in all_real_chns:
            good_channels.append(chn)
        counts[chn] = counts[chn] + 1 if chn in counts else 1
    ie["good_channels"] = good_channels

high_counts = {'fp1': 72369, 'fp2': 72369, 'f3': 72368, 'f4': 72369, 
               'c3': 72371, 'c4': 72371, 'p3': 72367, 'p4': 72367, 
               'o1': 72367, 'o2': 72367, 'f7': 72369, 'f8': 72369, 
               't3': 72371, 't4': 72371, 't5': 72369, 't6': 72369, 
               'a1': 65950, 'a2': 65950, 'fz': 72130, 'cz': 72370, 
               'pz': 72128, 'roc': 26181, 'loc': 26180, 'ekg1': 56335, 
               't1': 61068, 't2': 61072, '26': 21088, '27': 20471, 
               '28': 22780, '29': 24374, '30': 25998, 'oz': 12864, 
               'pg1': 12240, 'pg2': 12240, 'ekg': 13358, '31': 29457, '32': 29457, 
               '23': 3179, '24': 3179, '20': 2599, '21': 2626, '22': 2626, 
               '25': 4790, 'c3p': 23616, 'c4p': 23613, 'sp1': 29445, 'sp2': 28601}
high_counts = list(high_counts.keys())
high_counts = list(set(high_counts) & set(real_chns.keys()))
high_counts = natsorted(high_counts)
high_counts = ['None'] + high_counts
high_counts = {v: k for k, v in enumerate(high_counts)}

print(real_chns)
print(high_counts)

assert False, "break before dumping"

with open('/home/maxihuber/eeg-foundation/src/data/components/channels_to_id2.json', 'w') as f:
    json.dump(high_counts, f, indent=4)
    print("Dumped file!")

['20', '21', '22', '24', '26', '28', '30', '31', '35', '36', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '56', '57', '58', '59', '62', '63', '65', '66', '68', '69', '70', '73', '77', '78', '79', '81', '82', '85', '86', '87', '88', '89', '90', '91', '92', '94', '98', '99', '100', '103', '105', '107', '108', '109', '110', '111', '112', '113', '114', '115', '116', '117', '118', '119', '120', '121', '123', '127', '128', 'a1', 'a2', 'c3', 'c3p', 'c4', 'c4p', 'cz', 'f3', 'f4', 'f7', 'f8', 'fp1', 'fp2', 'fz', 'o1', 'o2', 'oz', 'p3', 'p4', 'pz', 't1', 't2', 't3', 't4', 't5', 't6']
{'None': 0, 'a1': 1, 'a2': 2, 'c3': 3, 'c3p': 4, 'c4p': 5, 'cz': 6, 'f3': 7, 'f4': 8, 'f7': 9, 'f8': 10, 'fp1': 11, 'fp2': 12, 'fz': 13, 'o1': 14, 'o2': 15, 'oz': 16, 'p3': 17, 'p4': 18, 'pz': 19, 't1': 20, 't2': 21, 't3': 22, 't4': 23, 't5': 24, 't6': 25}
{'None': 0, 'a1': 1, 'a2': 2, 'c3': 3, 'c3p': 4, 'c4p': 5, 'cz': 6, 'f3': 7, 'f4': 8, 'f7': 9, 'f8': 10, 'fp1': 11, 'fp2': 12, 'fz': 13, 'o1'

AssertionError: break before dumping

In [71]:
with open("/itet-stor/maxihuber/deepeye_storage/index_files/311G_tueg_index.json", "r") as f:
    tueg = json.load(f)

with open("/itet-stor/maxihuber/deepeye_storage/index_files/tuab_train_index2.json", "r") as f:
    tuab_train = json.load(f)

with open("/itet-stor/maxihuber/deepeye_storage/index_files/tuab_test_index2.json", "r") as f:
    tuab_test = json.load(f)

index = tueg

for ie in index:
    good_channels = []
    for chn in ie["channels"]:
        if get_generic_channel_name(chn) in all_real_chns:
            good_channels.append(chn)
    ie["good_channels"] = good_channels

with open("/itet-stor/maxihuber/deepeye_storage/index_files/311G_tueg_index2.json", "w") as f:
    json.dump(index, f, indent=4)

In [75]:
with open('/itet-stor/maxihuber/deepeye_storage/index_files/full_tueg_index.json', 'r') as f:
    tueg = json.load(f)

with open('/itet-stor/maxihuber/deepeye_storage/index_files/311G_tueg_index.json', 'r') as f:
    index300 = json.load(f)

# Convert tueg from list to dictionary for faster searching access
tueg_dict = {entry['path']: entry for entry in tueg}  # Assuming each entry has a unique 'id' key

index300new = []
for ie in index300:
    # Find the same ie in tueg_dict
    ie_id = ie['path']  # Assuming each entry in index300 has an 'id' key
    if ie_id in tueg_dict:
        ie_tueg = tueg_dict[ie_id]
        index300new.append(ie_tueg)

with open('/itet-stor/maxihuber/deepeye_storage/index_files/311G_tueg_index2.json', 'w') as f:
    json.dump(index300new, f)

In [55]:
def get_subject_id(filepath):
    if filepath.endswith("pkl"):
        # Regular expression to match the UUID in the middle of the file path
        match = re.search(
            r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", filepath
        )
        if match:
            return match.group(0)
        return None
    elif filepath.endswith("edf"):
        parts = filepath.split("/")
        subject_id = parts[
            2
        ]  # The subject ID is the third element after splitting by '/'
        return subject_id
    else:
        assert False, f"invaliv file format for file {filepath}"

with open("/itet-stor/maxihuber/deepeye_storage/index_files/tuab_test_index.json", "r") as f:
    tuab = json.load(f)

for ie in tuab:
    ie["ref"] = None
    ie["SubjectID"] = get_subject_id(ie["path"])
    ie["Dataset"] = "TUAB"

assert False, "break before dumping"

with open("/itet-stor/maxihuber/deepeye_storage/index_files/tuab_test_index2.json", "w") as f:
    tuab = json.dump(tuab, f, indent=4)

In [50]:
print(len('/itet-stor/maxihuber/deepeye_storage/foundation/tueg/edf'))

56


In [20]:
def get_generic_channel_name(channel_name):
    channel_name = channel_name.lower()
    # Remove "eeg " prefix if present
    if channel_name.startswith("eeg "):
        channel_name = channel_name[4:]
    # Simplify names with a dash and check if it ends with "-"
    if "-" in channel_name:
        if channel_name.endswith("-"):
            return "None"
        return channel_name.split("-")[0]
    return channel_name

hist = {}
for ie in index:
    for chn in ie["channels"]:
        chn = get_generic_channel_name(chn)
        if chn == '23':
            pass
            #print(ie["sr"])
        hist[chn] = hist[chn]+1 if chn in hist else 1

durs = [int(ie["duration"]) for ie in index if ie["duration"] > 10]
counts, bins, _ = plt.hist(durs, bins=range(min(durs), max(durs) + 3_600, 3_600))

for bin, count in zip(bins, counts):
    print(f"{bin}: {count}")
#print(hist)

KeyboardInterrupt: 

# Finetuning Notebook for NeuroBench

### Imports

In [25]:
def get_nr_x_patches(win_size, dur):
    win_shift = win_size * self_win_shift_factor
    x_datapoints_per_second = 1 / win_shift
    x_datapoints = dur * x_datapoints_per_second + 1
    return int(x_datapoints // self_patch_size)

get_nr_x_patches(.25, 600)

600

In [70]:
file = "/itet-stor/maxihuber/deepeye_storage/foundation/tueg_tasks/epilepsy/00_epilepsy/aaaaamoa/s004_2012/01_tcp_ar/aaaaamoa_s004_t000.edf"

def get_generic_channel_name(channel_name):
    channel_name = channel_name.lower()
    # Remove "eeg " prefix if present
    if channel_name.startswith("eeg "):
        channel_name = channel_name[4:]
    # Simplify names with a dash and check if it ends with "-"
    if "-" in channel_name:
        if channel_name.endswith("-"):
            return "None"
        return channel_name.split("-")[0]
    return channel_name

def load_edf_to_dataframe(file_path):
    eeg_data = mne.io.read_raw_edf(file_path, preload=True)
    channel_data_dict = {}

    for channel in eeg_data.ch_names:
        idx = eeg_data.ch_names.index(channel)
        channel = get_generic_channel_name(channel)
        data, times = eeg_data[idx, :]
        channel_data_dict[channel] = data.flatten()

    df = pd.DataFrame(channel_data_dict)
    df['Time in Seconds'] = times.flatten()
    return df

df = load_edf_to_dataframe(file)
task_channels = set(df.columns) & set(['p4', 'c3', 'pz', 'fp2', 't6', 'fz', 'f4', 'f8', 't4', 't3', 'p3', 'a2', 'oz', 'a1', 't5', 't1', 'cz', 'c4', 'fp1', 'o1', 'o2', 'f3', 'f7']))

signal = df["cz"] * 1_000_000
signal = torch.tensor(signal)

In [50]:
task_channels = set(df.columns) & set(['p4', 'c3', 'pz', 'fp2', 't6', 'fz', 'f4', 'f8', 't4', 't3', 'p3', 'a2', 'oz', 'a1', 't5', 't1', 'cz', 'c4', 'fp1', 'o1', 'o2', 'f3', 'f7'])
print(len(task_channels))


22
{'t5', 'p4', 'o1', 't1', 'fp2', 'fz', 'cz', 't3', 'a1', 't6', 'f3', 'c4', 'f7', 'o2', 'a2', 'c3', 'p3', 'f8', 'f4', 't4', 'fp1', 'pz'}


In [79]:
montage = mne.channels.make_standard_montage('standard_1005')
print(sorted(montage.ch_names))

['A1', 'A2', 'AF1', 'AF10', 'AF10h', 'AF1h', 'AF2', 'AF2h', 'AF3', 'AF3h', 'AF4', 'AF4h', 'AF5', 'AF5h', 'AF6', 'AF6h', 'AF7', 'AF7h', 'AF8', 'AF8h', 'AF9', 'AF9h', 'AFF1', 'AFF10', 'AFF10h', 'AFF1h', 'AFF2', 'AFF2h', 'AFF3', 'AFF3h', 'AFF4', 'AFF4h', 'AFF5', 'AFF5h', 'AFF6', 'AFF6h', 'AFF7', 'AFF7h', 'AFF8', 'AFF8h', 'AFF9', 'AFF9h', 'AFFz', 'AFp1', 'AFp10', 'AFp10h', 'AFp1h', 'AFp2', 'AFp2h', 'AFp3', 'AFp3h', 'AFp4', 'AFp4h', 'AFp5', 'AFp5h', 'AFp6', 'AFp6h', 'AFp7', 'AFp7h', 'AFp8', 'AFp8h', 'AFp9', 'AFp9h', 'AFpz', 'AFz', 'C1', 'C1h', 'C2', 'C2h', 'C3', 'C3h', 'C4', 'C4h', 'C5', 'C5h', 'C6', 'C6h', 'CCP1', 'CCP1h', 'CCP2', 'CCP2h', 'CCP3', 'CCP3h', 'CCP4', 'CCP4h', 'CCP5', 'CCP5h', 'CCP6', 'CCP6h', 'CCPz', 'CP1', 'CP1h', 'CP2', 'CP2h', 'CP3', 'CP3h', 'CP4', 'CP4h', 'CP5', 'CP5h', 'CP6', 'CP6h', 'CPP1', 'CPP1h', 'CPP2', 'CPP2h', 'CPP3', 'CPP3h', 'CPP4', 'CPP4h', 'CPP5', 'CPP5h', 'CPP6', 'CPP6h', 'CPPz', 'CPz', 'Cz', 'F1', 'F10', 'F10h', 'F1h', 'F2', 'F2h', 'F3', 'F3h', 'F4', 'F4h', 

In [1]:
import sys
sys.path.append("/home/maxihuber/eeg-foundation/")

import json
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
from sklearn.preprocessing import StandardScaler
import os
import pickle
from tqdm import tqdm
import lightning as L
import torch.nn as nn
from torch.utils.data import random_split
from lightning.pytorch.callbacks import ModelCheckpoint
import random
from collections import Counter
from collections import defaultdict

import os
import numpy as np
import mne
import torch
from tqdm import tqdm
from mne.preprocessing import Xdawn
import pickle
from sklearn.preprocessing import LabelEncoder
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.linear_model import LinearRegression
from sklearn.metrics import balanced_accuracy_score, mean_squared_error

import torchaudio
from src.data.transforms import (
    crop_spg,
    custom_fft,
    normalize_spg,
)

import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torchmetrics

from functools import partial
from sklearn.metrics import balanced_accuracy_score
from src.models.mae_rope_encoder import EncoderViTRoPE
from src.models.components.vit_rope import (
    Flexible_RoPE_Layer_scale_init_Block,
    FlexibleRoPEAttention,
    compute_axial_cis,
    select_freqs_cis,
)
from timm.models.vision_transformer import Mlp as Mlp
from torch.nn import TransformerEncoderLayer
from src.models.components.SimpleTransformer import SimpleTransformer

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import balanced_accuracy_score
from torchmetrics.functional import mean_squared_error as rmse
import lightning.pytorch as L
from lightning.pytorch.callbacks import ModelCheckpoint
from collections import defaultdict

mne.set_log_level('warning')

L.seed_everything(42)

[rank: 0] Seed set to 42


42

## Data Loading

### Load Train/Val/Test Information

In [2]:
########################################################################################################################
# TUAB and Epilepsy

yc_class = {
    "class_name": "YC",
    "time_col": "Time in Seconds",
    "prefix_filepath": "/itet-stor/maxihuber/deepeye_storage/foundation/tueg/edf",
    "load_mode": 2,
}

tuab = {
    "task_name": "TUAB",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/deepeye_storage/foundation_tasks/cli/tuab.json",
    "out_dim": 2,
}

epilepsy = {
    "task_name": "Epilepsy",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/deepeye_storage/foundation_tasks/cli/epilepsy.json",
    "out_dim": 2,
}

########################################################################################################################
# Clinical JSONs

cli_class = {
    "class_name": "Clinical",
    "time_col": "Time in Seconds",
    "prefix_filepath": "/itet-stor/maxihuber/deepeye_storage/foundation_clinical_prepared/",
    "load_mode": 0,
}

age = {
    "task_name": "Age",
    "task_type": "Regression",
    "json_path": "/itet-stor/maxihuber/deepeye_storage/foundation_clinical_tasks/age.json",
    "out_dim": 1,
}

depression = {
    "task_name": "Depression",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/deepeye_storage/foundation_clinical_tasks/cli_depression.json",
    "out_dim": 2,
}

parkinsons = {
    "task_name": "Parkinsons",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/deepeye_storage/foundation_clinical_tasks/cli_parkinsons.json",
    "out_dim": 2,
}

schizophrenia = {
    "task_name": "Schizophrenia",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/deepeye_storage/foundation_clinical_tasks/cli_schizophrenia.json",
    "out_dim": 2,
}

sex = {
    "task_name": "Sex",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/deepeye_storage/foundation_clinical_tasks/sex.json",
    "out_dim": 2,
}


########################################################################################################################
# Motor-Imagery JSONs

mi_class = {
    "class_name": "Motor Imagery",
    "time_col": "time in seconds",
    "prefix_filepath": "/itet-stor/maxihuber/deepeye_storage/foundation_prepared/",
    "load_mode": 0,
}

eye_open_closed = {
    "task_name": "EyeOpenClosed",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/deepeye_storage/foundation_tasks/mi/eye_open_closed.json",
    "out_dim": 2,
    "outputs": set(["eye open", "eye closed"]),
    "short_mode": False,
}

eye_vh = {
    "task_name": "EyeVH",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/deepeye_storage/foundation_tasks/mi/eye_vh.json",
    "out_dim": 2,
    "outputs": set(["vertical", "horizontal"]),
    "short_mode": False,
}

flexion_extension_imaginary = {
    "task_name": "FlexionExtensionImaginary",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/deepeye_storage/foundation_tasks/mi/flexion_extension_imaginary.json",
    "out_dim": 2,
    "outputs": set(
        [
            "hand movement imagined elbow flexion",
            "hand movement imagined elbow extension",
        ]
    ),
    "short_mode": False,
}

flexion_extension_real = {
    "task_name": "FlexionExtensionReal",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/deepeye_storage/foundation_tasks/mi/flexion_extension_real.json",
    "out_dim": 2,
    "outputs": set(["hand movement elbow extension", "hand movement elbow flexion"]),
    "short_mode": False,
}

grasp_imaginary = {
    "task_name": "GraspImaginary",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/deepeye_storage/foundation_tasks/mi/grasp_imaginary.json",
    "out_dim": 2,
    "outputs": set(["imagined palmar grasp", "imagined lateral grasp"]),
    "short_mode": False,
}

grasp_real = {
    "task_name": "GraspReal",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/deepeye_storage/foundation_tasks/mi/grasp_real.json",
    "out_dim": 2,
    "outputs": set(["movement palmar grasp", "movement lateral grasp"]),
    "short_mode": False,
}

lr_imaginary = {
    "task_name": "LRImaginary",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/deepeye_storage/foundation_tasks/mi/lr_imaginary.json",
    "out_dim": 2,
    "outputs": set(["left hand imagined movement", "right hand imagined movement"]),
    "short_mode": True,
}

lr_real = {
    "task_name": "LRReal",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/deepeye_storage/foundation_tasks/mi/lr_real.json",
    "out_dim": 2,
    "outputs": set(["right hand movement", "left hand movement"]),
    "short_mode": True,
}

mi_task_body_parts_real = {
    "task_name": "BodyPartsReal",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/deepeye_storage/foundation_tasks/mi/mi_task_body_parts.json",
    "out_dim": 4,
    "outputs": set(
        ["rest", "right hand movement", "foot movement", "left hand movement"]
    ),
    "short_mode": True,
}

mi_task_body_parts_imagined = {
    "task_name": "BodyPartsImagined",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/deepeye_storage/foundation_tasks/mi/mi_task_body_parts.json",
    "out_dim": 4,
    "outputs": set(
        [
            "rest",
            "right hand imagined movement",
            "foot imagined movement",
            "left hand imagined movement",
            "tongue imagined movement",
        ]
    ),
    "short_mode": True,
}

pronation_supination_real = {
    "task_name": "PronationSupinationReal",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/deepeye_storage/foundation_tasks/mi/pronation_supination_real.json",
    "out_dim": 2,
    "outputs": set(["movement supination", "movement pronation"]),
    "short_mode": False,
}

pronation_supination_imaginary = {
    "task_name": "PronationSupinationImaginary",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/deepeye_storage/foundation_tasks/mi/pronation_supination_imaginary.json",
    "out_dim": 2,
    "outputs": set(["imagined supination", "imagined pronation"]),
    "short_mode": False,
}

########################################################################################################################
# ERP JSONs

erp_class = {
    "class_name": "Error-Related Potential",
    "time_col": "time in seconds",
    "prefix_filepath": "/itet-stor/maxihuber/deepeye_storage/foundation_prepared/",
    "load_mode": 0,
}

erp = {
    "task_name": "ERP",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/deepeye_storage/foundation_tasks/erp/erp_all.json",
    "out_dim": 5,
    "outputs": set(
        [
            "Participant is in resting state",
            "with event-related potential",
            "Participant is in interval between two flashes",
            "without event-related potential",
            "Participant keeps closing eyes",
        ]
    ),
}

errp = {
    "task_name": "ERRP",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/deepeye_storage/foundation_tasks/erp/errp_all.json",
    "out_dim": 7,
    "outputs": set(
        [
            "Target is located in the right",
            "without error-related potential",
            "The cursor moves to the left",
            "The feedback consisted in the selected item is presented on the screen",
            "The cursor moves to the right",
            "with error-related potential",
            "Target is located in the left",
        ]
    ),
}

########################################################################################################################
# EyeNet JSONs

eye_class = {
    "class_name": "EyeNet",
    "time_col": "time",
    "prefix_filepath": "/itet-stor/maxihuber/deepeye_storage/foundation_prepared/",
    "load_mode": 1,
}

eye_dir_amp = {
    "task_name": "EyeNetDirectionAmp",
    "task_type": "Regression",
    "json_path": [
        "/itet-stor/maxihuber/deepeye_storage/eegeyenet_tasks/EEGEyeNet_Direction_Amp_train.json",
        "/itet-stor/maxihuber/deepeye_storage/eegeyenet_tasks/EEGEyeNet_Direction_Amp_val.json",
        "/itet-stor/maxihuber/deepeye_storage/eegeyenet_tasks/EEGEyeNet_Direction_Amp_test.json",
    ],
    "out_dim": 1,
}

eye_dir_ang = {
    "task_name": "EyeNetDirectionAng",
    "task_type": "Regression",
    "json_path": [
        "/itet-stor/maxihuber/deepeye_storage/eegeyenet_tasks/EEGEyeNet_Direction_Ang_train.json",
        "/itet-stor/maxihuber/deepeye_storage/eegeyenet_tasks/EEGEyeNet_Direction_Ang_val.json",
        "/itet-stor/maxihuber/deepeye_storage/eegeyenet_tasks/EEGEyeNet_Direction_Ang_test.json",
    ],
    "out_dim": 1,
}

eye_lr = {
    "task_name": "EyeNetLR",
    "task_type": "Classification",
    "json_path": [
        "/itet-stor/maxihuber/deepeye_storage/eegeyenet_tasks/EEGEyeNet_LR_train.json",
        "/itet-stor/maxihuber/deepeye_storage/eegeyenet_tasks/EEGEyeNet_LR_val.json",
        "/itet-stor/maxihuber/deepeye_storage/eegeyenet_tasks/EEGEyeNet_LR_test.json",
    ],
    "out_dim": 2,
}

eye_position = {
    "task_name": "EyeNetPosition",
    "task_type": "Regression",
    "json_path": [
        "/itet-stor/maxihuber/deepeye_storage/eegeyenet_tasks/EEGEyeNet_Position_train.json",
        "/itet-stor/maxihuber/deepeye_storage/eegeyenet_tasks/EEGEyeNet_Position_val.json",
        "/itet-stor/maxihuber/deepeye_storage/eegeyenet_tasks/EEGEyeNet_Position_test.json",
    ],
    "out_dim": 2,
}

### Load data into memory

In [4]:
########################################################################################################################
# Select the class and task

used_class = yc_class
# used_class = cli_class
# used_class = mi_class
# used_class = erp_class
# used_class = eye_class
#
used_task = tuab
# used_task = epilepsy
# used_task = age
# used_task = depression
# used_task = parkinsons
# used_task = schizophrenia
# used_task = sex
#
# used_task = eye_open_closed
# used_task = eye_vh
# used_task = flexion_extension_imaginary
# used_task = flexion_extension_real
# used_task = grasp_real
# used_task = lr_imaginary
# used_task = lr_real
# used_task = mi_task_body_parts_real
# used_task = mi_task_body_parts_imagined
# used_task = pronation_supination_real
# used_task = pronation_supination_imaginary
#
# used_task = erp
# used_task = errp
#
# used_task = eye_dir_amp
# used_task = eye_dir_ang
# used_task = eye_lr
# used_task = eye_position

class_name = used_class["class_name"]
time_col = used_class["time_col"]
prefix_filepath = used_class["prefix_filepath"]
load_mode = used_class["load_mode"]
task_name = used_task["task_name"]
task_type = used_task["task_type"]
json_path = used_task["json_path"]
out_dim = used_task["out_dim"]
short_mode = used_task["short_mode"] if "short_mode" in used_task else False

task_channels = set(['p4', 'c3', 'pz', 'fp2', 't6', 'fz', 'f4', 'f8', 't4', 't3', 'p3', 'a2', 'oz', 'a1', 't5', 't1', 'cz', 'c4', 'fp1', 'o1', 'o2', 'f3', 'f7'])


def load_index0(data_index_path):
    with open(data_index_path, "r") as f:
        train_test_dict = json.load(f)
    train_samples = train_test_dict["train"]
    test_samples = train_test_dict["test"]
    return train_samples, test_samples


def load_index1(data_index_paths):
    all_samples = []
    for data_index_path in data_index_paths:
        with open(data_index_path, "r") as f:
            subset_dict = json.load(f)
        all_samples.append(list(subset_dict.values())[0])
    return all_samples[0], all_samples[1], all_samples[2]


dataset_dict = {
    "ERP_ERP_ANA": 0,
    "RS_RS_ALPHA": 1,
    "ERP_ERP_BISC": 2,
    "ERP_ERP_BBI": 3,
    "ERP_ERP_BICF": 4,
    "ERP_ERP_BICD": 5,
    "RS_RS_SPIS": 6,
    "MI_MI_HGD": 7,
    "MI_MI_SCP": 8,
    "ErrP_ErrP_MERP": 9,
    "MI_MI_ULM": 10,
    "MI_MI_VEP": 11,
    "MI_MI_LR": 12,
    "MI_BBCI_IV_Graz_b": 13,
    "MI_MI_EB": 14,
    "MI_BBCI_IV_Graz_a": 15,
    "MI_MI_GVH_V": 16,
    "MI_MI_GAL": 17,
    "MI_MI_Two": 18,
    "MI_MI_GVH_H": 19,
    "MI_MI_II": 20,
    "ErrP_ErrP_BCI": 21,
    "MI_MI_GVH_G": 22,
    "MI_MI_Limb": 23,
    "MI_MI_SCI": 24,
    "MI_BBCI_IV_Berlin": 25,
    "MI_eegmmidb": 26,
    "ERP_ERP_FHD": 27,
    "RS_RS_EID": 28,
}


def extract_dataset_name(file_path, dataset_dict):
    for name in dataset_dict.keys():
        if name in file_path:
            return name
    return "Unknown"

def get_generic_channel_name(channel_name):
    channel_name = channel_name.lower()
    # Remove "eeg " prefix if present
    if channel_name.startswith("eeg "):
        channel_name = channel_name[4:]
    # Simplify names with a dash and check if it ends with "-"
    if "-" in channel_name:
        if channel_name.endswith("-"):
            return "None"
        return channel_name.split("-")[0]
    return channel_name

def load_edf_to_dataframe(file_path):
    eeg_data = mne.io.read_raw_edf(file_path, preload=True)
    channel_data_dict = {}

    for channel in eeg_data.ch_names:
        idx = eeg_data.ch_names.index(channel)
        channel = get_generic_channel_name(channel)
        data, times = eeg_data[idx, :]
        channel_data_dict[channel] = data.flatten()

    df = pd.DataFrame(channel_data_dict)
    df['Time in Seconds'] = times.flatten()
    return df

def load_file_data(data_index, task_channels):
    num_samples = 0
    data = {}
    outputs = {}
    srs = {}
    durs = {}
    channels = {}
    datasets = {}
    failed_samples = []

    for sample in tqdm(data_index, desc="Loading data", position=0, leave=True):
        try:
            # Load the data of this sample
            input_files = sample["input"]

            if load_mode == 2:
                file = prefix_filepath + input_files[0] if "/itet-stor" not in input_files[0] else input_files[0]
                df = load_edf_to_dataframe(file)
                datasets[num_samples] = "TUEG"
            else:
                dataframes = [pd.read_pickle(filepath) for filepath in input_files]
                df = pd.concat(dataframes, axis=0)
                dataset_name = extract_dataset_name(file, dataset_dict)
                datasets[num_samples] = dataset_name

            # Crop the data to the desired length
            start = int(sample["start"])
            length = int(sample["length"]) if "length" in sample else int(sample["end"])
            df.loc[start : start + length, :] if load_mode==1 else df.iloc[start:length, :]
            assert len(df) > 0, f"Empty dataframe for sample: {sample}"
   
            sr = int(
                1 / float(float(df[time_col].iloc[1]) - float(df[time_col].iloc[0]))
            )
            if load_mode != 1:
                outputs[num_samples] = (
                    sample["output"] if "output" in sample else sample["label"]
                )
            else:
                if task_name == "EyeNetPosition":
                    outputs[num_samples] = list(sample["output"].values())
                else:
                    outputs[num_samples] = list(sample["output"].values())[0]
            srs[num_samples] = sr
            durs[num_samples] = len(df) / sr
            channels[num_samples] = sorted(list(set(df.columns) & set(task_channels)), key=lambda x: list(task_channels).index(x))
            df = df[channels[num_samples]].astype(float)
            signals = torch.tensor(df.to_numpy(), dtype=torch.float32).T
            data[num_samples] = signals
            num_samples += 1
            del df

            if num_samples % 100 == 0:
                gc.collect()

        except Exception as e:
            print(f"Failed to process sample: {sample}. Error: {e}", file=sys.stderr)
            failed_samples.append(sample)

    return data, outputs, srs, durs, channels, datasets


if load_mode != 1:
    print(json_path, file=sys.stderr)
    train_index, test_index = load_index0(json_path)
else:
    train_index, val_index, test_index = load_index1(json_path)

print(f"Full train size: {len(train_index)}", file=sys.stderr)
print(f"Full test size: {len(test_index)}", file=sys.stderr)

if load_mode != 1:
    train_index = train_index[:100] + train_index[-100:]
    test_index = test_index[:50] + test_index[-50:]
else:
    train_index = train_index
    val_index = val_index
    test_index = test_index

if load_mode == 0 or load_mode == 2:
    print("=" * 10 + "Load train data" + "=" * 100)
    train_data, train_outputs, train_sr, train_dur, train_channels, train_datasets = (
        load_file_data(train_index, task_channels)
    )
    print("=" * 10 + "Load test data" + "=" * 100)
    test_data, test_outputs, test_sr, test_dur, test_channels, test_datasets = (
        load_file_data(test_index, task_channels)
    )
elif load_mode == 1:
    train_data, train_outputs, train_sr, train_dur, train_channels, train_datasets = (
        load_file_data(train_index, task_channels)
    )
    val_data, val_outputs, val_sr, val_dur, val_channels, val_datasets = load_file_data(
        val_index, task_channels
    )
    test_data, test_outputs, test_sr, test_dur, test_channels, test_datasets = (
        load_file_data(test_index, task_channels)
    )
else:
    pass


# Label Encoder & Class Weights
from sklearn.preprocessing import LabelEncoder

if isinstance(list(train_outputs.values())[0], str):
    all_outputs = list(set(list(train_outputs.values()) + list(test_outputs.values())))
    label_encoder = LabelEncoder()
    label_encoder.fit(all_outputs)

    print(f"Train classes: {set(train_outputs.values())}", file=sys.stderr)
    print(f"Test classes: {set(test_outputs.values())}", file=sys.stderr)

    # Encode the train and test outputs
    encoded_train_outputs = {
        k: label_encoder.transform([v])[0] for k, v in train_outputs.items()
    }
    encoded_test_outputs = {
        k: label_encoder.transform([v])[0] for k, v in test_outputs.items()
    }

    # Create the output counts map
    train_output_counts = defaultdict(int)
    for output in encoded_train_outputs.values():
        train_output_counts[output] += 1

    test_output_counts = defaultdict(int)
    for output in encoded_test_outputs.values():
        test_output_counts[output] += 1

    full_output_counts = train_output_counts.copy()
    for output, count in test_output_counts.items():
        full_output_counts[output] += count

    print("Full Output Counts:", full_output_counts, file=sys.stderr)

    # Calculate class weights
    total_count = sum(full_output_counts.values())
    class_weights = {
        output: total_count / count for output, count in full_output_counts.items()
    }

    # Convert class weights to a tensor
    weight_tensor = torch.tensor(
        [class_weights[i] for i in range(len(class_weights))], dtype=torch.float
    )
else:
    label_encoder = None
    weight_tensor = None

/itet-stor/maxihuber/deepeye_storage/foundation_tasks/cli/tuab.json
Full train size: 2717
Full test size: 276




Loading data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:49<00:00,  4.03it/s]




Loading data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:19<00:00,  5.07it/s]
Train classes: {'normal', 'abnormal'}
Test classes: {'normal', 'abnormal'}
Full Output Counts: defaultdict(<class 'int'>, {1: 150, 0: 150})


# Pretrained Model

### Instantiate Model

In [8]:
L.seed_everything(42)

ckpt_path = '/itet-stor/maxihuber/net_scratch/checkpoints/980473/epoch=7-step=239317-val_loss=130.45-lr.ckpt'
ckpt_path = '/itet-stor/maxihuber/net_scratch/checkpoints/977598/epoch=0-step=32807-val_loss=133.55.ckpt'

#########################################################################################################
class FinetuneDataset(Dataset):
    def __init__(self, data, outputs, srs, durs, channels, datasets, task_type, label_encoder=None):
        self.data = data
        self.outputs = outputs
        self.srs = srs
        self.durs = durs
        self.channels = channels
        self.datasets = datasets
        self.task_type = task_type
        self.label_encoder = label_encoder
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        signals = self.data[idx]
        output = self.outputs[idx]
        sr = self.srs[idx]
        dur = self.durs[idx]
        channels = self.channels[idx]
        dataset = self.datasets[idx]

        if self.task_type == "Classification" and self.label_encoder is not None:
            output = self.label_encoder.transform([output])[0]  # Encode the output label
            output_tensor = torch.tensor(output, dtype=torch.long)
        else:
            if task_name == "EyeNetPosition":
                output_tensor = torch.tensor(output, dtype=torch.float32)
            else:
                output_tensor = torch.tensor([output], dtype=torch.float32)
        
        return {
            "signals": signals,
            "output": output_tensor,
            "sr": sr,
            "dur": dur,
            "channels": channels,
            "dataset": dataset
        }

if load_mode != 1:
    full_train_dataset = FinetuneDataset(train_data, train_outputs, train_sr, train_dur, train_channels, train_datasets, task_type=task_type, label_encoder=label_encoder)
    test_dataset = FinetuneDataset(test_data, test_outputs, test_sr, test_dur, test_channels, test_datasets, task_type=task_type, label_encoder=label_encoder)
    # Define the split ratio
    train_ratio = 0.85
    val_ratio = 0.15
    # Calculate lengths for train and validation sets
    total_size = len(full_train_dataset)
    train_size = int(train_ratio * total_size)
    val_size = total_size - train_size
    # Split the dataset
    train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])
elif load_mode == 1:
    train_dataset = FinetuneDataset(train_data, train_outputs, train_sr, train_dur, train_channels, train_datasets, task_type=task_type, label_encoder=label_encoder)
    val_dataset = FinetuneDataset(val_data, val_outputs, val_sr, val_dur, val_channels, val_datasets, task_type=task_type, label_encoder=label_encoder)
    test_dataset = FinetuneDataset(test_data, test_outputs, test_sr, test_dur, test_channels, test_datasets, task_type=task_type, label_encoder=label_encoder)
else:
    pass
    
#########################################################################################################
# DataLoaders
self_win_shifts = [.25, .5, 1, 2, 4, 8]
self_patch_size = 16
self_win_shift_factor = .25
self_max_win_shift = self_win_shifts[-1]
self_max_y_datapoints = 4_000
max_nr_patches = 8_500

def get_nr_y_patches(win_size, sr):
    return int((sr / 2 * win_size + 1) / self_patch_size)

def get_nr_x_patches(win_size, dur):
    win_shift = win_size * self_win_shift_factor
    x_datapoints_per_second = 1 / win_shift
    x_datapoints = dur * x_datapoints_per_second + 1
    return int(x_datapoints // self_patch_size)

channel_name_map_path = '/home/maxihuber/eeg-foundation/src/data/components/channels_to_id.json'
with open(channel_name_map_path, "r") as file:
    self_channel_name_map = json.load(file)

def self_get_generic_channel_name(channel_name):
    channel_name = channel_name.lower()
    # Remove "eeg " prefix if present
    if channel_name.startswith("eeg "):
        channel_name = channel_name[4:]
    # Simplify names with a dash and check if it ends with "-"
    if "-" in channel_name:
        if channel_name.endswith("-"):
            return "None"
        return channel_name.split("-")[0]
    return channel_name

def self_encode_mean(mean, win_size):
    y_datapoints = mean.shape[0]
    encoded_mean = torch.zeros(self_max_y_datapoints)
    step_size = int(self_max_win_shift // win_size)
    end_idx = step_size * y_datapoints
    indices = torch.arange(0, end_idx, step_size)
    encoded_mean[indices] = mean.squeeze_().float()
    encoded_mean.unsqueeze_(1)
    return encoded_mean

#########################################################################################################
# collate_fn
# make batches as the pre-trained network expects (channel tokens, means, standard deviation etc.)
def sample_collate_fn(batch):

    signals, output, sr, dur, channels, dataset = batch[0]["signals"], batch[0]["output"], batch[0]["sr"], batch[0]["dur"], batch[0]["channels"], batch[0]["dataset"]

    if dur > 1_000:
        dur = 1_000
        signals = signals[:, :1_000*sr]
    
    valid_win_shifts = """
    # TODO: compute spectrograms for each win_size
    # gives a new dimension (S) in batch
    # need another extra transformer after the encoder
    # (B, 1, H, W) -> (S*B, 1, H, W)
    valid_win_shifts = [
        win_shift
        for win_shift in self_win_shifts
        if get_nr_y_patches(win_shift, sr) >= 1
        and get_nr_x_patches(win_shift, dur) >= 1
        and get_nr_y_patches(win_shift, sr) * get_nr_x_patches(win_shift, dur) < max_nr_patches
    ]

    assert valid_win_shifts != [], "no valid win_shift found"
    """
    valid_win_shifts = [1]

    # list holding assembled tensors for varying window shifts
    full_batch = {}   

    for win_size in valid_win_shifts:
        
        fft = torchaudio.transforms.Spectrogram(
            n_fft=int(sr * win_size),
            win_length=int(sr * win_size),
            hop_length=int(sr * win_size * self_win_shift_factor),
            normalized=True,
        )
    
        spg_list = []
        chn_list = []
        mean_list = []
        std_list = []
    
        for signal, channel in zip(signals, channels):
            
            # Channel information
            channel_name = self_get_generic_channel_name(channel)
            channel = self_channel_name_map[channel_name] if channel_name in self_channel_name_map else self_channel_name_map["None"]
    
            # Spectrogram Computation & Cropping
            spg = fft(signal)
            spg = spg**2
            spg = crop_spg(spg, self_patch_size)
            
            H_new, W_new = spg.shape[0], spg.shape[1]
            h_new, w_new = H_new // self_patch_size, W_new // self_patch_size
    
            # Prepare channel information (per-patch)
            channel = torch.full((h_new, w_new), channel, dtype=torch.float16)
            
            spg, mean, std = normalize_spg(spg)
            mean = self_encode_mean(mean, win_size)
            std = self_encode_mean(std, win_size)
            
            spg_list.append(spg)
            chn_list.append(channel)
            mean_list.append(mean)
            std_list.append(std)
        
        win_batch = torch.stack(spg_list)
        win_channels = torch.stack(chn_list)
        win_means = torch.stack(mean_list)
        win_stds = torch.stack(std_list)
        
        win_batch.unsqueeze_(1)
        win_channels = win_channels.flatten(1)
        win_means = win_means.transpose(1, 2)
        win_stds = win_stds.transpose(1, 2)
        
        full_batch[win_size] = {
            "batch": win_batch,
            "channels": win_channels,
            "means": win_means,
            "stds": win_stds
        }
        #print(f"[collate_fn] win_size={win_size}: {win_batch.shape}")
        
    # == Finished iterating over all possible window shifts
    print("collate_fn")
    return full_batch, output, dataset

train_loader = DataLoader(train_dataset, batch_size=1, collate_fn=sample_collate_fn, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, collate_fn=sample_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=sample_collate_fn)

print(len(train_loader), len(val_loader), len(test_loader))

#########################################################################################################
# Model
# == Metrics ==
def rmse(y_true, y_pred):
    return torch.sqrt(torch.mean((y_true - y_pred) ** 2))


def balanced_accuracy(y_true, y_pred):
    return balanced_accuracy_score(y_true, y_pred)

class SingleTransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead):
        super(SingleTransformerEncoderLayer, self).__init__()
        self.encoder_layer = TransformerEncoderLayer(d_model, nhead)

    def forward(self, src):
        return self.encoder_layer(src)

def mean_aggregation(tokens):
    return torch.mean(torch.stack(tokens), dim=0)

class FineTuningModel(L.LightningModule):
    def __init__(self, encoder, frozen_encoder, out_dim, task_name, task_type, learning_rate, mask_ratio):
        super(FineTuningModel, self).__init__()

        self.task_name = task_name
        self.task_type = task_type
        self.learning_rate = learning_rate
        self.mask_ratio = mask_ratio

        # Pretrained network
        self.encoder = encoder       
        if frozen_encoder:
            self.freeze_encoder()

        # Finetuning network
        self.finetune_time_transformer = SimpleTransformer(
            embed_size=384,
            max_len=8_5000
        )
        
        self.finetune_channel_transformer = SimpleTransformer(
            embed_size=384,
            max_len=200,
        )
        
        # Modular aggregation method on channel tokens
        self.win_shift_aggregation = mean_aggregation
        
        if task_type == "Regression":
            self.head = nn.Linear(encoder.encoder_embed_dim, 1)
            self.criterion = nn.MSELoss()
        else:
            self.head = nn.Linear(encoder.encoder_embed_dim, 1)
            self.criterion = nn.BCEWithLogitsLoss()

        self.train_step_outputs = []
        self.validation_step_outputs = []
        self.test_step_outputs = []

    def forward(self, full_x):
        x_embeds = {}
        H_W = {}

        print(f"[FT.forward] win_shifts: {full_x.keys()}")
        
        for win_size, x_win in full_x.items():
            spgs = x_win["batch"]
            channels = x_win["channels"]
            means = x_win["means"]
            stds = x_win["stds"]
            B, C, H, W = spgs.shape
            x_emb, _, _, nr_meta_patches = self.encoder(
                x=spgs,
                means=means,
                stds=stds,
                channels=channels,
                win_size=win_size,
                mask_ratio=self.mask_ratio,
            )
            x_embeds[win_size] = x_emb
            H_W[win_size] = (H, W)
            print(f"[FT.forward, after self.encoder] x_emb.shape: {x_emb.shape}")

        # Pass through time-transformer
        for win_size, x_emb in x_embeds.items():
            print(f"[FT.forward, before self.time_transformer] x_emb.shape: {x_emb.shape}")
            x_emb = self.finetune_time_transformer(x_emb)
            x_emb = x_emb[:, 0]
            print(f"[FT.forward, after time-token] x_emb.shape: {x_emb.shape}")
            x_embeds[win_size] = x_emb

        # Pass through channel-transformer
        tokens = []
        for win_size, x_emb in x_embeds.items():
            x_emb = x_emb.unsqueeze(0)
            print(f"[FT.forward, before channel-token] x_emb.shape: {x_emb.shape}")
            x_emb = self.finetune_channel_transformer(x_emb)
            x_emb = x_emb[0, 0]
            print(f"[FT.forward, after channel-token] x_emb.shape: {x_emb.shape}")
            tokens.append(x_emb)

        # Average over all window shifts
        smart_token = self.win_shift_aggregation(tokens)

        # Pass through head
        y_hat = self.head(smart_token).squeeze()
        
        return y_hat

    def training_step(self, batch, batch_idx):
        x, y, dataset = batch
        y_hat = self(x)
        loss = self.criterion(input=y_hat, target=y.float())
        self.log('train_loss', loss, prog_bar=True)

        if self.task_type == "Classification":
            y_pred = (torch.sigmoid(y_hat) >= 0.5).float()
            self.train_step_outputs.append((y.cpu(), y_pred.cpu(), dataset))
        elif self.task_type == "Regression":
            self.train_step_outputs.append((y.cpu(), y_hat.cpu(), dataset))
        
        return loss

    def on_train_epoch_end(self):
        self.compute_metrics(self.train_step_outputs, 'train')
        self.train_step_outputs.clear()
    
    def validation_step(self, batch, batch_idx):
        x, y, dataset = batch
        y_hat = self(x)
        loss = self.criterion(input=y_hat, target=y.float())
        self.log('val_loss', loss, prog_bar=True)

        if self.task_type == "Classification":
            y_pred = (torch.sigmoid(y_hat) >= 0.5).float()
            self.validation_step_outputs.append((y.cpu(), y_pred.cpu(), dataset))
        elif self.task_type == "Regression":
            self.validation_step_outputs.append((y.cpu(), y_hat.cpu(), dataset))

        return loss

    def on_validation_epoch_end(self):
        self.compute_metrics(self.validation_step_outputs, 'val')
        self.validation_step_outputs.clear()

    def test_step(self, batch, batch_idx):
        x, y, dataset = batch
        y_hat = self(x)
        loss = self.criterion(input=y_hat, target=y.float())
        self.log('test_loss', loss, prog_bar=True)

        if self.task_type == "Classification":
            y_pred = (torch.sigmoid(y_hat) >= 0.5).float()
            self.test_step_outputs.append((y.cpu(), y_pred.cpu(), dataset))
        elif self.task_type == "Regression":
            self.test_step_outputs.append((y.cpu(), y_hat.cpu(), dataset))

        return loss

    def on_test_epoch_end(self):
        self.compute_metrics(self.test_step_outputs, 'test')
        self.test_step_outputs.clear()

    def compute_metrics(self, outputs, stage):
        y_true_all = defaultdict(list)
        y_pred_all = defaultdict(list)
        
        for y_true, y_pred, dataset in outputs:
            y_true_all[dataset].append(y_true)
            y_pred_all[dataset].append(y_pred)

        overall_y_true = []
        overall_y_pred = []

        for dataset in y_true_all.keys():
            y_true_cat = torch.stack(y_true_all[dataset])
            y_pred_cat = torch.stack(y_pred_all[dataset])

            overall_y_true.append(y_true_cat)
            overall_y_pred.append(y_pred_cat)

            if self.task_type == "Classification":
                balanced_acc = balanced_accuracy_score(y_true_cat, y_pred_cat)
                self.log(f'{stage}_balanced_accuracy_{dataset}', balanced_acc, prog_bar=True)
            elif self.task_type == "Regression":
                rmse_value = rmse(y_true_cat, y_pred_cat)
                self.log(f'{stage}_rmse_{dataset}', rmse_value, prog_bar=True)

        # Compute overall metrics
        overall_y_true = torch.cat(overall_y_true, dim=0)
        overall_y_pred = torch.cat(overall_y_pred, dim=0)

        if self.task_type == "Classification":
            balanced_acc = balanced_accuracy_score(overall_y_true, overall_y_pred)
            self.log(f'{stage}_balanced_accuracy', balanced_acc, prog_bar=True)
        elif self.task_type == "Regression":
            rmse_value = rmse(overall_y_true, overall_y_pred)
            self.log(f'{stage}_rmse', rmse_value, prog_bar=True)

    def configure_optimizers(self):
        return optim.Adam(self.head.parameters(), lr=self.learning_rate)

    def on_train_epoch_start(self):
        if self.trainer.current_epoch == 1:
            self.unfreeze_encoder()
            print(f"Unfroze encoder at epoch {self.trainer.current_epoch}")
        
    def freeze_encoder(self):
        for param in self.encoder.parameters():
            param.requires_grad = False

    def unfreeze_encoder(self):
        for param in self.encoder.parameters():
            param.requires_grad = True

#########################################################################################################
# Load the checkpoint
chkpt_path = ckpt_path
checkpoint = torch.load(chkpt_path, map_location=torch.device('cpu'))
state_dict = checkpoint['state_dict']
state_dict = {k.replace("net.encoder.", ""): v for k, v in state_dict.items() if "net.encoder." in k}

# Initialize the encoder and load the state dict
encoder = EncoderViTRoPE(channel_name_map_path)
encoder.load_state_dict(state_dict)

# Instantiate the fine-tuning model
fine_tuning_model = FineTuningModel(encoder=encoder,
                                    frozen_encoder=True,
                                    out_dim=out_dim,
                                    task_name=task_name,
                                    task_type=task_type,
                                    learning_rate=0.01,
                                    mask_ratio=0)

#########################################################################################################
# Define the checkpoint callback
checkpoint_callback = ModelCheckpoint(
    dirpath=f"/itet-stor/maxihuber/deepeye_storage/finetune_ckpts/{task_name}",
    filename="{epoch:02d}-{val_loss:.2f}",
    save_top_k=3,
    monitor="val_loss",
    mode="min",
)

# Train the model
trainer = L.Trainer(
    max_epochs=5,
    callbacks=[checkpoint_callback],
    log_every_n_steps=1,
    num_sanity_val_steps=0,
)

print(f"Class: {class_name}")
print(f"Task: {task_name} ({task_type})")
trainer.fit(fine_tuning_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

trainer.test(model=fine_tuning_model, dataloaders=test_loader)
final_checkpoint_path = f"/itet-stor/maxihuber/net_scratch/finetune_ckpts/{task_name}/final_model.ckpt"
trainer.save_checkpoint(final_checkpoint_path)

[rank: 0] Seed set to 42


170 30 100


/itet-stor/maxihuber/net_scratch/conda_envs/fastenv/lib/python3.9/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python3.9 /itet-stor/maxihuber/net_scratch/conda_envs/faste ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                         | Type              | Params | Mode 
---------------------------------------------------------------------------
0 | encoder                      | EncoderViTRoPE    | 23.1 M | train
1 | finetune_time_transformer    | SimpleTransformer | 1.8 M  | train
2 | finetune_channel_transformer | SimpleTransformer | 1.8 M  | train
3 | head                         | Linear            | 770    | train
4 | criterion                    | CrossEntropy

Class: YC
Task: Epilepsy (Classification)


Training: |                                                                                                   …

collate_fn
[FT.forward] win_shifts: dict_keys([1])
[FT.forward, after self.encoder] x_emb.shape: torch.Size([22, 1226, 384])
[FT.forward, before self.time_transformer] x_emb.shape: torch.Size([22, 1226, 384])
[FT.forward, after time-token] x_emb.shape: torch.Size([22, 384])
[FT.forward, before channel-token] x_emb.shape: torch.Size([1, 22, 384])
[FT.forward, after channel-token] x_emb.shape: torch.Size([384])
collate_fn
[FT.forward] win_shifts: dict_keys([1])
[FT.forward, after self.encoder] x_emb.shape: torch.Size([22, 602, 384])
[FT.forward, before self.time_transformer] x_emb.shape: torch.Size([22, 602, 384])
[FT.forward, after time-token] x_emb.shape: torch.Size([22, 384])
[FT.forward, before channel-token] x_emb.shape: torch.Size([1, 22, 384])
[FT.forward, after channel-token] x_emb.shape: torch.Size([384])
collate_fn
[FT.forward] win_shifts: dict_keys([1])
[FT.forward, after self.encoder] x_emb.shape: torch.Size([20, 1202, 384])
[FT.forward, before self.time_transformer] x_emb.sh

OutOfMemoryError: CUDA out of memory. Tried to allocate 29.55 GiB. GPU 0 has a total capacty of 47.54 GiB of which 11.42 GiB is free. Including non-PyTorch memory, this process has 36.11 GiB memory in use. Of the allocated memory 32.40 GiB is allocated by PyTorch, and 3.38 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [20]:
with open('/home/maxihuber/eeg-foundation/runs/eps_trainchns.json', 'r') as file:
    train_chns = json.load(file)

with open('/home/maxihuber/eeg-foundation/runs/eps_testchns.json', 'r') as file:
    test_chns = json.load(file)

all_chns = train_chns + test_chns

chns = set()
for f_chns in all_chns:
    chns.update(f_chns)

print(chns)

{'luc', 'p4', 'c3', 'sp2', '21', 'pz', 'fp2', 't6', 'fz', 't2', '25', '27', '28', 'rlc', 'loc', 'f4', 'f8', '30', '31', 'roc', '29', 't4', 't3', 'p3', 'a2', 'oz', 'ekg', 'a1', 't5', 't1', 'pg1', 'c4p', 'cz', 'c4', 'fp1', 'o1', 'o2', '24', 'f3', '22', '23', '26', 'f7', '32', 'c3p', 'pg2', '20', 'ekg1', 'sp1'}


# Baseline Models

## Data Preparation & Datasets

In [5]:
L.seed_everything(42)
sys.path.append("/home/maxihuber/eeg-foundation/src/models/components/Baselines")

class SimpleDataset(Dataset):
    def __init__(self, data, outputs, datasets, task_type, label_encoder=None):
        self.data = data
        self.outputs = outputs
        self.datasets = datasets
        self.task_type = task_type
        self.label_encoder = label_encoder

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        signals = self.data[idx]
        output = self.outputs[idx]
        dataset = self.datasets[idx]

        if self.task_type == "Classification" and self.label_encoder is not None:
            output = self.label_encoder.transform([output])[
                0
            ]  # Encode the output label
            output_tensor = torch.tensor(output, dtype=torch.long)
        else:
            output_tensor = torch.tensor([output], dtype=torch.float32)

        return {
            "signals": signals,
            "output": output_tensor,
            "dataset": dataset,
        }


durs = [df.shape[1] for idx, df in train_data.items()] + [
    df.shape[1] for idx, df in test_data.items()
]
n_chns = [df.shape[0] for idx, df in train_data.items()] + [
    df.shape[0] for idx, df in test_data.items()
]
dur_90 = int(np.percentile(durs, 90))
chn_90 = 128  # int(np.percentile(n_chns, 90))


def pad_tensor(tensor, target_height, target_width):
    current_height, current_width = tensor.shape

    # Pad height if necessary
    if current_height < target_height:
        padding_height = target_height - current_height
        padding = torch.zeros((padding_height, current_width), dtype=tensor.dtype)
        tensor = torch.cat((tensor, padding), dim=0)
    else:
        tensor = tensor[:target_height, :]

    # Pad width if necessary
    if current_width < target_width:
        padding_width = target_width - current_width
        padding = torch.zeros((tensor.shape[0], padding_width), dtype=tensor.dtype)
        tensor = torch.cat((tensor, padding), dim=1)
    else:
        tensor = tensor[:, :target_width]

    return tensor


train_data_pad = {
    k: pad_tensor(signals, chn_90, dur_90) for k, signals in train_data.items()
}
test_data_pad = {
    k: pad_tensor(signals, chn_90, dur_90) for k, signals in test_data.items()
}

full_train_dataset = SimpleDataset(
    train_data_pad,
    train_outputs,
    train_datasets,
    task_type=task_type,
    label_encoder=label_encoder,
)
test_dataset = SimpleDataset(
    test_data_pad,
    test_outputs,
    train_datasets,
    task_type=task_type,
    label_encoder=label_encoder,
)

# Define the split ratio
train_ratio, val_ratio = 0.85, 0.15

# Calculate lengths for train and validation sets
total_size = len(full_train_dataset)
train_size = int(train_ratio * total_size)
val_size = total_size - train_size

# Split the dataset
train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])

[rank: 0] Seed set to 42


## XDawn + LDA ============================================

In [7]:
os.makedirs(
    f"/itet-stor/maxihuber/net_scratch/finetune_ckpts/{task_name}", exist_ok=True
)

# Function to resample signals
def resample_signals(data, srs, target_sfreq):
    resampled_data = {}
    for idx, signal in tqdm(data.items(), desc="Resampling signals"):
        signal_numpy = signal.numpy().astype(np.float64)  # Convert to float64
        signal_resampled = mne.filter.resample(signal_numpy, up=target_sfreq / srs[idx])
        resampled_data[idx] = torch.tensor(signal_resampled, dtype=torch.float32)
    return resampled_data


# Function to pad or truncate signals to a common length
def pad_or_truncate_signals(data, common_length):
    for idx, signal in tqdm(data.items(), desc="Pad/Truncate signals"):
        signal_length = signal.shape[1]
        if signal_length < common_length:
            pad_width = common_length - signal_length
            signal_padded = np.pad(signal, ((0, 0), (0, pad_width)), mode="constant")
        else:
            signal_padded = signal[:, :common_length]
        data[idx] = torch.tensor(signal_padded.clone().detach(), dtype=torch.float32)
    return data


# Function to create MNE Epochs object from data
def create_epochs(data, outputs, channels, sfreq=1000, is_classification=True):
    events = []
    event_id = {}
    epochs_data = []
    for idx, signal in tqdm(data.items(), desc="Creating epochs"):
        epochs_data.append(signal.numpy())
        if is_classification:
            if outputs[idx] not in event_id:
                event_id[outputs[idx]] = len(event_id) + 1
            events.append([idx, 0, event_id[outputs[idx]]])
        else:
            events.append([idx, 0, 1])  # Dummy event_id for regression
    events = np.array(events, dtype=int)
    info = mne.create_info(
        ch_names=[f"EEG_{i}" for i in range(chn_90)], sfreq=sfreq, ch_types="eeg"
    )
    epochs = mne.EpochsArray(
        np.array(epochs_data),
        info,
        events=events,
        event_id=event_id if is_classification else None,
    )
    return epochs


# Determine the target sampling frequency (e.g., the highest or mean sampling rate)
target_sfreq = int(max(train_sr.values()))  # or use statistics.mean(train_sr.values())

# Resample train and test data
# train_data_resampled = resample_signals(train_data_pad, train_sr, target_sfreq)
# test_data_resampled = resample_signals(test_data_pad, test_sr, target_sfreq)
train_data_resampled = train_data_pad
test_data_resampled = test_data_pad

# Determine the common length for all signals
common_length = min(
    min(signal.shape[1] for signal in train_data_resampled.values()),
    min(signal.shape[1] for signal in test_data_resampled.values()),
)

# Pad or truncate train and test data to the common length
train_data_padded = pad_or_truncate_signals(train_data_resampled, common_length)
test_data_padded = pad_or_truncate_signals(test_data_resampled, common_length)

# Convert train and test data to MNE Epochs
is_classification = isinstance(list(train_outputs.values())[0], str)
epochs_train = create_epochs(
    train_data_padded,
    train_outputs,
    list(train_channels.values()),
    target_sfreq,
    is_classification,
)
epochs_test = create_epochs(
    test_data_padded,
    test_outputs,
    list(test_channels.values()),
    target_sfreq,
    is_classification,
)

print("Start fitting Xdawn...", file=sys.stderr)

# Create an xDAWN instance and fit it to the training data
xdawn = Xdawn(n_components=2, correct_overlap=False, reg=0.1)  # Adding regularization
xdawn.fit(epochs_train)

# Transform the data using xDAWN
X_train_xdawn = xdawn.transform(epochs_train)
X_test_xdawn = xdawn.transform(epochs_test)

# Save xDAWN model parameters
with open(
    f"/itet-stor/maxihuber/net_scratch/finetune_ckpts/{task_name}/xdawn_model.pkl", "wb"
) as f:
    pickle.dump(xdawn, f)

# Flatten the transformed data for LDA input
n_epochs_train, n_components, n_times = X_train_xdawn.shape
X_train_xdawn = X_train_xdawn.reshape(n_epochs_train, n_components * n_times)
n_epochs_test, n_components, n_times = X_test_xdawn.shape
X_test_xdawn = X_test_xdawn.reshape(n_epochs_test, n_components * n_times)

if is_classification:
    # Encode labels if they are strings (for classification tasks)
    label_encoder = LabelEncoder()
    y_train = label_encoder.fit_transform(list(train_outputs.values()))
    y_test = label_encoder.transform(list(test_outputs.values()))

    # Create an LDA instance and fit it to the transformed training data
    lda = LinearDiscriminantAnalysis()
    lda.fit(X_train_xdawn, y_train)

    # Save LDA model parameters
    with open(
        f"/itet-stor/maxihuber/net_scratch/finetune_ckpts/{task_name}/lda_model.pkl",
        "wb",
    ) as f:
        pickle.dump(lda, f)

    # Predict the labels of the test set
    y_pred = lda.predict(X_test_xdawn)

    # Calculate metrics
    balanced_acc = balanced_accuracy_score(y_test, y_pred)
    print(f"Balanced Accuracy: {balanced_acc}", file=sys.stderr)
else:
    # For regression tasks
    y_train = np.array(list(train_outputs.values()))
    y_test = np.array(list(test_outputs.values()))

    # Create a linear regression model and fit it to the transformed training data
    lr = LinearRegression()
    lr.fit(X_train_xdawn, y_train)

    # Save Linear Regression model parameters
    with open(
        f"/itet-stor/maxihuber/net_scratch/finetune_ckpts/{task_name}/linear_regression_model.pkl",
        "wb",
    ) as f:
        pickle.dump(lr, f)

    # Predict the values of the test set
    y_pred = lr.predict(X_test_xdawn)

    # Calculate metrics
    rmse_value = np.sqrt(mean_squared_error(y_test, y_pred))
    print(f"RMSE: {rmse_value}", file=sys.stderr)

  data[idx] = torch.tensor(signal_padded.clone().detach(), dtype=torch.float32)
Pad/Truncate signals: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:17<00:00, 11.69it/s]
Pad/Truncate signals: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:09<00:00, 11.02it/s]
Creating epochs: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 238042.22it/s]
Creating epochs: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 1253.32it/s]
Balanced Accuracy: 0.51


## Classifiers from Sklearn  ===========================================

In [None]:
# Code source: Gaël Varoquaux
#              Andreas Müller
# Modified for documentation by Jaques Grobler
# License: BSD 3 clause
# https://scikit-learn.org/stable/auto_examples/classification/plot_classifier_comparison.html

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap

from sklearn.datasets import make_circles, make_classification, make_moons
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
from sklearn.ensemble import AdaBoostClassifier, RandomForestClassifier
from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.gaussian_process.kernels import RBF
from sklearn.inspection import DecisionBoundaryDisplay
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import GaussianNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier

names = [
    "Nearest Neighbors",
    "Linear SVM",
    "RBF SVM",
    # "Gaussian Process",
    "Decision Tree",
    "Random Forest",
    # "Neural Net (MLP)",
    "AdaBoost",
    "Naive Bayes",
    "QDA",
]

classifiers = [
    KNeighborsClassifier(3),
    SVC(kernel="linear", C=0.025, random_state=42),
    SVC(gamma=2, C=1, random_state=42),
    # GaussianProcessClassifier(1.0 * RBF(1.0), random_state=42),
    DecisionTreeClassifier(max_depth=5, random_state=42),
    RandomForestClassifier(
        max_depth=5, n_estimators=10, max_features=1, random_state=42
    ),
    # MLPClassifier(alpha=1, max_iter=100, random_state=42),
    AdaBoostClassifier(algorithm="SAMME", random_state=42),
    GaussianNB(),
    QuadraticDiscriminantAnalysis(),
]

# Flattening the data for all channels
def flatten_data(data_pad):
    flattened_data = [sample.flatten().numpy() for sample in data_pad.values()]
    return flattened_data

X_train = flatten_data(train_data_pad)
y_train = [output for output in train_outputs.values()]

X_test = flatten_data(test_data_pad)
y_test = [output for output in test_outputs.values()]

# iterate over classifiers
for name, clf in zip(names, classifiers):
    print(name, file=sys.stderr)
    clf = make_pipeline(StandardScaler(), clf)
    clf.fit(X_train, y_train)
    score = clf.score(X_test, y_test)
    print(f"{name}: {score}", file=sys.stderr)

Nearest Neighbors
Nearest Neighbors: 0.51
Linear SVM
Linear SVM: 0.59
RBF SVM
RBF SVM: 0.5
Decision Tree
Decision Tree: 0.6
Random Forest
Random Forest: 0.56
AdaBoost
