In [3]:
import torch
import wandb, os, json
import time

from muon import Muon
from model_model import TransformerModel
from model_electrode_embedding import ElectrodeEmbedding_Learned, ElectrodeEmbedding_NoisyCoordinate, ElectrodeEmbedding_Learned_CoordinateInit, ElectrodeDataEmbeddingFFT, ElectrodeDataEmbedding

from dataset import load_dataloaders, load_subjects
from evaluation_btbench import FrozenModelEvaluation_SS_SM
from train_utils import log, update_dir_name, update_random_seed, convert_dtypes, parse_configs_from_args, get_default_configs, get_shared_memory_info

training_config, model_config, cluster_config = get_default_configs(random_string="TEMP", wandb_project="")
cluster_config['cache_subjects'] = False
dir_name = update_dir_name(model_config, training_config, cluster_config)
update_random_seed(training_config)
cluster_config['wandb_name'] = cluster_config['dir_name']
log(f"Directory name: {dir_name}", priority=0)

if len(cluster_config['wandb_project'])==0: wandb = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log(f"Using device: {device}", priority=0)

log(f"Loading subjects...", priority=0)
all_subjects = load_subjects(training_config['train_subject_trials'], training_config['eval_subject_trials'], training_config['data_dtype'], 
                             cache=cluster_config['cache_subjects'], allow_corrupted=False)

[13:53:16 gpu 0.0G ram 0.6G] (0) Directory name: M_nst8_dm192_nh12_nl5_5_nCS_eaM_eeL_fb1_cls_lr0.003_rTEMP
[13:53:16 gpu 0.0G ram 0.6G] (0) Using device: cpu
[13:53:16 gpu 0.0G ram 0.6G] (0) Loading subjects...
[13:53:16 gpu 0.0G ram 0.6G] (1)     loading subject btbank2...
[13:53:16 gpu 0.0G ram 0.6G] (1)     loading subject btbank7...
[13:53:16 gpu 0.0G ram 0.6G] (1)     loading subject btbank1...
[13:53:16 gpu 0.0G ram 0.6G] (1)     loading subject btbank3...
[13:53:16 gpu 0.0G ram 0.6G] (1)     loading subject btbank10...


In [4]:

log(f"Loading model...", priority=0)
model = TransformerModel(
    model_config['transformer']['d_model'],  
    n_layers_electrode=model_config['transformer']['n_layers_electrode'], 
    n_layers_time=model_config['transformer']['n_layers_time'],
    n_heads=model_config['transformer']['n_heads'],
    use_cls_token=model_config['transformer']['use_cls_token']
).to(device, dtype=model_config['dtype'])

if model_config['electrode_embedding']['type'] == 'learned' or model_config['electrode_embedding']['type'] == 'zero':
    electrode_embeddings = ElectrodeEmbedding_Learned(
        model_config['transformer']['d_model'], 
        embedding_dim=model_config['electrode_embedding']['embedding_dim'],
        embedding_requires_grad=model_config['electrode_embedding']['type'] != 'zero'
    )
elif model_config['electrode_embedding']['type'] == 'coordinate_init':
    electrode_embeddings = ElectrodeEmbedding_Learned_CoordinateInit(
        model_config['transformer']['d_model'], 
        embedding_dim=model_config['electrode_embedding']['embedding_dim']
    )
elif model_config['electrode_embedding']['type'] == 'noisy_coordinate':
    electrode_embeddings = ElectrodeEmbedding_NoisyCoordinate(
        model_config['transformer']['d_model'], 
        coordinate_noise_std=model_config['electrode_embedding']['coordinate_noise_std'],
        embedding_dim=model_config['electrode_embedding']['embedding_dim']
    )
else:
    raise ValueError(f"Invalid electrode embedding type: {model_config['electrode_embedding']['type']}")
electrode_embeddings = electrode_embeddings.to(device, dtype=model_config['dtype'])

if model_config['electrode_embedding']['spectrogram']:
    electrode_data_embeddings = ElectrodeDataEmbeddingFFT(
        electrode_embeddings, model_config['sample_timebin_size'], 
        max_frequency_bin=model_config['max_frequency_bin']
    ).to(device, dtype=model_config['dtype'])
else:
    electrode_data_embeddings = ElectrodeDataEmbedding(
        electrode_embeddings, model_config['sample_timebin_size'], 
        overall_sampling_rate=next(iter(all_subjects.values())).get_sampling_rate(0) # XXX remove this once figured out how to be flexible here regarding the sampling rate
    ).to(device, dtype=model_config['dtype'])

for subject in all_subjects.values():
    this_subject_trials = [trial_id for (sub_id, trial_id) in training_config['train_subject_trials'] if sub_id == subject.subject_identifier]
    electrode_data_embeddings.add_subject(subject, subject.get_sampling_rate(this_subject_trials[0]))
    log(f"Adding subject {subject.subject_identifier} to electrode data embeddings...", priority=0)
    if model_config['init_normalization']:
        for trial_id in this_subject_trials:
            log(f"Initializing normalization for subject {subject.subject_identifier} trial {trial_id}...", priority=1, indent=1)
            electrode_data_embeddings.initialize_normalization(subject, trial_id, init_normalization_window_to=int(subject.get_sampling_rate(trial_id) * 60 * 5))
electrode_data_embeddings = electrode_data_embeddings.to(device, dtype=model_config['dtype']) # moving to device again to ensure the new parameters are on the correct device

log(f"Loading dataloaders...", priority=0)
n_samples = model_config['max_n_timebins'] * model_config['sample_timebin_size']
train_dataloader, test_dataloader = load_dataloaders(
    all_subjects, training_config['train_subject_trials'], training_config['p_test'], 
    model_config['sample_timebin_size'], model_config['max_n_timebins'], training_config['data_dtype'], 
    training_config['batch_size'],
    num_workers_dataloaders=cluster_config['num_workers_dataloaders'], 
    prefetch_factor=cluster_config['prefetch_factor'],
    max_n_electrodes=model_config['max_n_electrodes'],
    output_embeddings_map=electrode_embeddings.embeddings_map
)

[13:53:18 gpu 0.0G ram 0.6G] (0) Loading model...
[13:53:34 gpu 0.0G ram 0.6G] (0) Adding subject btbank2 to electrode data embeddings...
[13:53:34 gpu 0.0G ram 0.6G] (1)     Initializing normalization for subject btbank2 trial 4...
[13:53:39 gpu 0.0G ram 0.6G] (1)     Initializing normalization for subject btbank2 trial 5...
[13:53:44 gpu 0.0G ram 0.6G] (0) Adding subject btbank7 to electrode data embeddings...
[13:53:44 gpu 0.0G ram 0.6G] (1)     Initializing normalization for subject btbank7 trial 1...
[13:53:52 gpu 0.0G ram 0.7G] (0) Adding subject btbank1 to electrode data embeddings...
[13:53:52 gpu 0.0G ram 0.7G] (1)     Initializing normalization for subject btbank1 trial 0...
[13:53:57 gpu 0.0G ram 0.9G] (1)     Initializing normalization for subject btbank1 trial 1...
[13:54:02 gpu 0.0G ram 1.0G] (0) Adding subject btbank3 to electrode data embeddings...
[13:54:02 gpu 0.0G ram 1.0G] (1)     Initializing normalization for subject btbank3 trial 1...
[13:54:06 gpu 0.0G ram 1.1G]



In [5]:
batch = next(iter(train_dataloader))

batch['data'] = batch['data'].to(device, dtype=model_config['dtype'], non_blocking=True)
batch['electrode_index'] = batch['electrode_index'].to(device, dtype=torch.long, non_blocking=True)
subject_identifier, trial_id = batch['subject_trial'][0]

In [6]:
batch['data'].shape

torch.Size([100, 128, 6144])

In [8]:
batch['electrode_index'][0]

tensor([672, 710, 724, 809, 701, 761, 805, 793, 740, 667, 636, 733, 771, 829,
        706, 644, 673, 714, 641, 631, 806, 807, 702, 817, 682, 658, 777, 779,
        633, 790, 830, 758, 808, 781, 690, 785, 734, 699, 798, 721, 797, 812,
        737, 727, 749, 767, 801, 694, 656, 648, 752, 831, 787, 778, 800, 736,
        768, 735, 668, 671, 731, 713, 717, 833, 684, 662, 814, 741, 784, 747,
        663, 834, 738, 687, 666, 757, 715, 646, 744, 630, 760, 770, 795, 745,
        653, 764, 816, 723, 729, 722, 742, 700, 774, 775, 649, 720, 712, 728,
        716, 827, 676, 750, 635, 654, 730, 832, 823, 765, 743, 692, 661, 659,
        705, 677, 810, 647, 825, 645, 788, 681, 796, 769, 776, 640, 664, 783,
        639, 804])

In [14]:
batch['electrode_index'][1], batch['subject_trial'][1]

(tensor([672, 710, 724, 809, 701, 761, 805, 793, 740, 667, 636, 733, 771, 829,
         706, 644, 673, 714, 641, 631, 806, 807, 702, 817, 682, 658, 777, 779,
         633, 790, 830, 758, 808, 781, 690, 785, 734, 699, 798, 721, 797, 812,
         737, 727, 749, 767, 801, 694, 656, 648, 752, 831, 787, 778, 800, 736,
         768, 735, 668, 671, 731, 713, 717, 833, 684, 662, 814, 741, 784, 747,
         663, 834, 738, 687, 666, 757, 715, 646, 744, 630, 760, 770, 795, 745,
         653, 764, 816, 723, 729, 722, 742, 700, 774, 775, 649, 720, 712, 728,
         716, 827, 676, 750, 635, 654, 730, 832, 823, 765, 743, 692, 661, 659,
         705, 677, 810, 647, 825, 645, 788, 681, 796, 769, 776, 640, 664, 783,
         639, 804]),
 ('btbank10', 1))

In [23]:
electrode_data_embeddings.normalization_stds.std()

tensor(0., dtype=torch.bfloat16, grad_fn=<StdBackward0>)

In [25]:
def initialize_normalization(self, subject, session_id, init_normalization_window_to=2048 * 60 * 10):
    subject_identifier = subject.subject_identifier
    indices = subject.get_electrode_indices(session_id)
    electrode_labels = subject.get_electrode_labels()
    electrode_labels = [electrode_labels[i] for i in indices]
    electrode_indices = [self.electrode_embedding_class.embeddings_map[(subject.subject_identifier, electrode_label)] for electrode_label in electrode_labels]

    with torch.no_grad():
        all_electrode_data = subject.get_all_electrode_data(session_id, window_to=init_normalization_window_to).to(self.device, dtype=self.dtype)
        electrode_means, electrode_stds = self.calculate_electrode_normalization(all_electrode_data, self.sampling_rates[electrode_indices[0]]) # XXX: Assuming all electrodes have the same sampling rate

        for idx, electrode_index in enumerate(electrode_indices):
            #print(electrode_index, electrode_means[idx], electrode_stds[idx])
            self.normalization_means.data[electrode_index] = electrode_means[idx]
            self.normalization_stds.data[electrode_index] = electrode_stds[idx]
initialize_normalization(electrode_data_embeddings, all_subjects[subject_identifier], trial_id)

In [29]:
electrode_data_embeddings.normalization_means

Parameter containing:
tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [9.0625, 7.8125, 7.5938,  ..., 3.4219, 3.4219, 3.4375],
        [8.5000, 7.7188, 7.4688,  ..., 3.3438, 3.3438, 3.3594],
        [8.9375, 7.9062, 7.5938,  ..., 3.4688, 3.4844, 3.4844]],
       dtype=torch.bfloat16, requires_grad=True)

In [26]:
electrode_data_embeddings.preprocess_electrode_data(batch['data'], 2048)

tensor([[[[ 9.2500,  8.4375,  7.0000,  ...,  4.5625,  4.3125,  4.5625],
          [ 9.4375,  7.4375,  7.6562,  ...,  4.3750,  4.7500,  4.5625],
          [ 9.3750,  9.4375,  8.0000,  ...,  5.0000,  4.9375,  5.0312],
          ...,
          [ 9.3750,  8.5000,  7.5625,  ...,  4.0938,  4.9062,  4.8125],
          [ 8.7500,  7.8750,  6.8125,  ...,  4.7188,  4.4688,  4.4062],
          [ 9.6250,  8.3750,  7.7188,  ...,  4.0938,  4.0312,  4.0938]],

         [[ 9.6875,  8.4375,  8.1250,  ...,  4.4375,  4.7188,  4.6562],
          [ 7.9375,  8.3125,  7.9688,  ...,  4.3750,  4.6875,  4.5625],
          [ 9.0000,  9.0625,  7.7188,  ...,  4.9062,  4.6875,  5.0938],
          ...,
          [ 7.6562,  7.3125,  8.4375,  ...,  4.1562,  4.6875,  4.8125],
          [ 9.0625,  7.7812,  7.5000,  ...,  4.4062,  4.2500,  4.5938],
          [ 9.9375,  8.1875,  6.8438,  ...,  4.1875,  4.1562,  4.5312]],

         [[ 6.6250,  8.7500,  7.5312,  ...,  4.5000,  4.1250,  4.2500],
          [ 9.3125,  8.4375,  

In [13]:
electrode_embeddings.embeddings_map

{('btbank2', 'LT3a1'): 0,
 ('btbank2', 'LT3a2'): 1,
 ('btbank2', 'LT3a3'): 2,
 ('btbank2', 'LT3a4'): 3,
 ('btbank2', 'LT3a5'): 4,
 ('btbank2', 'LT3a6'): 5,
 ('btbank2', 'LT3a7'): 6,
 ('btbank2', 'LT3a8'): 7,
 ('btbank2', 'LT3a9'): 8,
 ('btbank2', 'LT3a10'): 9,
 ('btbank2', 'LT2aA1'): 10,
 ('btbank2', 'LT2aA2'): 11,
 ('btbank2', 'LT2aA4'): 12,
 ('btbank2', 'LT2aA5'): 13,
 ('btbank2', 'LT2aA6'): 14,
 ('btbank2', 'LT2aA7'): 15,
 ('btbank2', 'LT2aA8'): 16,
 ('btbank2', 'LT2aA9'): 17,
 ('btbank2', 'LT2aA10'): 18,
 ('btbank2', 'LT2aA11'): 19,
 ('btbank2', 'LT2aA12'): 20,
 ('btbank2', 'LT2aA13'): 21,
 ('btbank2', 'LT2aA14'): 22,
 ('btbank2', 'LT3bHa4'): 23,
 ('btbank2', 'LT3bHa5'): 24,
 ('btbank2', 'LT3bHa6'): 25,
 ('btbank2', 'LT3bHa7'): 26,
 ('btbank2', 'LT3bHa8'): 27,
 ('btbank2', 'LT3bHa9'): 28,
 ('btbank2', 'LT3bHa10'): 29,
 ('btbank2', 'LT3bHa11'): 30,
 ('btbank2', 'LT3bHa12'): 31,
 ('btbank2', 'LT3bHa14'): 32,
 ('btbank2', 'LT1bIb1'): 33,
 ('btbank2', 'LT1bIb2'): 34,
 ('btbank2', 'LT1b