In [None]:
import os
os.environ['ROOT_DIR_BRAINTREEBANK'] = ''
import h5py
import tqdm
import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score

from temporaldata import Data
import neuroprobe.config as neuroprobe_config
from neuroprobe_utils.eval_utils import preprocess_data


PROCESSED_DATA_PATH = '<ENTER PATH>'

In [2]:
stft_parameters = {
    "nperseg": 512,
    "poverlap": 0.75,
    "window": "hann",
    "max_frequency": 150,
    "min_frequency": 0
}

def _n_timebins_(n_samples: int, nperseg: int, poverlap: float, center: bool = True) -> int:
    hop_length = int(nperseg * (1 - poverlap))
    if center:
        n_samples_eff = n_samples + 2 * (nperseg // 2)
    else:
        n_samples_eff = n_samples

    n_timebins = np.floor((n_samples_eff - nperseg) / hop_length) + 1
    return int(n_timebins)


def _n_freqs_(sampling_rate: float, nperseg: int, fmin: float, fmax: float) -> int:
    freqs = np.fft.rfftfreq(nperseg, d=1.0 / sampling_rate)
    mask = (freqs >= fmin) & (freqs <= fmax)
    return int(mask.sum())

In [None]:
from itertools import product

def run_neuroprobe_baseline(model_name, preprocess_type, preprocess_parameters):
    iterations = list(product(
        neuroprobe_config.NEUROPROBE_TASKS,
        neuroprobe_config.NEUROPROBE_LITE_SUBJECT_TRIALS,
        range(neuroprobe_config.NEUROPROBE_LITE_N_FOLDS)
    ))
    iterations = [
        (eval_name, subject_id, trial_id, fold_idx)
        for eval_name, (subject_id, trial_id), fold_idx in iterations
    ]

    results = []
    pbar = tqdm.tqdm(iterations, desc=f"", leave=True)
    for eval_name, subject_id, trial_id, fold_idx in pbar:
        pbar.set_description(f"Running {eval_name} (sub {subject_id}, trial {trial_id}, fold {fold_idx})")
        with h5py.File(os.path.join(PROCESSED_DATA_PATH, f'sub_{subject_id}_trial{trial_id:03d}.h5'), 'r') as f:
            data = Data.from_hdf5(f)
            train_split = getattr(data, f'{eval_name}_fold{fold_idx}_train')
            test_split = getattr(data, f'{eval_name}_fold{fold_idx}_test')

            split_included_channels_train = getattr(data.channels, f'included_{eval_name}_fold{fold_idx}_train')
            split_included_channels_test = getattr(data.channels, f'included_{eval_name}_fold{fold_idx}_test')

            if 'stft' in preprocess_type:
                n_raw_samples_train = int(np.unique(train_split.end - train_split.start)[0] * neuroprobe_config.SAMPLING_RATE)
                n_timebins_train = _n_timebins_(n_raw_samples_train, preprocess_parameters["stft"]["nperseg"], preprocess_parameters["stft"]["poverlap"], True)
                n_freqs_train = _n_freqs_(neuroprobe_config.SAMPLING_RATE, preprocess_parameters["stft"]["nperseg"], preprocess_parameters["stft"]["min_frequency"], preprocess_parameters["stft"]["max_frequency"])
                num_samples_train = n_timebins_train * n_freqs_train * sum(split_included_channels_train)
                
                n_raw_samples_test = int(np.unique(test_split.end - test_split.start)[0] * neuroprobe_config.SAMPLING_RATE)
                n_timebins_test = _n_timebins_(n_raw_samples_test, preprocess_parameters["stft"]["nperseg"], preprocess_parameters["stft"]["poverlap"], True)
                n_freqs_test = _n_freqs_(neuroprobe_config.SAMPLING_RATE, preprocess_parameters["stft"]["nperseg"], preprocess_parameters["stft"]["min_frequency"], preprocess_parameters["stft"]["max_frequency"])
                num_samples_test = n_timebins_test * n_freqs_test * sum(split_included_channels_test)
            else:
                num_samples_train = neuroprobe_config.SAMPLING_RATE * sum(split_included_channels_train)
                num_samples_test = neuroprobe_config.SAMPLING_RATE * sum(split_included_channels_test)

            train_electrode_labels = data.channels.name[split_included_channels_train].tolist()
            X_train = np.zeros((len(train_split), num_samples_train), dtype=np.float32)
            y_train = np.zeros(len(train_split), dtype=np.int32)
            # TODO more efficient way to do this?
            for i in range(len(train_split)):
                data_train = data.slice(train_split.start[i], train_split.end[i])
                X_train[i, :] = preprocess_data(
                    data_train.seeg_data.data[:, split_included_channels_train].T,
                    train_electrode_labels,
                    preprocess_type,
                    preprocess_parameters,
                ).flatten()
                y_train[i] = train_split.label[i]

            test_electrode_labels = data.channels.name[split_included_channels_test].tolist()
            X_test = np.zeros((len(test_split), num_samples_test), dtype=np.float32)
            y_test = np.zeros(len(test_split), dtype=np.int32)
            for i in range(len(test_split)):
                data_test = data.slice(test_split.start[i], test_split.end[i])
                X_test[i, :] = preprocess_data(
                    data_test.seeg_data.data[:, split_included_channels_test].T,
                    test_electrode_labels,
                    preprocess_type,
                    preprocess_parameters,
                ).flatten()
                y_test[i] = test_split.label[i]
            
        scaler = StandardScaler()
        X_train = scaler.fit_transform(X_train)
        X_test = scaler.transform(X_test)

        clf = LogisticRegression(random_state=42, max_iter=1000, tol=1e-3)
        clf.fit(X_train, y_train)

        train_score = clf.score(X_train, y_train)
        test_score = clf.score(X_test, y_test)

        try:
            test_proba = clf.predict_proba(X_test)
            # If y_test is multiclass, roc_auc_score needs multi_class='ovr'
            if len(np.unique(y_test)) > 2:
                test_auc = roc_auc_score(y_test, test_proba, multi_class='ovr')
            else:
                # For binary, use probability of class 1
                test_auc = roc_auc_score(y_test, test_proba[:, 1])
        except Exception as e:
            print(f"Could not compute ROC AUC: {e}")
            test_auc = float('nan')

        results.append({
            "model_name": model_name,
            "eval_name": eval_name,
            "subject_id": subject_id,
            "trial_id": trial_id,
            "fold_idx": fold_idx,
            "train_acc": train_score,
            "test_acc": test_score,
            "test_auc": test_auc,
        })

    results_df = pd.DataFrame(results)
    return results_df


def print_model_results(results_df):
    # Assumes results for single model
    print('-' * 100)
    grouped = results_df.groupby('eval_name', as_index=False)
    task_means = grouped[['train_acc', 'test_acc', 'test_auc']].mean(numeric_only=True)
    # auc_std = grouped['test_auc'].std().rename(columns={'test_auc': 'test_auc_std'})
    auc_sem = grouped['test_auc'].sem().rename(columns={'test_auc': 'test_auc_sem'})
    # task_means['test_auc_std'] = auc_std['test_auc_std']
    task_means['test_auc_sem'] = auc_sem['test_auc_sem']
    print(task_means)
    print('-' * 100)
    print(f"Mean train accuracy: {task_means['train_acc'].mean():.4f}")
    print(f"Mean test accuracy: {task_means['test_acc'].mean():.4f}")
    print(f"Mean test AUC: {task_means['test_auc'].mean():.4f} ± {(task_means['test_auc_sem']**2).sum()**0.5 / len(task_means):.4f}")

In [4]:
preprocess_type = ''
preprocess_parameters = {"type": preprocess_type}
linear_raw_results = run_neuroprobe_baseline('Linear (raw)', preprocess_type, preprocess_parameters)
print_model_results(linear_raw_results)

Running face_num (sub 10, trial 1, fold 1): 100%|██████████| 360/360 [1:55:00<00:00, 19.17s/it]        

----------------------------------------------------------------------------------------------------
           eval_name  train_acc  test_acc  test_auc  test_auc_sem
0       delta_volume        1.0  0.690266  0.753039      0.013646
1           face_num        1.0  0.499726  0.499039      0.005284
2   frame_brightness        1.0  0.505580  0.507079      0.010881
3        global_flow        1.0  0.524088  0.535152      0.007117
4     gpt2_surprisal        1.0  0.559915  0.584387      0.006969
5         local_flow        1.0  0.535423  0.544285      0.004338
6              onset        1.0  0.727451  0.794724      0.015286
7              pitch        1.0  0.526979  0.535886      0.004243
8             speech        1.0  0.614515  0.655844      0.016394
9             volume        1.0  0.569266  0.594880      0.011494
10          word_gap        1.0  0.567905  0.595251      0.011042
11     word_head_pos        1.0  0.549048  0.570358      0.006023
12        word_index        1.0  0.682002




In [5]:
preprocess_type = 'stft_abs'
preprocess_parameters = {"type": preprocess_type, "stft": stft_parameters}
stft_results = run_neuroprobe_baseline('Linear (STFT)', preprocess_type, preprocess_parameters)
print_model_results(stft_results)

Running face_num (sub 10, trial 1, fold 1): 100%|██████████| 360/360 [41:07<00:00,  6.85s/it]        

----------------------------------------------------------------------------------------------------
           eval_name  train_acc  test_acc  test_auc  test_auc_sem
0       delta_volume        1.0  0.662303  0.718109      0.017969
1           face_num        1.0  0.518869  0.524925      0.006293
2   frame_brightness        1.0  0.521796  0.533023      0.012438
3        global_flow        1.0  0.567829  0.603799      0.013345
4     gpt2_surprisal        1.0  0.548802  0.570385      0.012315
5         local_flow        1.0  0.564284  0.592587      0.014368
6              onset        1.0  0.776174  0.850897      0.018646
7              pitch        1.0  0.547087  0.569822      0.008769
8             speech        1.0  0.756011  0.825206      0.020790
9             volume        1.0  0.663641  0.726024      0.027559
10          word_gap        1.0  0.551678  0.579498      0.013650
11     word_head_pos        1.0  0.544857  0.564990      0.008992
12        word_index        1.0  0.616789




In [6]:
preprocess_type = 'laplacian-stft_abs'
preprocess_parameters = {"type": preprocess_type, "stft": stft_parameters}
laplacian_stft_results = run_neuroprobe_baseline('Linear (Laplacian STFT)', preprocess_type, preprocess_parameters)
print_model_results(laplacian_stft_results)

Running face_num (sub 10, trial 1, fold 1): 100%|██████████| 360/360 [1:00:42<00:00, 10.12s/it]      

----------------------------------------------------------------------------------------------------
           eval_name  train_acc  test_acc  test_auc  test_auc_sem
0       delta_volume        1.0  0.700043  0.762268      0.018781
1           face_num        1.0  0.519314  0.529991      0.010499
2   frame_brightness        1.0  0.504834  0.520560      0.021550
3        global_flow        1.0  0.583130  0.627021      0.011151
4     gpt2_surprisal        1.0  0.579699  0.613087      0.012484
5         local_flow        1.0  0.570840  0.607144      0.012947
6              onset        1.0  0.820063  0.890582      0.016106
7              pitch        1.0  0.553957  0.578374      0.012393
8             speech        1.0  0.813079  0.882720      0.014732
9             volume        1.0  0.644094  0.716555      0.023712
10          word_gap        1.0  0.573870  0.611651      0.010623
11     word_head_pos        1.0  0.571452  0.601639      0.009155
12        word_index        1.0  0.683563




In [None]:
combined_results = pd.concat([laplacian_stft_results, stft_results, linear_raw_results], ignore_index=True)
agg = (
    combined_results
    .groupby(["model_name", "eval_name"])
    .agg(
        train_acc_mean=("train_acc", "mean"),
        test_acc_mean=("test_acc", "mean"),
        test_auc_mean=("test_auc", "mean"),
        test_auc_sem=("test_auc", lambda x: x.std(ddof=1) / np.sqrt(len(x)))
    )
    .reset_index()
)
overall = (
    agg
    .groupby("model_name")
    .agg(
        train_acc_mean=("train_acc_mean", "mean"),
        test_acc_mean=("test_acc_mean", "mean"),
        test_auc_mean=("test_auc_mean", "mean"),
        test_auc_sem=("test_auc_mean", lambda x: x.std(ddof=1) / np.sqrt(len(x)))  # sem of means
    )
    .reset_index()
)
overall["eval_name"] = "Total"

summary = pd.concat([agg, overall], ignore_index=True)
auc_wide = summary.pivot(index="eval_name", columns="model_name", values="test_auc_mean")
auc_wide[['Linear (raw)', 'Linear (STFT)', 'Linear (Laplacian STFT)']]

model_name,Linear (raw),Linear (STFT),Linear (Laplacian STFT)
eval_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
ALL,0.605852,0.62958,0.660318
delta_volume,0.753039,0.718109,0.762268
face_num,0.499039,0.524925,0.529991
frame_brightness,0.507079,0.533023,0.52056
global_flow,0.535152,0.603799,0.627021
gpt2_surprisal,0.584387,0.570385,0.613087
local_flow,0.544285,0.592587,0.607144
onset,0.794724,0.850897,0.890582
pitch,0.535886,0.569822,0.578374
speech,0.655844,0.825206,0.88272
