In [1]:
import gc
import json
import os
import time

import numpy as np
import torch
import wandb
from torch.amp import autocast

from evaluation.neuroprobe import config as neuroprobe_config
from evaluation.neuroprobe.datasets import BrainTreebankSubjectTrialBenchmarkDataset
from evaluation.neuroprobe_tasks import FrozenModelEvaluation_SS_SM
from subject.dataset import load_subjects
from torch.optim.lr_scheduler import ChainedScheduler
from torch.utils.data import DataLoader
from training_setup.training_config import (
    convert_dtypes,
    get_default_config,
    log,
    parse_config_from_args,
    parse_subject_trials_from_config,
    unconvert_dtypes,
    update_dir_name,
    update_random_seed,
)
from utils.muon_optimizer import Muon

torch.set_float32_matmul_precision('high')

RUNS_DIR='runs/data'

### PARSE MODEL DIR ###

# Default values instead of argparse
model_dir = "andrii0_lr0.003_wd0.0_dr0.1_rTEMP_t20250812_172908"  # Directory containing the saved model
model_epoch = 0  # Epoch of the model to load
subject_id = 10  # Subject identifier 
trial_id = 0  # Trial identifier
eval_tasks = ["speech"]  # Tasks to evaluate on
overwrite = False  # Whether to overwrite existing frozen features
batch_size = 100  # Batch size for feature computation

bins_start_before_word_onset_seconds = 0
bins_end_after_word_onset_seconds = 1.0

### LOAD CONFIG ###

# Load the checkpoint
if model_epoch < 0: model_epoch = "final"
checkpoint_path = os.path.join(RUNS_DIR, model_dir, f"model_epoch_{model_epoch}.pth")
checkpoint = torch.load(checkpoint_path)
config = unconvert_dtypes(checkpoint['config'])
log(f"Directory name: {model_dir}", priority=0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config['device'] = device
log(f"Using device: {device}", priority=0)

config['training']['train_subject_trials'] = ""
config['training']['eval_subject_trials'] = f"btbank{subject_id}_{trial_id}"
parse_subject_trials_from_config(config)

if 'setup_name' not in config['training']:
    config['training']['setup_name'] = "andrii0" # XXX: this is only here for backwards compatibility, can remove soon

### LOAD SUBJECTS ###

log(f"Loading subjects...", priority=0)
# all_subjects is a dictionary of subjects, with the subject identifier as the key and the subject object as the value
all_subjects = load_subjects(config['training']['train_subject_trials'], 
                             config['training']['eval_subject_trials'], config['training']['data_dtype'], 
                             cache=config['cluster']['cache_subjects'], allow_corrupted=False)
subject = all_subjects[f"btbank{subject_id}"] # we only really have one subject, so we can just get it by subject identifier

electrode_subset = neuroprobe_config.NEUROPROBE_LITE_ELECTRODES[f"btbank{subject_id}"]

### LOAD MODEL ###

# Import the training setup class dynamically based on config
try:
    setup_module = __import__(f'training_setup.{config["training"]["setup_name"].lower()}', fromlist=[config["training"]["setup_name"]])
    setup_class = getattr(setup_module, config["training"]["setup_name"])
    training_setup = setup_class(all_subjects, config, verbose=True)
except (ImportError, AttributeError) as e:
    print(f"Could not load training setup '{config['training']['setup_name']}'. Are you sure the filename and the class name are the same and correspond to the parameter? Error: {str(e)}")
    exit()

log(f"Loading model...", priority=0)
training_setup.initialize_model()

log(f"Loading model weights...", priority=0)
training_setup.load_model(model_epoch)

log(f"Computing frozen features...", priority=0)
from evaluation.neuroprobe_frozen import NeuroprobeFrozenFeaturesExtractor
feature_extractor = NeuroprobeFrozenFeaturesExtractor(
    training_setup=training_setup,
    all_subjects=all_subjects,
    device='cpu',
    eval_names=eval_tasks,
    subject_trials=[(subject_id, trial_id)],
    dtype=torch.float32,
    feature_aggregation_method=None,
    log_priority=0,
    batch_size=batch_size,
)
save_file_path = os.path.join(RUNS_DIR, model_dir, "frozen_features_neuroprobe", f"model_epoch{model_epoch}", f"frozen_population_btbank{subject_id}_{trial_id}.pth")
os.makedirs(os.path.dirname(save_file_path), exist_ok=True)
feature_extractor.generate_frozen_features(save_path=save_file_path)
log(f"Saved results to {save_file_path}")


[17:57:14 gpu 0.0G ram 0.8G] Directory name: andrii0_lr0.003_wd0.0_dr0.1_rTEMP_t20250812_172908
[17:57:14 gpu 0.0G ram 0.8G] Using device: cuda
[17:57:14 gpu 0.0G ram 0.8G] Loading subjects...
[17:57:14 gpu 0.0G ram 0.8G]     loading subject btbank10...
[17:57:14 gpu 0.0G ram 0.8G] Loading model...
[17:57:29 gpu 0.0G ram 1.0G] Adding subject btbank10 to electrode embeddings...
[17:57:29 gpu 0.0G ram 1.0G] Loading model weights...
[17:57:29 gpu 0.1G ram 1.0G] Computing frozen features...
[17:57:29 gpu 0.1G ram 1.0G] Generating Neuroprobe Eval indices for all subjects and trials
[17:57:30 gpu 0.1G ram 1.0G]     Generated 3500 indices for btbank10_0
[17:57:42 gpu 0.1G ram 1.1G]     Generating features for batch 0 of 3500
[17:57:57 gpu 2.9G ram 6.5G]     Generating features for batch 100 of 3500
[17:58:09 gpu 3.1G ram 6.4G]     Generating features for batch 200 of 3500
[17:58:20 gpu 3.1G ram 6.4G]     Generating features for batch 300 of 3500
[17:58:31 gpu 3.1G ram 6.4G]     Generating fea