## 1. Loading the config and the subject data

In [1]:
import torch
import wandb, os, json
import time

from muon import Muon
from model_model import TransformerModel
from model_electrode_embedding import ElectrodeEmbedding_Learned, ElectrodeEmbedding_Learned_FixedVocabulary
from model_electrode_embedding import ElectrodeDataEmbeddingFFT, ElectrodeDataEmbedding

from dataset import load_dataloaders, load_subjects
from evaluation_btbench import FrozenModelEvaluation_SS_SM
from train_utils import log, update_dir_name, update_random_seed, convert_dtypes, parse_configs_from_args, get_default_configs, get_shared_memory_info

training_config, model_config, cluster_config = get_default_configs(random_string="TEMP", wandb_project="")

# This is to be used to parse the command line arguments
# parse_configs_from_args(training_config, model_config, cluster_config)
# Instead, for now just hardcode the configs
training_config['train_subject_trials'] = [('mgh1', 3)] #[('mgh1', 3), ('mgh1', 2)]
training_config['eval_subject_trials'] = []
cluster_config['eval_model_every_n_epochs'] = 3
cluster_config['cache_subjects'] = True
model_config['name'] = 'EEG-IEEG'

# EEG channel names
EEG_channels = ['Fp1', 'Fp2', 'F3', 'Fz', 'F4', 'C3', 'Cz', 'C4', 'P3', 'Pz', 'P4', 'O1', 'O2', 'F7', 'F8', 'T3', 'T4', 'T5', 'T6']

dir_name = update_dir_name(model_config, training_config, cluster_config)
update_random_seed(training_config)
cluster_config['wandb_name'] = cluster_config['dir_name']
log(f"Directory name: {dir_name}", priority=0)

if len(cluster_config['wandb_project'])==0: wandb = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log(f"Using device: {device}", priority=0)

log(f"Loading subjects...", priority=0)
all_subjects = load_subjects(training_config['train_subject_trials'], training_config['eval_subject_trials'], training_config['data_dtype'], 
                             cache=cluster_config['cache_subjects'], allow_corrupted=False)


[12:48:33 gpu 0.0G ram 0.5G] (0) Directory name: EEG-IEEG_nst1_dm192_nh12_nl5_5_eaM_eeL_fb1_cls_lr0.003_rTEMP
[12:48:33 gpu 0.0G ram 0.5G] (0) Using device: cuda
[12:48:33 gpu 0.0G ram 0.5G] (0) Loading subjects...
[12:48:33 gpu 0.0G ram 0.5G] (1)     loading subject mgh1...


In [2]:
train_dataloader, test_dataloader = load_dataloaders(
    all_subjects, training_config['train_subject_trials'], training_config['p_test'], 
    model_config['sample_timebin_size'], model_config['max_n_timebins'], training_config['data_dtype'], 
    training_config['batch_size'],
    num_workers_dataloaders=cluster_config['num_workers_dataloaders'], 
    prefetch_factor=cluster_config['prefetch_factor'],
    max_n_electrodes=model_config['max_n_electrodes'],
    output_embeddings_map=None
)

[12:48:33 gpu 0.0G ram 0.5G] (1)     loading dataset for mgh1_3...
[12:49:11 gpu 0.0G ram 7.1G] (1)     finished loading dataset for mgh1_3




In [None]:
batch_i, batch = 0, train_dataloader[0]
batch

## 2. Loading the model

In [3]:
assert model_config['electrode_embedding']['type'] in ['learned', 'zero'] # only those two are supported for now; we don't have coordinates for MGH data.

electrode_embeddings = ElectrodeEmbedding_Learned(
    model_config['transformer']['d_model'], 
    embedding_dim=model_config['electrode_embedding']['embedding_dim'],
    embedding_requires_grad=model_config['electrode_embedding']['type'] != 'zero' # if zero, the embedding is not learned and fixed at zero
)

electrode_embeddings_eeg = ElectrodeEmbedding_Learned_FixedVocabulary(
    model_config['transformer']['d_model'], 
    embedding_dim=model_config['electrode_embedding']['embedding_dim'],
    embedding_requires_grad=model_config['electrode_embedding']['type'] != 'zero' # if zero, the embedding is not learned and fixed at zero
)
