In [5]:
import torch
import wandb, os, json
import time
import numpy as np
from torch.amp import autocast
import gc

from utils.muon_optimizer import Muon
from subject.dataset import load_subjects
from evaluation.neuroprobe_tasks import FrozenModelEvaluation_SS_SM
from training_setup.training_config import log, update_dir_name, update_random_seed, parse_config_from_args, get_default_config, parse_subject_trials_from_config, convert_dtypes
from torch.optim.lr_scheduler import ChainedScheduler
from training_setup.training_config import convert_dtypes, unconvert_dtypes, parse_subject_trials_from_config
from torch.utils.data import DataLoader

from evaluation.neuroprobe.datasets import BrainTreebankSubjectTrialBenchmarkDataset
import evaluation.neuroprobe.config as neuroprobe_config

### PARSE MODEL DIR ###

model_dir = "andrii0_lr0.003_wd0.0_dr0.1_rR1_t20250714_121055"
batch_size = 100
model_epoch = 0

### LOAD CONFIG ###

# Load the checkpoint
if model_epoch < 0: model_epoch = "final"
checkpoint_path = os.path.join("runs/data", model_dir, f"model_epoch_{model_epoch}.pth")
checkpoint = torch.load(checkpoint_path)
config = unconvert_dtypes(checkpoint['config'])
log(f"Directory name: {model_dir}", priority=0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config['device'] = device
log(f"Using device: {device}", priority=0)

config['training']['train_subject_trials'] = []

### LOAD SUBJECTS ###

log(f"Loading subjects...", priority=0)
# all_subjects is a dictionary of subjects, with the subject identifier as the key and the subject object as the value
all_subjects = load_subjects(config['training']['train_subject_trials'], 
                             config['training']['eval_subject_trials'], config['training']['data_dtype'], 
                             cache=config['cluster']['cache_subjects'], allow_corrupted=False)
for subject_identifier, trial_id in config['training']['eval_subject_trials']:
    log(f"Loading subject {subject_identifier} trial {trial_id}", priority=0)
    subject = all_subjects[subject_identifier]
    subject.load_neural_data(trial_id)

### LOAD MODEL ###

# Import the training setup class dynamically based on config
try:
    setup_module = __import__(f'training_setup.{config["training"]["setup_name"].lower()}', fromlist=[config["training"]["setup_name"]])
    setup_class = getattr(setup_module, config["training"]["setup_name"])
    training_setup = setup_class(all_subjects, config, verbose=True)
except (ImportError, AttributeError) as e:
    print(f"Could not load training setup '{config['training']['setup_name']}'. Are you sure the filename and the class name are the same and correspond to the parameter? Error: {str(e)}")
    exit()

[22:26:27 gpu 0.1G ram 0.7G] Directory name: andrii0_lr0.003_wd0.0_dr0.1_rR1_t20250714_121055
[22:26:27 gpu 0.1G ram 0.7G] Using device: cuda
[22:26:27 gpu 0.1G ram 0.7G] Loading subjects...
[22:26:27 gpu 0.1G ram 0.7G]     loading subject btbank4...
[22:26:27 gpu 0.1G ram 0.7G]     loading subject btbank2...
[22:26:27 gpu 0.1G ram 0.7G]     loading subject btbank7...
[22:26:27 gpu 0.1G ram 0.7G]     loading subject btbank3...
[22:26:27 gpu 0.1G ram 0.7G]     loading subject btbank1...
[22:26:27 gpu 0.1G ram 0.7G]     loading subject btbank10...
[22:26:27 gpu 0.1G ram 0.7G] Loading subject btbank1 trial 1
[22:27:25 gpu 0.1G ram 5.9G] Loading subject btbank1 trial 2
[22:28:09 gpu 0.1G ram 9.6G] Loading subject btbank2 trial 0
[22:29:06 gpu 0.1G ram 14.5G] Loading subject btbank2 trial 4
[22:30:28 gpu 0.1G ram 21.4G] Loading subject btbank3 trial 0
[22:31:08 gpu 0.1G ram 24.6G] Loading subject btbank3 trial 1
[22:32:09 gpu 0.1G ram 29.6G] Loading subject btbank4 trial 0
[22:33:10 gpu 0.1

In [10]:
log(f"Loading model...", priority=0)
training_setup.initialize_model()

log(f"Loading the evaluation class...", priority=0)
evaluation = FrozenModelEvaluation_SS_SM(
    # model evaluation function
    model_preprocess_functions=training_setup.get_preprocess_functions(pretraining=False),
    model_evaluation_function=training_setup.generate_frozen_features,
    # benchmark parameters 
    eval_names=neuroprobe_config.NEUROPROBE_TASKS, lite=True,
    subject_trials=[(all_subjects[subject_identifier], trial_id) for subject_identifier, trial_id in config['training']['eval_subject_trials']],
    # dataloader parameters
    device=device,
    dtype=config['training']['data_dtype'],
    batch_size=config['training']['batch_size'],
    num_workers_eval=config['cluster']['num_workers_eval'],
    prefetch_factor=config['cluster']['prefetch_factor'],
)

[22:44:13 gpu 0.1G ram 74.5G] Loading model...
[22:44:25 gpu 0.1G ram 74.5G] Adding subject btbank4 to electrode embeddings...
[22:44:25 gpu 0.1G ram 74.5G] Adding subject btbank2 to electrode embeddings...
[22:44:25 gpu 0.1G ram 74.5G] Adding subject btbank7 to electrode embeddings...
[22:44:25 gpu 0.1G ram 74.5G] Adding subject btbank3 to electrode embeddings...
[22:44:25 gpu 0.1G ram 74.5G] Adding subject btbank1 to electrode embeddings...
[22:44:25 gpu 0.1G ram 74.5G] Adding subject btbank10 to electrode embeddings...


1. Evaluating the epoch 0 model, while taking only the cls token for every timebin

In [30]:
from evaluation.neuroprobe.datasets import BrainTreebankSubjectTrialBenchmarkDataset

def generate_frozen_features(batch):
    # INPUT:
    #   batch['data'] shape: (batch_size, n_electrodes, n_timesamples)
    #   batch['electrode_labels'] shape: list of length 1 (since it's the same across the batch), each element is a list of electrode labels
    #   batch['metadata']: dictionary containing metadata like the subject identifier and trial id, sampling rate, etc.
    # OUTPUT:
    #   features shape: (batch_size, *) where * can be arbitrary (and will be concatenated for regression)
    self = training_setup

    batch['data'] = batch['data'].to(self.model.device, dtype=self.model.dtype, non_blocking=True)
    batch['electrode_index'] = batch['electrode_index'].to(self.model.device, non_blocking=True)

    embeddings = self.electrode_embeddings(batch)
    features = self.model(batch, embeddings, electrode_transformer_only=True) # shape: (batch_size, n_electrodes + 1, n_timebins, d_model)
    features = features[:, 0:1, :, :] # shape: (batch_size, 1, n_timebins, d_model) -- take just the cls token

    if self.config['cluster']['eval_aggregation_method'] == 'mean':
        features = features.mean(dim=[1, 2])
    elif self.config['cluster']['eval_aggregation_method'] == 'concat':
        features = features.reshape(batch['data'].shape[0], -1)
    return features

precomputed_features = {}

window_size = 2048 # sampling rate
batch_size = 400
for subject_identifier, trial_id in config['training']['eval_subject_trials']:
    subject = all_subjects[subject_identifier]
    electrode_subset_indices = [subject.electrode_labels.index(e) for e in neuroprobe_config.NEUROPROBE_LITE_ELECTRODES[subject_identifier]]
    
    indices = []
    for eval_name in neuroprobe_config.NEUROPROBE_TASKS:
        dataset = BrainTreebankSubjectTrialBenchmarkDataset(subject, trial_id, torch.float32, eval_name, output_indices=True)

        for i, ((index_start, index_end), label) in enumerate(dataset):
            indices.append(index_start)
    indices = torch.tensor(sorted(list(set(indices))))

    log(f"For subject {subject_identifier} trial {trial_id}, generating features for {len(indices)} indices", priority=0)

    frozen_features = None
    for i_start in range(0, len(indices), batch_size):
        log(f"Generating features for batch {i_start} of {len(indices)}", priority=0, indent=1)
        i_end = min(i_start + batch_size, len(indices))
        indices_batch = indices[i_start:i_end]

        model_input = torch.zeros((len(indices_batch), len(electrode_subset_indices), window_size))
        for i_index, index in enumerate(indices_batch):
            model_input[i_index, :, :] = subject.get_all_electrode_data(trial_id, window_from=index, window_to=index+window_size)[electrode_subset_indices, :]
        model_input = model_input.to(device)

        batch = {
            'data': model_input, # shape (batch_size, n_electrodes, n_samples),
            'electrode_labels': [neuroprobe_config.NEUROPROBE_LITE_ELECTRODES[subject_identifier]],
            'metadata': {
                'subject_identifier': subject_identifier,
                'trial_id': trial_id,
                'sampling_rate': subject.get_sampling_rate(trial_id),
            },
        }

        with torch.no_grad():
            for preprocess_function in training_setup.get_preprocess_functions(pretraining=False):
                batch = preprocess_function(batch)
            features = generate_frozen_features(batch).reshape(batch['data'].shape[0], -1)
        if frozen_features is None:
            frozen_features = torch.zeros((len(indices), features.shape[1]))
        frozen_features[i_start:i_end, :] = features

    precomputed_features[subject_identifier, trial_id] = {
        'frozen_features': frozen_features,
        'indices': indices,
    }

[23:28:23 gpu 64.2G ram 76.2G] For subject btbank1 trial 1, generating features for 12402 indices
[23:28:23 gpu 64.2G ram 76.1G]     Generating features for batch 0 of 12402
[23:28:24 gpu 64.2G ram 76.2G]     Generating features for batch 400 of 12402
[23:28:25 gpu 64.2G ram 76.2G]     Generating features for batch 800 of 12402
[23:28:25 gpu 64.2G ram 76.2G]     Generating features for batch 1200 of 12402
[23:28:26 gpu 64.2G ram 76.2G]     Generating features for batch 1600 of 12402
[23:28:26 gpu 64.2G ram 76.2G]     Generating features for batch 2000 of 12402
[23:28:27 gpu 64.2G ram 76.2G]     Generating features for batch 2400 of 12402
[23:28:28 gpu 64.2G ram 76.2G]     Generating features for batch 2800 of 12402
[23:28:28 gpu 64.2G ram 76.2G]     Generating features for batch 3200 of 12402
[23:28:29 gpu 64.2G ram 76.2G]     Generating features for batch 3600 of 12402
[23:28:29 gpu 64.2G ram 76.2G]     Generating features for batch 4000 of 12402
[23:28:30 gpu 64.2G ram 76.2G]     Gen

In [20]:
from evaluation.neuroprobe.datasets import BrainTreebankSubjectTrialBenchmarkDataset

subject_identifier = "btbank1"
trial_id = 1
subject = all_subjects[subject_identifier]

dataset = BrainTreebankSubjectTrialBenchmarkDataset(subject, trial_id, torch.float32, "onset", output_indices=True)
dataloader = DataLoader(dataset, batch_size=100, shuffle=False, num_workers=1, prefetch_factor=1)

for batch in dataloader:
    print(batch[0])
    break

[tensor([ 415356,  571056,  419781,  617368,  425242,  717870,  428396,  721966,
         454743,  728110,  455782,  780557,  458240,  782605,  460618,  786701,
         464098,  796941,  467916,  800013,  472051,  801037,  480328,  805133,
         487390,  810253,  502320,  812301,  503375,  814349,  513045, 1141436,
         520173, 1158434,  524042, 1161506,  528330, 1165602,  533802, 1168674,
         536450, 1170722,  542348, 1219105,  543830, 1222177,  544837, 1223201,
         550429, 1224225,  555844, 1231393,  558581, 1236513,  563006, 1238561,
         580396, 1239585,  585911, 1241633,  591654, 1242657,  593740, 1244705,
         596232, 1250849,  599408, 1254945,  604011, 1263137,  609837, 1274401,
         626304, 1276449,  632737, 1399075,  639207, 1407267,  647410, 1414435,
         649992, 1419555,  660042, 1422627,  663058, 1427747,  668447, 1430819,
         673047, 1431843,  676225, 1435939,  679812, 1450275,  683175, 1478850,
         688834, 1480898,  695106, 1483

In [31]:
def retrieve_frozen_features(batch):
    # INPUT:
    #   batch['data'] shape: (batch_size, n_electrodes, n_timesamples)
    #   batch['electrode_labels'] shape: list of length 1 (since it's the same across the batch), each element is a list of electrode labels
    #   batch['metadata']: dictionary containing metadata like the subject identifier and trial id, sampling rate, etc.
    # OUTPUT:
    #   features shape: (batch_size, *) where * can be arbitrary (and will be concatenated for regression)
    indices = precomputed_features[batch['metadata']['subject_identifier'], batch['metadata']['trial_id']]['indices']
    frozen_features = precomputed_features[batch['metadata']['subject_identifier'], batch['metadata']['trial_id']]['frozen_features']

    need_indices = batch['data'][0]
    print(need_indices.shape)
    print(indices.index(need_indices))

    return frozen_features[indices.index(need_indices)]

evaluation.model_evaluation_function = retrieve_frozen_features
evaluation.model_preprocess_functions = []
evaluation.batch_size = 400

epoch0_eval_results = evaluation.evaluate_on_all_metrics(log_priority=3, quick_eval=False, raw_data=False)



torch.Size([119, 2048])


AttributeError: 'Tensor' object has no attribute 'index'