## Imports ##

In [1]:
import os
import pickle
import json
import joblib

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import mne
# from mne import events_from_annotations, concatenate_raws

from sklearn.model_selection import train_test_split
from scipy.interpolate import CubicSpline
from scipy.integrate import simps
from sklearn.metrics import r2_score
import scipy.io
from scipy.io import loadmat
# from mat4py import loadmat, savemat

#from imblearn.pipeline import Pipeline
#from imblearn.over_sampling import SMOTE
#from imblearn.under_sampling import RandomUnderSampler

from sklearn.preprocessing import StandardScaler, MinMaxScaler, LabelEncoder
from sklearn.model_selection import RandomizedSearchCV, GridSearchCV, StratifiedKFold
from sklearn.feature_selection import mutual_info_classif
from sklearn.feature_selection import SelectKBest
from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.gaussian_process.kernels import RBF,WhiteKernel,Matern,RationalQuadratic,ExpSineSquared

from sklearn.metrics import classification_report, accuracy_score

from processing.Processing_EEG import process_eeg_raw, process_eeg_epochs
from processing.Processing_NIRS import process_nirs_raw, process_nirs_epochs

from utilities.Read_Data import read_subject_raw_nirs, read_subject_raw_eeg
from utilities.utilities import translate_channel_name_to_ch_id, find_sections, spatial_zscore

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F

import csv
from sklearn.metrics import r2_score

import torch
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
import sys
sys.path.insert(1, '../../iTransformer/')

from iTransformer.iTransformerTranscoding import iTransformer

import gc



## Constants ##

In [2]:
BASE_PATH = '/scratch/mjm9724/stinger/data/'

ROOT_DIRECTORY_EEG = os.path.join(BASE_PATH, 'raw/eeg/')
ROOT_DIRECTORY_NIRS = os.path.join(BASE_PATH, 'raw/nirs/')
MODEL_WEIGHTS = os.path.join(BASE_PATH, 'model_weights/')
OUTPUT_DIRECTORY = os.path.join(BASE_PATH, 'output/')

# Trial order
TRIAL_TO_CHECK_NIRS = {'VP001': {
                            'nback': ['2016-05-26_007', '2016-05-26_008', '2016-05-26_009',],
                            'gonogo': ['2016-05-26_001', '2016-05-26_002', '2016-05-26_003',],
                            'word': ['2016-05-26_004', '2016-05-26_005', '2016-05-26_006',]
                        },
                        'VP002': {
                            'nback': ['2016-05-26_016', '2016-05-26_017', '2016-05-26_018',],
                            'gonogo': ['2016-05-26_010', '2016-05-26_011', '2016-05-26_012',],
                            'word': ['2016-05-26_013', '2016-05-26_014', '2016-05-26_015',]
                        },
                        'VP003': {
                            'nback': ['2016-05-27_001', '2016-05-27_002', '2016-05-27_003',],
                            'gonogo': ['2016-05-27_007', '2016-05-27_008', '2016-05-27_009',],
                            'word': ['2016-05-27_004', '2016-05-27_005', '2016-05-27_006',]
                        },
                        'VP004': {
                            'nback': ['2016-05-30_001', '2016-05-30_002', '2016-05-30_003'],
                            'gonogo': ['2016-05-30_007', '2016-05-30_008', '2016-05-30_009'],
                            'word': ['2016-05-30_004', '2016-05-30_005', '2016-05-30_006']
                        },
                        'VP005': {
                            'nback': ['2016-05-30_010', '2016-05-30_011', '2016-05-30_012'],
                            'gonogo': ['2016-05-30_016', '2016-05-30_017', '2016-05-30_018'],
                            'word': ['2016-05-30_013', '2016-05-30_014', '2016-05-30_015']
                        },
                        'VP006': {
                            'nback': ['2016-05-31_001', '2016-05-31_002', '2016-05-31_003'],
                            'gonogo': ['2016-05-31_007', '2016-05-31_008', '2016-05-31_009'],
                            'word': ['2016-05-31_004', '2016-05-31_005', '2016-05-31_006']
                        },
                        'VP007': {
                            'nback': ['2016-06-01_001', '2016-06-01_002', '2016-06-01_003'],
                            'gonogo': ['2016-06-01_007', '2016-06-01_008', '2016-06-01_009'],
                            'word': ['2016-06-01_004', '2016-06-01_005', '2016-06-01_006']
                        },
                        'VP008': {
                            'nback': ['2016-06-02_001', '2016-06-02_002', '2016-06-02_003'],
                            'gonogo': ['2016-06-02_007', '2016-06-02_008', '2016-06-02_009'],
                            'word': ['2016-06-02_004', '2016-06-02_005', '2016-06-02_006']
                        },
                        'VP009': {
                            'nback': ['2016-06-02_010', '2016-06-02_011', '2016-06-02_012'],
                            'gonogo': ['2016-06-02_016', '2016-06-02_017', '2016-06-02_018'],
                            'word': ['2016-06-02_013', '2016-06-02_014', '2016-06-02_015']
                        },
                        'VP010': {
                            'nback': ['2016-06-03_001', '2016-06-03_002', '2016-06-03_003'],
                            'gonogo': ['2016-06-03_007', '2016-06-03_008', '2016-06-03_009'],
                            'word': ['2016-06-03_004', '2016-06-03_005', '2016-06-03_006']
                        },
                        'VP011': {
                            'nback': ['2016-06-03_010', '2016-06-03_011', '2016-06-03_012'],
                            'gonogo': ['2016-06-03_016', '2016-06-03_017', '2016-06-03_018'],
                            'word': ['2016-06-03_013', '2016-06-03_014', '2016-06-03_015']
                        },'VP012': {
                            'nback': ['2016-06-06_001', '2016-06-06_002', '2016-06-06_003'],
                            'gonogo': ['2016-06-06_007', '2016-06-06_008', '2016-06-06_009'],
                            'word': ['2016-06-06_004', '2016-06-06_005', '2016-06-06_006']
                        },'VP013': {
                            'nback': ['2016-06-06_010', '2016-06-06_011', '2016-06-06_012'],
                            'gonogo': ['2016-06-06_016', '2016-06-06_017', '2016-06-06_018'],
                            'word': ['2016-06-06_013', '2016-06-06_014', '2016-06-06_015']
                        },'VP014': {
                            'nback': ['2016-06-07_001', '2016-06-07_002', '2016-06-07_003'],
                            'gonogo': ['2016-06-07_007', '2016-06-07_008', '2016-06-07_009'],
                            'word': ['2016-06-07_004', '2016-06-07_005', '2016-06-07_006']
                        },'VP015': {
                            'nback': ['2016-06-07_010', '2016-06-07_011', '2016-06-07_012'],
                            'gonogo': ['2016-06-07_016', '2016-06-07_017', '2016-06-07_018'],
                            'word': ['2016-06-07_013', '2016-06-07_014', '2016-06-07_015']
                        },'VP016': {
                            'nback': ['2016-06-08_001', '2016-06-08_002', '2016-06-08_003'],
                            'gonogo': ['2016-06-08_007', '2016-06-08_008', '2016-06-08_009'],
                            'word': ['2016-06-08_004', '2016-06-08_005', '2016-06-08_006']
                        },'VP017': {
                            'nback': ['2016-06-09_001', '2016-06-09_002', '2016-06-09_003'],
                            'gonogo': ['2016-06-09_007', '2016-06-09_008', '2016-06-09_009'],
                            'word': ['2016-06-09_004', '2016-06-09_005', '2016-06-09_006']
                        },'VP018': {
                            'nback': ['2016-06-10_001', '2016-06-10_002', '2016-06-10_003'],
                            'gonogo': ['2016-06-10_007', '2016-06-10_008', '2016-06-10_009'],
                            'word': ['2016-06-10_004', '2016-06-10_005', '2016-06-10_006']
                        },'VP019': {
                            'nback': ['2016-06-13_001', '2016-06-13_002', '2016-06-13_003'],
                            'gonogo': ['2016-06-13_007', '2016-06-13_008', '2016-06-13_009'],
                            'word': ['2016-06-13_004', '2016-06-13_005', '2016-06-13_006']
                        },'VP020': {
                            'nback': ['2016-06-14_001', '2016-06-14_002', '2016-06-14_003'],
                            'gonogo': ['2016-06-14_007', '2016-06-14_008', '2016-06-14_009'],
                            'word': ['2016-06-14_004', '2016-06-14_005', '2016-06-14_006']
                        },'VP021': {
                            'nback': ['2016-06-14_010', '2016-06-14_011', '2016-06-14_012'],
                            'gonogo': ['2016-06-14_016', '2016-06-14_017', '2016-06-14_018'],
                            'word': ['2016-06-14_013', '2016-06-14_014', '2016-06-14_015']
                        },'VP022': {
                            'nback': ['2016-06-15_001', '2016-06-15_002', '2016-06-15_003'],
                            'gonogo': ['2016-06-15_007', '2016-06-15_008', '2016-06-15_009'],
                            'word': ['2016-06-15_004', '2016-06-15_005', '2016-06-15_006']
                        },'VP023': {
                            'nback': ['2016-06-16_001', '2016-06-16_002', '2016-06-16_003'],
                            'gonogo': ['2016-06-16_007', '2016-06-16_008', '2016-06-16_009'],
                            'word': ['2016-06-16_004', '2016-06-16_005', '2016-06-16_006']
                        },'VP024': {
                            'nback': ['2016-06-16_010', '2016-06-16_011', '2016-06-16_012'],
                            'gonogo': ['2016-06-16_016', '2016-06-16_017', '2016-06-16_018'],
                            'word': ['2016-06-16_013', '2016-06-16_014', '2016-06-16_015']
                        },
                        'VP025': {
                            'nback': ['2016-06-17_010', '2016-06-17_011', '2016-06-17_012',],
                            'gonogo': ['2016-06-17_016', '2016-06-17_017', '2016-06-17_018',],
                            'word': ['2016-06-17_013', '2016-06-17_014', '2016-06-17_015',]
                        },
                        'VP026': {
                            'nback': ['2016-07-11_001', '2016-07-11_002', '2016-07-11_003',],
                            'gonogo': ['2016-07-11_007', '2016-07-11_008', '2016-07-11_009',],
                            'word': ['2016-07-11_004', '2016-07-11_005', '2016-07-11_006',]
                        }
                    }

# Task translation dictionaries
EEG_EVENT_TRANSLATIONS = {
            'nback': {
                'Stimulus/S 16': '0-back target',
                'Stimulus/S 48': '2-back target',
                'Stimulus/S 64': '2-back non-target',
                'Stimulus/S 80': '3-back target',
                'Stimulus/S 96': '3-back non-target',
                'Stimulus/S112': '0-back session',
                'Stimulus/S128': '2-back session',
                'Stimulus/S144': '3-back session'},
            'gonogo': {
                'Stimulus/S 16': 'go',
                'Stimulus/S 32': 'nogo',
                'Stimulus/S 48': 'gonogo session'},
            'word': {
                'Stimulus/S 16': 'verbal_fluency',
                'Stimulus/S 32': 'baseline'}
}
NIRS_EVENT_TRANSLATIONS = {
    'nback': {
        '7.0': '0-back session',
        '8.0': '2-back session',
        '9.0': '3-back session'},
    'gonogo': {
        '3.0': 'gonogo session'},
    'word': {
        '1.0': 'verbal_fluency',
        '2.0': 'baseline'}
}

# Sub tasks to crop times to for same length
TASK_STIMULOUS_TO_CROP = {'nback': ['0-back session', '2-back session', '3-back session'],
                            'gonogo': ['gonogo session'],
                            'word': ['verbal_fluency', 'baseline']
                            }

# EEG Coordinates
EEG_COORDS = {'FP1':(-0.3090,0.9511,0.0001), #Fp1
                'AFF5':(-0.5417,0.7777,0.3163), #AFF5h
                'AFz':(0.0000,0.9230,0.3824),
                'F1':(-0.2888,0.6979,0.6542),
                'FC5':(-0.8709,0.3373,0.3549),
                'FC1':(-0.3581,0.3770,0.8532),
                'T7':(-1.0000,0.0000,0.0000),
                'C3':(-0.7066,0.0001,0.7066),
                'Cz':(0.0000,0.0002,1.0000),
                'CP5':(-0.8712,-0.3372,0.3552),
                'CP1':(-0.3580,-0.3767,0.8534),
                'P7':(-0.8090,-0.5878,-0.0001),
                'P3':(-0.5401,-0.6724,0.5045),
                'Pz':(0.0000,-0.7063,0.7065),
                'POz':(0.0000,-0.9230,0.3824),
                'O1':(-0.3090,-0.9511,0.0000),
                'FP2':(0.3091,0.9511,0.0000), #Fp2
                'AFF6':(0.5417,0.7777,0.3163), #AFF6h
                'F2':(0.2888,0.6979,0.6542),
                'FC2':(0.3581,0.3770,0.8532),
                'FC6':(0.8709,0.3373,0.3549),
                'C4':(0.7066,0.0001,0.7066),
                'T8':(1.0000,0.0000,0.0000),
                'CP2':(0.3580,-0.3767,0.8534),
                'CP6':(0.8712,-0.3372,0.3552),
                'P4':(0.5401,-0.6724,0.5045),
                'P8':(0.8090,-0.5878,-0.0001),
                'O2':(0.3090,-0.9511,0.0000),
                'TP9':(-0.8777,-0.2852,-0.3826),
                'TP10':(0.8777,-0.2853,-0.3826),
                
                'Fp1':(-0.3090,0.9511,0.0001),
                'AFF5h':(-0.5417,0.7777,0.3163),
                'Fp2':(0.3091,0.9511,0.0000),
                'AFF6h':(0.5417,0.7777,0.3163),}

# NIRS Ccoordinates
NIRS_COORDS = {
    'AF7':(-0.5878,0.809,0),
    'AFF5':(-0.6149,0.7564,0.2206),
    'AFp7':(-0.454,0.891,0),
    'AF5h':(-0.4284,0.875,0.2213),
    'AFp3':(-0.2508,0.9565,0.1438),
    'AFF3h':(-0.352,0.8111,0.4658),
    'AF1':(-0.1857,0.915,0.3558),
    'AFFz':(0,0.8312,0.5554),
    'AFpz':(0,0.9799,0.1949),
    'AF2':(0.1857,0.915,0.3558),
    'AFp4':(0.2508,0.9565,0.1437),
    'FCC3':(-0.6957,0.1838,0.6933),
    'C3h':(-0.555,0.0002,0.8306),
    'C5h':(-0.8311,0.0001,0.5552),
    'CCP3':(-0.6959,-0.1836,0.6936),
    'CPP3':(-0.6109,-0.5259,0.5904),
    'P3h':(-0.4217,-0.6869,0.5912),
    'P5h':(-0.6411,-0.6546,0.3985),
    'PPO3':(-0.4537,-0.796,0.3995),
    'AFF4h':(0.352,0.8111,0.4658),
    'AF6h':(0.4284,0.875,0.2212),
    'AFF6':(0.6149,0.7564,0.2206),
    'AFp8':(0.454,0.891,0),
    'AF8':(0.5878,0.809,0),
    'FCC4':(0.6957,0.1838,0.6933),
    'C6h':(0.8311,0.0001,0.5552),
    'C4h':(0.555,0.0002,0.8306),
    'CCP4':(0.6959,-0.1836,0.6936),
    'CPP4':(0.6109,-0.5258,0.5904),
    'P6h':(0.6411,-0.6546,0.3985),
    'P4h':(0.4216,-0.687,0.5912),
    'PPO4':(0.4537,-0.796,0.3995),
    'PPOz':(0,-0.8306,0.5551),
    'PO1':(-0.1858,-0.9151,0.3559),
    'PO2':(0.1859,-0.9151,0.3559),
    'POOz':(0,-0.9797,0.1949)}

# EEG Channels names
EEG_CHANNEL_NAMES = ['FP1', 
                    'AFF5h', 
                    'AFz', 
                    'F1', 
                    'FC5', 
                    'FC1', 
                    'T7', 
                    'C3', 
                    'Cz', 
                    'CP5', 
                    'CP1', 
                    'P7', 
                    'P3', 
                    'Pz', 
                    'POz', 
                    'O1',  
                    'FP2', 
                    'AFF6h',
                    'F2', 
                    'FC2', 
                    'FC6', 
                    'C4', 
                    'T8', 
                    'CP2', 
                    'CP6', 
                    'P4', 
                    'P8', 
                    'O2',
                    'HEOG',
                    'VEOG']

## Parameters - Raw ##

In [3]:
## Subject/Trial Parameters ##
subject_ids = np.arange(1,2) # 1-27
subjects = []
for i in subject_ids:
    subjects.append(f'VP{i:03d}')

tasks = ['nback','gonogo','word']

# NIRS Sampling rate
fnirs_sample_rate = 10
# EEG Downsampling rate
eeg_sample_rate = 10

# Do processing or not
do_processing = True

# Redo preprocessing pickle files, TAKES A LONG TIME 
redo_preprocessing = False

# Redo data formating pickle files, TAKES A LONG TIME
redo_data_formatting = False

## Signal Prediction ##

### Parameters - Signal Prediction

In [4]:
# Time window (seconds)
eeg_t_min = 0
eeg_t_max = 1
nirs_t_min = -10
nirs_t_max = 10

offset_t = 0

# Train/Test Size
train_size = 4000
test_size = 500

# training loop
num_epochs = 100

do_load = False
do_train = True

fnirs_lookback = 4000
eeg_lookback = 200

use_hbr = False

subjects = np.arange(1,27)

### Extract Data ###

In [5]:
def get_single_window(center_point, 
                      nirs_data, 
                      eeg_data, 
                      eeg_i_min, 
                      eeg_i_max, 
                      nirs_i_min, 
                      nirs_i_max):
    eeg_low_index = center_point + eeg_i_min
    eeg_high_index = center_point + eeg_i_max
    single_eeg_window = eeg_data[:,eeg_low_index:eeg_high_index]

    nirs_low_index = center_point + nirs_i_min
    nirs_high_index = center_point + nirs_i_max
    single_nirs_window = nirs_data[:,nirs_low_index:nirs_high_index]
    
    return single_eeg_window, single_nirs_window

def grab_ordered_windows(nirs_data,
                        eeg_data,
                        sampling_rate,
                        nirs_t_min, 
                        nirs_t_max,
                        eeg_t_min, 
                        eeg_t_max):
    nirs_i_min = int(nirs_t_min*sampling_rate)
    nirs_i_max = int(nirs_t_max*sampling_rate)
    eeg_i_min = int(eeg_t_min*sampling_rate)
    eeg_i_max = int(eeg_t_max*sampling_rate)

    eeg_window_size = eeg_i_max - eeg_i_min

    max_center_eeg = eeg_data.shape[1] - eeg_i_max
    max_center_nirs = nirs_data.shape[1] - nirs_i_max
    max_center = np.min([max_center_eeg, max_center_nirs])

    min_center_eeg = np.abs(eeg_i_min)
    min_center_nirs = np.abs(nirs_i_min)
    min_center = np.max([min_center_eeg, min_center_nirs])

    nirs_full_windows = []
    eeg_full_windows = []
    meta_data = []

    for i in range(min_center, max_center, eeg_window_size):
        center_point = i
        meta_data.append(center_point)
        single_eeg_window, single_nirs_window = get_single_window(center_point, 
                                                                  nirs_data, 
                                                                  eeg_data, 
                                                                  eeg_i_min, 
                                                                  eeg_i_max, 
                                                                  nirs_i_min, 
                                                                  nirs_i_max)
        
        eeg_full_windows.append(single_eeg_window)
        nirs_full_windows.append(single_nirs_window)

    nirs_full_windows = np.array(nirs_full_windows)
    eeg_full_windows = np.array(eeg_full_windows)

    return eeg_full_windows, nirs_full_windows, meta_data
    
def grab_random_windows(nirs_data, 
                        eeg_data,
                        sampling_rate,
                        nirs_t_min, 
                        nirs_t_max,
                        eeg_t_min, 
                        eeg_t_max,
                        number_of_windows=1000):
    '''make number_of_windows of size t_min to t_max for each offset 0 to offset_t for eeg and nirs'''

    nirs_i_min = int(nirs_t_min*sampling_rate)
    nirs_i_max = int(nirs_t_max*sampling_rate)
    eeg_i_min = int(eeg_t_min*sampling_rate)
    eeg_i_max = int(eeg_t_max*sampling_rate)

    max_center_eeg = eeg_data.shape[1] - eeg_i_max
    max_center_nirs = nirs_data.shape[1] - nirs_i_max
    max_center = np.min([max_center_eeg, max_center_nirs])

    min_center_eeg = np.abs(eeg_i_min)
    min_center_nirs = np.abs(nirs_i_min)
    min_center = np.max([min_center_eeg, min_center_nirs])

    nirs_full_windows = []
    eeg_full_windows = []
    meta_data = []
    for i in range(number_of_windows):
        center_point = np.random.randint(min_center, max_center)
        meta_data.append(center_point)
        single_eeg_window, single_nirs_window = get_single_window(center_point, 
                                                                  nirs_data, 
                                                                  eeg_data, 
                                                                  eeg_i_min, 
                                                                  eeg_i_max, 
                                                                  nirs_i_min, 
                                                                  nirs_i_max)
        
        eeg_full_windows.append(single_eeg_window)
        nirs_full_windows.append(single_nirs_window)
    
    nirs_full_windows = np.array(nirs_full_windows)
    eeg_full_windows = np.array(eeg_full_windows)

    return eeg_full_windows, nirs_full_windows, meta_data

class EEGfNIRSData(Dataset):
    def __init__(self, fnirs_data, eeg_data):
        self.fnirs_data = fnirs_data
        self.eeg_data = eeg_data
    
    def __len__(self):
        return len(self.eeg_data)
    
    def __getitem__(self, idx):
        return self.fnirs_data[idx], self.eeg_data[idx]

def plot_series(target, output, epoch):
    plt.figure(figsize=(10, 4))
    plt.plot(target, label='Target')
    plt.plot(output, label='Output', linestyle='--')
    plt.title(f'Epoch {epoch + 1}')
    plt.legend()
    plt.grid(True)
    plt.show()

def find_indices(x, y):
    indices = []
    for item in y:
        if item in x:
            indices.append(x.index(item))
    return indices

In [None]:
# Define channels to use
nirs_channels_to_use_base = list(NIRS_COORDS.keys())[:16]
nirs_channel_index = find_indices(list(NIRS_COORDS.keys()),nirs_channels_to_use_base)

eeg_channels_to_use_full = EEG_CHANNEL_NAMES

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Device: {device}')

for test_subject_id in subjects:
    test = [test_subject_id]
    for eeg_channel_name in eeg_channels_to_use_full:
        eeg_channels_to_use = [eeg_channel_name]
        eeg_channel_index = find_indices(EEG_CHANNEL_NAMES,eeg_channels_to_use)

        model_name = f'transformer_{test_subject_id:02d}_{eeg_channel_name}'
        final_model_path = os.path.join(MODEL_WEIGHTS, f'{model_name}_{num_epochs}.pth')
        if os.path.exists(final_model_path):
            print(f'Model name exists, skipping {model_name}')
        else:
            print(f'Starting {model_name}')
        
            
            model = iTransformer(
                num_variates = len(nirs_channels_to_use_base),
                lookback_len = fnirs_lookback,      # or the lookback length in the paper
                target_num_variates=len(eeg_channels_to_use),
                target_lookback_len=eeg_lookback,
                dim = 256,                          # model dimensions
                depth = 6,                          # depth
                heads = 8,                          # attention heads
                dim_head = 64,                      # head dimension
                attn_dropout=0.1,
                ff_mult=4,
                ff_dropout=0.1,
                num_mem_tokens=10,
                num_tokens_per_variate = 1,         # experimental setting that projects each variate to more than one token. the idea is that the network can learn to divide up into time tokens for more granular attention across time. thanks to flash attention, you should be able to accommodate long sequence lengths just fine
                use_reversible_instance_norm = True # use reversible instance normalization, proposed here https://openreview.net/forum?id=cGDAkQo1C0p . may be redundant given the layernorms within iTransformer (and whatever else attention learns emergently on the first layer, prior to the first layernorm). if i come across some time, i'll gather up all the statistics across variates, project them, and condition the transformer a bit further. that makes more sense
            )
            
            # Pre-allocate memory for training and testing data
            eeg_windowed_train = np.empty((0, eeg_lookback, len(eeg_channels_to_use)))
            nirs_windowed_train = np.empty((0, fnirs_lookback, len(nirs_channels_to_use_base)))
            eeg_windowed_test = np.empty((0, eeg_lookback, len(eeg_channels_to_use)))
            nirs_windowed_test = np.empty((0, fnirs_lookback, len(nirs_channels_to_use_base)))
            
            for i in subjects:
                subject_id = f'{i:02d}'
            
                subject_data = loadmat(os.path.join(BASE_PATH, 'matfiles', f'data_vp0{subject_id}.mat'))['subject_data_struct'][0]
                # # eeg subject_data[1][0]
                # eeg_data = []
                # for session_eeg_data in subject_data[1][0]:
                #     eeg_data.append(session_eeg_data.T)
                # eeg_data = np.hstack(eeg_data)
                # # fnirs subject_data[3][0]
                # nirs_data = []
                # for session_nirs_data in subject_data[3][0]:
                #     nirs_data.append(session_nirs_data.T)
                # nirs_data = np.hstack(nirs_data)
                # # mrk subject_data[5][0]
            
                eeg_data = subject_data[1][0][0].T
                nirs_data = subject_data[3][0][0].T
            
                assert eeg_data.shape[1] == nirs_data.shape[1]
                
                if i not in test and do_train:
                    single_eeg_windowed_train, single_nirs_windowed_train, meta_data = grab_random_windows(
                                 nirs_data=nirs_data, 
                                 eeg_data=eeg_data,
                                 sampling_rate=200,
                                 nirs_t_min=nirs_t_min, 
                                 nirs_t_max=nirs_t_max,
                                 eeg_t_min=0, 
                                 eeg_t_max=1,
                                 number_of_windows=1000)
            
                    # Append to the preallocated arrays
                    single_eeg_transposed = single_eeg_windowed_train.transpose(0,2,1)
                    single_nirs_transposed = single_nirs_windowed_train.transpose(0,2,1)
            
                    single_eeg_transposed = single_eeg_transposed[:,:eeg_lookback, eeg_channel_index]
                    single_nirs_transposed = single_nirs_transposed[:,:fnirs_lookback, nirs_channel_index]
            
                    # Stack new data into the existing array, avoiding list append
                    eeg_windowed_train = np.vstack((eeg_windowed_train, single_eeg_transposed))
                    nirs_windowed_train = np.vstack((nirs_windowed_train, single_nirs_transposed))
                elif i not in test and not do_train:
                    single_eeg_windowed_train, single_nirs_windowed_train, meta_data = grab_ordered_windows(
                         nirs_data=nirs_data, 
                         eeg_data=eeg_data,
                         sampling_rate=200,
                         nirs_t_min=nirs_t_min, 
                         nirs_t_max=nirs_t_max,
                         eeg_t_min=0, 
                         eeg_t_max=1)
                    
                    single_eeg_windowed_train = single_eeg_windowed_train.transpose(0,2,1)
                    single_nirs_windowed_train = single_nirs_windowed_train.transpose(0,2,1)
                
                    single_eeg_windowed_train = single_eeg_windowed_train[:,:eeg_lookback, eeg_channel_index]
                    single_nirs_windowed_train = single_nirs_windowed_train[:,:fnirs_lookback, nirs_channel_index]
                    
                    # For test data, direct stacking since no windowing
                    eeg_windowed_train = np.vstack((eeg_windowed_train, single_eeg_windowed_train))
                    nirs_windowed_train = np.vstack((nirs_windowed_train, single_nirs_windowed_train))
                else:
                    single_eeg_windowed_test, single_nirs_windowed_test, meta_data = grab_ordered_windows(
                         nirs_data=nirs_data, 
                         eeg_data=eeg_data,
                         sampling_rate=200,
                         nirs_t_min=nirs_t_min, 
                         nirs_t_max=nirs_t_max,
                         eeg_t_min=0, 
                         eeg_t_max=1)
            
                    single_eeg_windowed_test = single_eeg_windowed_test.transpose(0,2,1)
                    single_nirs_windowed_test = single_nirs_windowed_test.transpose(0,2,1)
                    
                    single_eeg_windowed_test = single_eeg_windowed_test[:,:eeg_lookback, eeg_channel_index]
                    single_nirs_windowed_test = single_nirs_windowed_test[:,:fnirs_lookback, nirs_channel_index]
                    
                    # For test data, direct stacking since no windowing
                    eeg_windowed_test = np.vstack((eeg_windowed_test, single_eeg_windowed_test))
                    nirs_windowed_test = np.vstack((nirs_windowed_test, single_nirs_windowed_test))
                    
                    print(f'Skipping {subject_id}')

            print(f'EEG Shape: {eeg_windowed_train.shape}')
            print(f'NIRS Shape: {nirs_windowed_train.shape}')
            
            if do_train:
                nirs_train_tensor = torch.from_numpy(nirs_windowed_train).float()
                eeg_train_tensor = torch.from_numpy(eeg_windowed_train).float()
                meta_data_tensor = torch.from_numpy(np.array(meta_data)).float()
                
                print(nirs_train_tensor.shape)
                print(eeg_train_tensor.shape)
                
                sequence_length = eeg_train_tensor.shape[1]
                eeg_number_of_features = eeg_train_tensor.shape[2]
                nirs_number_of_features = nirs_train_tensor.shape[2]
                
                dataset = EEGfNIRSData(nirs_train_tensor, eeg_train_tensor)
                dataloader = DataLoader(dataset, batch_size=500, shuffle=True)
            
                latest_epoch = 0
                loss_list = []
                if do_load:
                    model_path = f'{model_name}_epoch_1.pth'
            
                    # find the latest model
                    for file in os.listdir(MODEL_WEIGHTS):
                        if file.startswith(f'{model_name}_epoch_'):
                            epoch = int(file.split('_')[-1].split('.')[0])
                            if epoch > latest_epoch:
                                latest_epoch = epoch
                                model_path = file
                    print(f'Using Model Weights: {model_path}')
                    model.load_state_dict(torch.load(os.path.join(MODEL_WEIGHTS, model_path)))
                    
                    # load loss list
                    with open(os.path.join(MODEL_WEIGHTS, f'loss_{model_name}_{latest_epoch}.csv'), 'r') as file_ptr:
                        reader = csv.reader(file_ptr)
                        loss_list = list(reader)[0]
                    print(f'Last loss: {float(loss_list[-1])/len(dataloader):.4f}')
            
                # Set correct device
                device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
                model.to(device)
            
                # Optimizer and loss function
                optimizer = Adam(model.parameters(), lr=0.00001)
                loss_function = torch.nn.MSELoss()
                for epoch in range(latest_epoch, num_epochs):
                    model.train()
                    total_loss = 0
            
                    for batch_idx, (X_batch, y_batch) in enumerate(dataloader):
                        X_batch = X_batch.to(device).float()
                        y_batch = y_batch.to(device).float()
                        
                        # Forward pass
                        predictions = model(X_batch)
            
                        # Loss calculation
                        loss = loss_function(predictions, y_batch)
            
                        # Backpropagation
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()
            
                        total_loss += loss.item()
                        # if (batch_idx+1) % 20 == 0 or batch_idx == 0:
                        #     print(f'Epoch: {epoch+1}, Batch: {batch_idx+1}, Loss: {loss.item():.4f}')
                    
                    loss_list.append(total_loss)
            
                    if (epoch+1) % 50 == 0:
                        # Save model weights
                        torch.save(model.state_dict(), os.path.join(MODEL_WEIGHTS, f'{model_name}_{epoch+1}.pth'))
                        with open(os.path.join(MODEL_WEIGHTS,f'loss_{model_name}_{epoch+1}.csv'), 'w', newline='') as file_ptr:
                            wr = csv.writer(file_ptr, quoting=csv.QUOTE_ALL)
                            wr.writerow(loss_list)
                        
                    # Plotting target vs. output for the first example in the last batch
                    # single_actual = y_batch[0, :, 0].detach().cpu().numpy()
                    # single_predicted = predictions[0,:,0].detach().cpu().numpy()
                    # r2 = r2_score(single_actual, single_predicted)
                    # print(f'R-squared: {r2}')
                    # if (epoch+1) % 10 == 0:
                    #     plot_series(single_actual, single_predicted, epoch)
            
                    print(f'Epoch: {epoch+1}, Average Loss: {total_loss / len(dataloader):.4f}')
            
            # Perform inference on test
            
            nirs_test_tensor = torch.from_numpy(nirs_windowed_test).float()
            eeg_test_tensor = torch.from_numpy(eeg_windowed_test).float()
            
            # Assuming fnirs_test and eeg_test are your test datasets
            test_dataset = EEGfNIRSData(nirs_test_tensor, eeg_test_tensor)
            test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
            
            # Get weights for specific epoch
            weight_epochs = [50,100]
            for weight_epoch in weight_epochs:
                model_path = f'{model_name}_{weight_epoch}.pth'
                model.load_state_dict(torch.load(os.path.join(MODEL_WEIGHTS, model_path)))
                model.to(device)
                # Set model to evaluation mode
                model.eval()
                
                # Perform inference on test data
                predictions = []
                targets = []
                for batch_idx, (X_batch, y_batch) in enumerate(test_loader):
                    X_batch = X_batch.to(device).float()
                    y_batch = y_batch.to(device).float()
                    predictions.append(model(X_batch).detach().cpu().numpy())
                    targets.append(y_batch.detach().cpu().numpy())
                
                predictions = np.array(predictions)
                targets = np.array(targets)
                
                # concatenate and plot
                predictions = predictions.reshape(-1, len(eeg_channels_to_use))
                targets = targets.reshape(-1, len(eeg_channels_to_use))
            
                scipy.io.savemat(os.path.join(OUTPUT_DIRECTORY, f'test_{model_name}_{weight_epoch}.mat'), {'X': targets, 
                                                                     'XPred':predictions,
                                                                    'bins':10,
                                                                    'scale':10,
                                                                    'srate':200})
                
                # R2 score
                r2 = r2_score(targets, predictions)
                print(f'Epoch-{weight_epoch}: {r2}')
                
                # Plotting target vs. output on concatenated data
                for i in range(len(eeg_channels_to_use)):
                    plt.figure(figsize=(10, 4))
                    plt.plot(targets[:,i], label='Target')
                    plt.plot(predictions[:,i], label='Output', linestyle='--')
                    plt.title(f'Epoch-{weight_epoch} Channel {eeg_channels_to_use[i]} : {r2}')
                    plt.legend()
                    plt.grid(True)
                    plt.savefig(os.path.join(OUTPUT_DIRECTORY, f'test_{model_name}_{weight_epoch}.jpeg'))

            gc.collect()
            


Device: cuda
Model name exists, skipping transformer_01_FP1
Model name exists, skipping transformer_01_AFF5h
Starting transformer_01_AFz
A100 GPU detected, using flash attention if input tensor is on cuda
Skipping 01
EEG Shape: (25000, 200, 1)
NIRS Shape: (25000, 4000, 16)
torch.Size([25000, 4000, 16])
torch.Size([25000, 200, 1])




Epoch: 1, Average Loss: 1.3798
Epoch: 2, Average Loss: 1.2834
Epoch: 3, Average Loss: 1.2481
Epoch: 4, Average Loss: 1.2284
Epoch: 5, Average Loss: 1.2159
Epoch: 6, Average Loss: 1.2068
Epoch: 7, Average Loss: 1.2004
Epoch: 8, Average Loss: 1.1955
Epoch: 9, Average Loss: 1.1912
Epoch: 10, Average Loss: 1.1881
Epoch: 11, Average Loss: 1.1849
Epoch: 12, Average Loss: 1.1823
Epoch: 13, Average Loss: 1.1803
Epoch: 14, Average Loss: 1.1780
Epoch: 15, Average Loss: 1.1762
Epoch: 16, Average Loss: 1.1739
Epoch: 17, Average Loss: 1.1723
Epoch: 18, Average Loss: 1.1704
Epoch: 19, Average Loss: 1.1685
Epoch: 20, Average Loss: 1.1669
Epoch: 21, Average Loss: 1.1648
Epoch: 22, Average Loss: 1.1632
Epoch: 23, Average Loss: 1.1611
Epoch: 24, Average Loss: 1.1592
Epoch: 25, Average Loss: 1.1574
Epoch: 26, Average Loss: 1.1552
Epoch: 27, Average Loss: 1.1529
Epoch: 28, Average Loss: 1.1511
Epoch: 29, Average Loss: 1.1487
Epoch: 30, Average Loss: 1.1465
Epoch: 31, Average Loss: 1.1443
Epoch: 32, Averag



Epoch-100: -0.14171630277094116
Starting transformer_01_F1
Skipping 01
EEG Shape: (25000, 200, 1)
NIRS Shape: (25000, 4000, 16)
torch.Size([25000, 4000, 16])
torch.Size([25000, 200, 1])




Epoch: 1, Average Loss: 1.3152
Epoch: 2, Average Loss: 1.2160
Epoch: 3, Average Loss: 1.1792
Epoch: 4, Average Loss: 1.1583
Epoch: 5, Average Loss: 1.1447
Epoch: 6, Average Loss: 1.1350
Epoch: 7, Average Loss: 1.1281
Epoch: 8, Average Loss: 1.1223
Epoch: 9, Average Loss: 1.1178
Epoch: 10, Average Loss: 1.1137
Epoch: 11, Average Loss: 1.1104
Epoch: 12, Average Loss: 1.1072
Epoch: 13, Average Loss: 1.1043
Epoch: 14, Average Loss: 1.1012
Epoch: 15, Average Loss: 1.0987
Epoch: 16, Average Loss: 1.0963
Epoch: 17, Average Loss: 1.0934
Epoch: 18, Average Loss: 1.0905
Epoch: 19, Average Loss: 1.0879
Epoch: 20, Average Loss: 1.0850
Epoch: 21, Average Loss: 1.0823
Epoch: 22, Average Loss: 1.0791
Epoch: 23, Average Loss: 1.0765
Epoch: 24, Average Loss: 1.0732
Epoch: 25, Average Loss: 1.0702
Epoch: 26, Average Loss: 1.0667
Epoch: 27, Average Loss: 1.0632
Epoch: 28, Average Loss: 1.0600
Epoch: 29, Average Loss: 1.0558
Epoch: 30, Average Loss: 1.0524
Epoch: 31, Average Loss: 1.0480
Epoch: 32, Averag



Epoch-100: -0.19995375261421144
Starting transformer_01_FC5
Skipping 01
EEG Shape: (25000, 200, 1)
NIRS Shape: (25000, 4000, 16)
torch.Size([25000, 4000, 16])
torch.Size([25000, 200, 1])




Epoch: 1, Average Loss: 1.1987
Epoch: 2, Average Loss: 1.1041
Epoch: 3, Average Loss: 1.0693
Epoch: 4, Average Loss: 1.0497
Epoch: 5, Average Loss: 1.0365
Epoch: 6, Average Loss: 1.0275
Epoch: 7, Average Loss: 1.0210
Epoch: 8, Average Loss: 1.0158
Epoch: 9, Average Loss: 1.0116
Epoch: 10, Average Loss: 1.0083
Epoch: 11, Average Loss: 1.0052
Epoch: 12, Average Loss: 1.0024
Epoch: 13, Average Loss: 0.9998
Epoch: 14, Average Loss: 0.9977
Epoch: 15, Average Loss: 0.9951
Epoch: 16, Average Loss: 0.9929
Epoch: 17, Average Loss: 0.9907
Epoch: 18, Average Loss: 0.9884
Epoch: 19, Average Loss: 0.9864
Epoch: 20, Average Loss: 0.9839
Epoch: 21, Average Loss: 0.9814
Epoch: 22, Average Loss: 0.9791
Epoch: 23, Average Loss: 0.9770
Epoch: 24, Average Loss: 0.9744
Epoch: 25, Average Loss: 0.9715
Epoch: 26, Average Loss: 0.9692
Epoch: 27, Average Loss: 0.9662
Epoch: 28, Average Loss: 0.9629
Epoch: 29, Average Loss: 0.9600
Epoch: 30, Average Loss: 0.9572
Epoch: 31, Average Loss: 0.9532
Epoch: 32, Averag



Epoch-100: -0.1774066230175182
Starting transformer_01_FC1
Skipping 01
EEG Shape: (25000, 200, 1)
NIRS Shape: (25000, 4000, 16)
torch.Size([25000, 4000, 16])
torch.Size([25000, 200, 1])




Epoch: 1, Average Loss: 1.1657
Epoch: 2, Average Loss: 1.0656
Epoch: 3, Average Loss: 1.0291
Epoch: 4, Average Loss: 1.0090
Epoch: 5, Average Loss: 0.9958
Epoch: 6, Average Loss: 0.9864
Epoch: 7, Average Loss: 0.9797
Epoch: 8, Average Loss: 0.9739
Epoch: 9, Average Loss: 0.9693
Epoch: 10, Average Loss: 0.9650
Epoch: 11, Average Loss: 0.9613
Epoch: 12, Average Loss: 0.9581
Epoch: 13, Average Loss: 0.9553
Epoch: 14, Average Loss: 0.9524
Epoch: 15, Average Loss: 0.9492
Epoch: 16, Average Loss: 0.9467
Epoch: 17, Average Loss: 0.9442
Epoch: 18, Average Loss: 0.9413
Epoch: 19, Average Loss: 0.9385
Epoch: 20, Average Loss: 0.9367
Epoch: 21, Average Loss: 0.9336
Epoch: 22, Average Loss: 0.9306
Epoch: 23, Average Loss: 0.9277
Epoch: 24, Average Loss: 0.9255
Epoch: 25, Average Loss: 0.9212
Epoch: 26, Average Loss: 0.9190
Epoch: 27, Average Loss: 0.9154
Epoch: 28, Average Loss: 0.9117
Epoch: 29, Average Loss: 0.9082
Epoch: 30, Average Loss: 0.9037
Epoch: 31, Average Loss: 0.9005
Epoch: 32, Averag



Epoch-100: -0.17577198853456166
Starting transformer_01_T7
Skipping 01
EEG Shape: (25000, 200, 1)
NIRS Shape: (25000, 4000, 16)
torch.Size([25000, 4000, 16])
torch.Size([25000, 200, 1])




Epoch: 1, Average Loss: 1.1183
Epoch: 2, Average Loss: 1.0218
Epoch: 3, Average Loss: 0.9861
Epoch: 4, Average Loss: 0.9659
Epoch: 5, Average Loss: 0.9527
Epoch: 6, Average Loss: 0.9436
Epoch: 7, Average Loss: 0.9369
Epoch: 8, Average Loss: 0.9311
Epoch: 9, Average Loss: 0.9268
Epoch: 10, Average Loss: 0.9227
Epoch: 11, Average Loss: 0.9189
Epoch: 12, Average Loss: 0.9159
Epoch: 13, Average Loss: 0.9127
Epoch: 14, Average Loss: 0.9099
Epoch: 15, Average Loss: 0.9067
Epoch: 16, Average Loss: 0.9030
Epoch: 17, Average Loss: 0.9002
Epoch: 18, Average Loss: 0.8968
Epoch: 19, Average Loss: 0.8938
Epoch: 20, Average Loss: 0.8905
Epoch: 21, Average Loss: 0.8872
Epoch: 22, Average Loss: 0.8841
Epoch: 23, Average Loss: 0.8795
Epoch: 24, Average Loss: 0.8744
Epoch: 25, Average Loss: 0.8708
Epoch: 26, Average Loss: 0.8660
Epoch: 27, Average Loss: 0.8615
Epoch: 28, Average Loss: 0.8574
Epoch: 29, Average Loss: 0.8514
Epoch: 30, Average Loss: 0.8463
Epoch: 31, Average Loss: 0.8414
Epoch: 32, Averag



Epoch-100: -0.25005585437655453
Starting transformer_01_C3
Skipping 01
EEG Shape: (25000, 200, 1)
NIRS Shape: (25000, 4000, 16)
torch.Size([25000, 4000, 16])
torch.Size([25000, 200, 1])




Epoch: 1, Average Loss: 1.2188
Epoch: 2, Average Loss: 1.1181
Epoch: 3, Average Loss: 1.0812
Epoch: 4, Average Loss: 1.0610
Epoch: 5, Average Loss: 1.0474
Epoch: 6, Average Loss: 1.0377
Epoch: 7, Average Loss: 1.0308
Epoch: 8, Average Loss: 1.0249
Epoch: 9, Average Loss: 1.0199
Epoch: 10, Average Loss: 1.0160
Epoch: 11, Average Loss: 1.0122
Epoch: 12, Average Loss: 1.0087
Epoch: 13, Average Loss: 1.0053
Epoch: 14, Average Loss: 1.0024
Epoch: 15, Average Loss: 0.9993
Epoch: 16, Average Loss: 0.9961
Epoch: 17, Average Loss: 0.9934
Epoch: 18, Average Loss: 0.9898
Epoch: 19, Average Loss: 0.9867
Epoch: 20, Average Loss: 0.9835
Epoch: 21, Average Loss: 0.9800
Epoch: 22, Average Loss: 0.9764
Epoch: 23, Average Loss: 0.9725
Epoch: 24, Average Loss: 0.9685
Epoch: 25, Average Loss: 0.9646
Epoch: 26, Average Loss: 0.9603
Epoch: 27, Average Loss: 0.9566
Epoch: 28, Average Loss: 0.9519
Epoch: 29, Average Loss: 0.9469
Epoch: 30, Average Loss: 0.9419
Epoch: 31, Average Loss: 0.9362
Epoch: 32, Averag



Epoch-100: -0.3032633549848134
Starting transformer_01_Cz
Skipping 01
EEG Shape: (25000, 200, 1)
NIRS Shape: (25000, 4000, 16)
torch.Size([25000, 4000, 16])
torch.Size([25000, 200, 1])




Epoch: 1, Average Loss: 0.9757
Epoch: 2, Average Loss: 0.8744
Epoch: 3, Average Loss: 0.8381
Epoch: 4, Average Loss: 0.8182
Epoch: 5, Average Loss: 0.8054
Epoch: 6, Average Loss: 0.7965
Epoch: 7, Average Loss: 0.7902
Epoch: 8, Average Loss: 0.7850
Epoch: 9, Average Loss: 0.7808
Epoch: 10, Average Loss: 0.7768
Epoch: 11, Average Loss: 0.7739
Epoch: 12, Average Loss: 0.7706
Epoch: 13, Average Loss: 0.7676
Epoch: 14, Average Loss: 0.7649
Epoch: 15, Average Loss: 0.7623
Epoch: 16, Average Loss: 0.7597
Epoch: 17, Average Loss: 0.7573
Epoch: 18, Average Loss: 0.7543
Epoch: 19, Average Loss: 0.7516
Epoch: 20, Average Loss: 0.7491
Epoch: 21, Average Loss: 0.7459
Epoch: 22, Average Loss: 0.7434
Epoch: 23, Average Loss: 0.7403
Epoch: 24, Average Loss: 0.7366
Epoch: 25, Average Loss: 0.7339
Epoch: 26, Average Loss: 0.7309
Epoch: 27, Average Loss: 0.7275
Epoch: 28, Average Loss: 0.7236
Epoch: 29, Average Loss: 0.7204
Epoch: 30, Average Loss: 0.7165
Epoch: 31, Average Loss: 0.7123
Epoch: 32, Averag



Epoch-100: -0.28649734407444005
Starting transformer_01_CP5
Skipping 01
EEG Shape: (25000, 200, 1)
NIRS Shape: (25000, 4000, 16)
torch.Size([25000, 4000, 16])
torch.Size([25000, 200, 1])




Epoch: 1, Average Loss: 1.1686
Epoch: 2, Average Loss: 1.0694
Epoch: 3, Average Loss: 1.0331
Epoch: 4, Average Loss: 1.0126
Epoch: 5, Average Loss: 0.9997
Epoch: 6, Average Loss: 0.9899
Epoch: 7, Average Loss: 0.9831
Epoch: 8, Average Loss: 0.9772
Epoch: 9, Average Loss: 0.9723
Epoch: 10, Average Loss: 0.9681
Epoch: 11, Average Loss: 0.9643
Epoch: 12, Average Loss: 0.9614
Epoch: 13, Average Loss: 0.9574
Epoch: 14, Average Loss: 0.9535
Epoch: 15, Average Loss: 0.9501
Epoch: 16, Average Loss: 0.9469
Epoch: 17, Average Loss: 0.9439
Epoch: 18, Average Loss: 0.9403
Epoch: 19, Average Loss: 0.9368
Epoch: 20, Average Loss: 0.9329
Epoch: 21, Average Loss: 0.9294
Epoch: 22, Average Loss: 0.9266
Epoch: 23, Average Loss: 0.9221
Epoch: 24, Average Loss: 0.9181
Epoch: 25, Average Loss: 0.9138
Epoch: 26, Average Loss: 0.9098
Epoch: 27, Average Loss: 0.9053
Epoch: 28, Average Loss: 0.9011
Epoch: 29, Average Loss: 0.8957
Epoch: 30, Average Loss: 0.8907
Epoch: 31, Average Loss: 0.8865
Epoch: 32, Averag



Epoch-100: -0.25701135910683104
Starting transformer_01_CP1
Skipping 01
EEG Shape: (25000, 200, 1)
NIRS Shape: (25000, 4000, 16)
torch.Size([25000, 4000, 16])
torch.Size([25000, 200, 1])




Epoch: 1, Average Loss: 1.1618
Epoch: 2, Average Loss: 1.0554
Epoch: 3, Average Loss: 1.0169
Epoch: 4, Average Loss: 0.9954
Epoch: 5, Average Loss: 0.9812
Epoch: 6, Average Loss: 0.9714
Epoch: 7, Average Loss: 0.9636
Epoch: 8, Average Loss: 0.9572
Epoch: 9, Average Loss: 0.9519
Epoch: 10, Average Loss: 0.9470
Epoch: 11, Average Loss: 0.9428
Epoch: 12, Average Loss: 0.9380
Epoch: 13, Average Loss: 0.9340
Epoch: 14, Average Loss: 0.9304
Epoch: 15, Average Loss: 0.9265
Epoch: 16, Average Loss: 0.9230
Epoch: 17, Average Loss: 0.9194
Epoch: 18, Average Loss: 0.9158
Epoch: 19, Average Loss: 0.9121
Epoch: 20, Average Loss: 0.9082
Epoch: 21, Average Loss: 0.9042
Epoch: 22, Average Loss: 0.9006
Epoch: 23, Average Loss: 0.8962
Epoch: 24, Average Loss: 0.8925
Epoch: 25, Average Loss: 0.8874
Epoch: 26, Average Loss: 0.8834
Epoch: 27, Average Loss: 0.8788
Epoch: 28, Average Loss: 0.8748
Epoch: 29, Average Loss: 0.8686
Epoch: 30, Average Loss: 0.8630
Epoch: 31, Average Loss: 0.8575
Epoch: 32, Averag



Epoch-100: -0.25697586671699546
Starting transformer_01_P7
Skipping 01
EEG Shape: (25000, 200, 1)
NIRS Shape: (25000, 4000, 16)
torch.Size([25000, 4000, 16])
torch.Size([25000, 200, 1])




Epoch: 1, Average Loss: 1.1266
Epoch: 2, Average Loss: 1.0264
Epoch: 3, Average Loss: 0.9896
Epoch: 4, Average Loss: 0.9687
Epoch: 5, Average Loss: 0.9547
Epoch: 6, Average Loss: 0.9448
Epoch: 7, Average Loss: 0.9366
Epoch: 8, Average Loss: 0.9307
Epoch: 9, Average Loss: 0.9254
Epoch: 10, Average Loss: 0.9206
Epoch: 11, Average Loss: 0.9164
Epoch: 12, Average Loss: 0.9116
Epoch: 13, Average Loss: 0.9081
Epoch: 14, Average Loss: 0.9041
Epoch: 15, Average Loss: 0.9005
Epoch: 16, Average Loss: 0.8967
Epoch: 17, Average Loss: 0.8926
Epoch: 18, Average Loss: 0.8884
Epoch: 19, Average Loss: 0.8848
Epoch: 20, Average Loss: 0.8804
Epoch: 21, Average Loss: 0.8760
Epoch: 22, Average Loss: 0.8711
Epoch: 23, Average Loss: 0.8671
Epoch: 24, Average Loss: 0.8621
Epoch: 25, Average Loss: 0.8588
Epoch: 26, Average Loss: 0.8523
Epoch: 27, Average Loss: 0.8451
Epoch: 28, Average Loss: 0.8400
Epoch: 29, Average Loss: 0.8333
Epoch: 30, Average Loss: 0.8277
Epoch: 31, Average Loss: 0.8226
Epoch: 32, Averag



Epoch-100: -0.24241549907865068
Starting transformer_01_P3
Skipping 01
EEG Shape: (25000, 200, 1)
NIRS Shape: (25000, 4000, 16)
torch.Size([25000, 4000, 16])
torch.Size([25000, 200, 1])




Epoch: 1, Average Loss: 1.1502
Epoch: 2, Average Loss: 1.0445
Epoch: 3, Average Loss: 1.0047
Epoch: 4, Average Loss: 0.9819
Epoch: 5, Average Loss: 0.9674
Epoch: 6, Average Loss: 0.9570
Epoch: 7, Average Loss: 0.9491
Epoch: 8, Average Loss: 0.9424
Epoch: 9, Average Loss: 0.9367
Epoch: 10, Average Loss: 0.9319
Epoch: 11, Average Loss: 0.9273
Epoch: 12, Average Loss: 0.9230
Epoch: 13, Average Loss: 0.9189
Epoch: 14, Average Loss: 0.9147
Epoch: 15, Average Loss: 0.9110
Epoch: 16, Average Loss: 0.9074
Epoch: 17, Average Loss: 0.9034
Epoch: 18, Average Loss: 0.8997
Epoch: 19, Average Loss: 0.8958
Epoch: 20, Average Loss: 0.8922
Epoch: 21, Average Loss: 0.8882
Epoch: 22, Average Loss: 0.8845
Epoch: 23, Average Loss: 0.8799
Epoch: 24, Average Loss: 0.8751
Epoch: 25, Average Loss: 0.8716
Epoch: 26, Average Loss: 0.8662
Epoch: 27, Average Loss: 0.8619
Epoch: 28, Average Loss: 0.8560
Epoch: 29, Average Loss: 0.8505
Epoch: 30, Average Loss: 0.8457
Epoch: 31, Average Loss: 0.8389
Epoch: 32, Averag

  plt.figure(figsize=(10, 4))


Epoch-100: -0.20458324059294064
Starting transformer_01_Pz
Skipping 01
EEG Shape: (25000, 200, 1)
NIRS Shape: (25000, 4000, 16)
torch.Size([25000, 4000, 16])
torch.Size([25000, 200, 1])




Epoch: 1, Average Loss: 1.1021
Epoch: 2, Average Loss: 0.9970
Epoch: 3, Average Loss: 0.9583
Epoch: 4, Average Loss: 0.9366
Epoch: 5, Average Loss: 0.9226
Epoch: 6, Average Loss: 0.9126
Epoch: 7, Average Loss: 0.9048
Epoch: 8, Average Loss: 0.8986
Epoch: 9, Average Loss: 0.8930
Epoch: 10, Average Loss: 0.8879
Epoch: 11, Average Loss: 0.8836
Epoch: 12, Average Loss: 0.8790
Epoch: 13, Average Loss: 0.8751
Epoch: 14, Average Loss: 0.8709
Epoch: 15, Average Loss: 0.8670
Epoch: 16, Average Loss: 0.8631
Epoch: 17, Average Loss: 0.8590
Epoch: 18, Average Loss: 0.8554
Epoch: 19, Average Loss: 0.8510
Epoch: 20, Average Loss: 0.8463
Epoch: 21, Average Loss: 0.8426
Epoch: 22, Average Loss: 0.8386
Epoch: 23, Average Loss: 0.8340
Epoch: 24, Average Loss: 0.8296
Epoch: 25, Average Loss: 0.8250
Epoch: 26, Average Loss: 0.8206
Epoch: 27, Average Loss: 0.8150
Epoch: 28, Average Loss: 0.8110
Epoch: 29, Average Loss: 0.8056
Epoch: 30, Average Loss: 0.8003
Epoch: 31, Average Loss: 0.7946
Epoch: 32, Averag



Epoch-100: -0.3149634995853865
Starting transformer_01_POz
Skipping 01
EEG Shape: (25000, 200, 1)
NIRS Shape: (25000, 4000, 16)
torch.Size([25000, 4000, 16])
torch.Size([25000, 200, 1])




Epoch: 1, Average Loss: 1.0791
Epoch: 2, Average Loss: 0.9815
Epoch: 3, Average Loss: 0.9457
Epoch: 4, Average Loss: 0.9255
Epoch: 5, Average Loss: 0.9120
Epoch: 6, Average Loss: 0.9032
Epoch: 7, Average Loss: 0.8958
Epoch: 8, Average Loss: 0.8902
Epoch: 9, Average Loss: 0.8857
Epoch: 10, Average Loss: 0.8818
Epoch: 11, Average Loss: 0.8779
Epoch: 12, Average Loss: 0.8750
Epoch: 13, Average Loss: 0.8712
Epoch: 14, Average Loss: 0.8679
Epoch: 15, Average Loss: 0.8651
Epoch: 16, Average Loss: 0.8614
Epoch: 17, Average Loss: 0.8581
Epoch: 18, Average Loss: 0.8547
Epoch: 19, Average Loss: 0.8514
Epoch: 20, Average Loss: 0.8477
Epoch: 21, Average Loss: 0.8440
Epoch: 22, Average Loss: 0.8400
Epoch: 23, Average Loss: 0.8369
Epoch: 24, Average Loss: 0.8320
Epoch: 25, Average Loss: 0.8276
Epoch: 26, Average Loss: 0.8229
Epoch: 27, Average Loss: 0.8179
Epoch: 28, Average Loss: 0.8138
Epoch: 29, Average Loss: 0.8076
Epoch: 30, Average Loss: 0.8028
Epoch: 31, Average Loss: 0.7970
Epoch: 32, Averag



Epoch-100: -0.25071853981386916
Starting transformer_01_O1
Skipping 01
EEG Shape: (25000, 200, 1)
NIRS Shape: (25000, 4000, 16)
torch.Size([25000, 4000, 16])
torch.Size([25000, 200, 1])




Epoch: 1, Average Loss: 1.1097
Epoch: 2, Average Loss: 1.0104
Epoch: 3, Average Loss: 0.9739
Epoch: 4, Average Loss: 0.9533
Epoch: 5, Average Loss: 0.9399
Epoch: 6, Average Loss: 0.9302
Epoch: 7, Average Loss: 0.9230
Epoch: 8, Average Loss: 0.9172
Epoch: 9, Average Loss: 0.9125
Epoch: 10, Average Loss: 0.9080
Epoch: 11, Average Loss: 0.9041
Epoch: 12, Average Loss: 0.9004
Epoch: 13, Average Loss: 0.8969
Epoch: 14, Average Loss: 0.8935
Epoch: 15, Average Loss: 0.8903
Epoch: 16, Average Loss: 0.8869
Epoch: 17, Average Loss: 0.8834
Epoch: 18, Average Loss: 0.8798
Epoch: 19, Average Loss: 0.8762
Epoch: 20, Average Loss: 0.8724
Epoch: 21, Average Loss: 0.8681
Epoch: 22, Average Loss: 0.8639
Epoch: 23, Average Loss: 0.8597
Epoch: 24, Average Loss: 0.8549
Epoch: 25, Average Loss: 0.8504
Epoch: 26, Average Loss: 0.8452
Epoch: 27, Average Loss: 0.8398
Epoch: 28, Average Loss: 0.8346
Epoch: 29, Average Loss: 0.8289
Epoch: 30, Average Loss: 0.8218
Epoch: 31, Average Loss: 0.8162
Epoch: 32, Averag



Epoch-100: -0.29505702573490944
Starting transformer_01_FP2
Skipping 01
EEG Shape: (25000, 200, 1)
NIRS Shape: (25000, 4000, 16)
torch.Size([25000, 4000, 16])
torch.Size([25000, 200, 1])




Epoch: 1, Average Loss: 1.3842
Epoch: 2, Average Loss: 1.2884
Epoch: 3, Average Loss: 1.2537
Epoch: 4, Average Loss: 1.2340
Epoch: 5, Average Loss: 1.2216
Epoch: 6, Average Loss: 1.2129
Epoch: 7, Average Loss: 1.2066
Epoch: 8, Average Loss: 1.2019
Epoch: 9, Average Loss: 1.1982
Epoch: 10, Average Loss: 1.1950
Epoch: 11, Average Loss: 1.1927
Epoch: 12, Average Loss: 1.1904
Epoch: 13, Average Loss: 1.1885
Epoch: 14, Average Loss: 1.1865
Epoch: 15, Average Loss: 1.1847
Epoch: 16, Average Loss: 1.1833
Epoch: 17, Average Loss: 1.1817
Epoch: 18, Average Loss: 1.1801
Epoch: 19, Average Loss: 1.1787
Epoch: 20, Average Loss: 1.1773
Epoch: 21, Average Loss: 1.1758
Epoch: 22, Average Loss: 1.1743
Epoch: 23, Average Loss: 1.1726
Epoch: 24, Average Loss: 1.1708
Epoch: 25, Average Loss: 1.1695
Epoch: 26, Average Loss: 1.1673
Epoch: 27, Average Loss: 1.1656
Epoch: 28, Average Loss: 1.1638
Epoch: 29, Average Loss: 1.1616
Epoch: 30, Average Loss: 1.1591
Epoch: 31, Average Loss: 1.1575
Epoch: 32, Averag



Epoch-100: -0.10078590852558822
Starting transformer_01_AFF6h
Skipping 01
EEG Shape: (25000, 200, 1)
NIRS Shape: (25000, 4000, 16)
torch.Size([25000, 4000, 16])
torch.Size([25000, 200, 1])




Epoch: 1, Average Loss: 1.3010
Epoch: 2, Average Loss: 1.2064
Epoch: 3, Average Loss: 1.1715
Epoch: 4, Average Loss: 1.1515
Epoch: 5, Average Loss: 1.1389
Epoch: 6, Average Loss: 1.1302
Epoch: 7, Average Loss: 1.1240
Epoch: 8, Average Loss: 1.1193
Epoch: 9, Average Loss: 1.1152
Epoch: 10, Average Loss: 1.1120
Epoch: 11, Average Loss: 1.1092
Epoch: 12, Average Loss: 1.1068
Epoch: 13, Average Loss: 1.1047
Epoch: 14, Average Loss: 1.1027
Epoch: 15, Average Loss: 1.1006
Epoch: 16, Average Loss: 1.0991
Epoch: 17, Average Loss: 1.0969
Epoch: 18, Average Loss: 1.0953
Epoch: 19, Average Loss: 1.0934
Epoch: 20, Average Loss: 1.0913
Epoch: 21, Average Loss: 1.0899
Epoch: 22, Average Loss: 1.0883
Epoch: 23, Average Loss: 1.0861
Epoch: 24, Average Loss: 1.0841
Epoch: 25, Average Loss: 1.0820
Epoch: 26, Average Loss: 1.0798
Epoch: 27, Average Loss: 1.0776
Epoch: 28, Average Loss: 1.0756
Epoch: 29, Average Loss: 1.0734
Epoch: 30, Average Loss: 1.0714
Epoch: 31, Average Loss: 1.0683
Epoch: 32, Averag



Epoch-100: -0.14368634774012978
Starting transformer_01_F2
Skipping 01
EEG Shape: (25000, 200, 1)
NIRS Shape: (25000, 4000, 16)
torch.Size([25000, 4000, 16])
torch.Size([25000, 200, 1])




Epoch: 1, Average Loss: 1.3103
Epoch: 2, Average Loss: 1.2120
Epoch: 3, Average Loss: 1.1760
Epoch: 4, Average Loss: 1.1559
Epoch: 5, Average Loss: 1.1432
Epoch: 6, Average Loss: 1.1340
Epoch: 7, Average Loss: 1.1275
Epoch: 8, Average Loss: 1.1223
Epoch: 9, Average Loss: 1.1181
Epoch: 10, Average Loss: 1.1143
Epoch: 11, Average Loss: 1.1114
Epoch: 12, Average Loss: 1.1085
Epoch: 13, Average Loss: 1.1060
Epoch: 14, Average Loss: 1.1032
Epoch: 15, Average Loss: 1.1010
Epoch: 16, Average Loss: 1.0988
Epoch: 17, Average Loss: 1.0965
Epoch: 18, Average Loss: 1.0944
Epoch: 19, Average Loss: 1.0922
Epoch: 20, Average Loss: 1.0900
Epoch: 21, Average Loss: 1.0880
Epoch: 22, Average Loss: 1.0852
Epoch: 23, Average Loss: 1.0826
Epoch: 24, Average Loss: 1.0807
Epoch: 25, Average Loss: 1.0777
Epoch: 26, Average Loss: 1.0754
Epoch: 27, Average Loss: 1.0730
Epoch: 28, Average Loss: 1.0703
Epoch: 29, Average Loss: 1.0681
Epoch: 30, Average Loss: 1.0649
Epoch: 31, Average Loss: 1.0619
Epoch: 32, Averag



Epoch-100: -0.12860007473334978
Starting transformer_01_FC2
Skipping 01
EEG Shape: (25000, 200, 1)
NIRS Shape: (25000, 4000, 16)
torch.Size([25000, 4000, 16])
torch.Size([25000, 200, 1])




Epoch: 1, Average Loss: 1.2237
Epoch: 2, Average Loss: 1.1261
Epoch: 3, Average Loss: 1.0889
Epoch: 4, Average Loss: 1.0677
Epoch: 5, Average Loss: 1.0541
Epoch: 6, Average Loss: 1.0441
Epoch: 7, Average Loss: 1.0371
Epoch: 8, Average Loss: 1.0310
Epoch: 9, Average Loss: 1.0260
Epoch: 10, Average Loss: 1.0220
Epoch: 11, Average Loss: 1.0179
Epoch: 12, Average Loss: 1.0142
Epoch: 13, Average Loss: 1.0109
Epoch: 14, Average Loss: 1.0076
Epoch: 15, Average Loss: 1.0045
Epoch: 16, Average Loss: 1.0016
Epoch: 17, Average Loss: 0.9987
Epoch: 18, Average Loss: 0.9958
Epoch: 19, Average Loss: 0.9924
Epoch: 20, Average Loss: 0.9892
Epoch: 21, Average Loss: 0.9867
Epoch: 22, Average Loss: 0.9844
Epoch: 23, Average Loss: 0.9800
Epoch: 24, Average Loss: 0.9764
Epoch: 25, Average Loss: 0.9733
Epoch: 26, Average Loss: 0.9698
Epoch: 27, Average Loss: 0.9667
Epoch: 28, Average Loss: 0.9619
Epoch: 29, Average Loss: 0.9576
Epoch: 30, Average Loss: 0.9539
Epoch: 31, Average Loss: 0.9497
Epoch: 32, Averag



Epoch-100: -0.17092877526117212
Starting transformer_01_FC6
Skipping 01
EEG Shape: (25000, 200, 1)
NIRS Shape: (25000, 4000, 16)
torch.Size([25000, 4000, 16])
torch.Size([25000, 200, 1])




Epoch: 1, Average Loss: 1.2028
