In [1]:
import numpy as np; print("numpy:", np.__version__)
import pandas as pd; print("pandas:", pd.__version__)
import sklearn; print("sklearn:", sklearn.__version__)

numpy: 2.0.2
pandas: 2.2.3
sklearn: 1.6.1


In [2]:
import time
start_time = time.perf_counter()

#### dummy data

In [3]:
from sklearn.datasets import make_classification
import random
import string

X, y = make_classification(n_samples=10000, n_features=128, n_informative=10, n_redundant=60, n_repeated=0, n_classes=5, 
                           n_clusters_per_class=3, weights=None, flip_y=0.01, class_sep=0.7)
groups = np.array([''.join(random.choices(string.ascii_letters + string.digits, k=10)) for _ in range(len(y))])
print(X.shape, y.shape, groups.shape)

(10000, 128) (10000,) (10000,)


### create X, y, groups (by subject id)

In [4]:
from sklearn.preprocessing import LabelEncoder

def get_X_y_groups(X_fname, df_fname, y_column_name, group_column_name):

    '''
    read X (pretrained model embeddings)
    '''
    X = np.load(X_fname)

    '''
    read metadata sheet
    '''
    df = pd.read_csv(df_fname)
    y = df[y_column_name].values
    
    '''
    compile X, y, subject_groups
    '''
    groups_encoder = LabelEncoder().fit(df[group_column_name].tolist())
    subject_groups = df.apply(lambda row: groups_encoder.transform([row[group_column_name]]), axis=1)
    
    assert len(X) == len(y) == len(subject_groups)

    return X, y, subject_groups

### for a given probe_type/feature_type and model_patch_size, train new linear probes for all available model checkpoints 

In [7]:
from utils.pipeline import multiclass_clf_pipeline_runner

'''
all common/global config across all linear probes
'''

feature_type = "channel_region"
model_patch_size = "1sec"

model_type = "LogReg_L2_MultiClass"
cv_type = "Simple_KFold"
cv_params = {
     'cv_folds': {"outer": 10},
     'n_jobs': {"outer": 1, "model_fit": 10},
     'random_state': 2509843
}

model_checkpoints = [f"epoch_{x}" for x in range(10, 110, 10)]
model_checkpoints

['epoch_10',
 'epoch_20',
 'epoch_30',
 'epoch_40',
 'epoch_50',
 'epoch_60',
 'epoch_70',
 'epoch_80',
 'epoch_90',
 'epoch_100']

### serial execution

In [8]:
from collections import OrderedDict
all_probes = OrderedDict()

for checkpoint in model_checkpoints:

    # X, y, groups = get_X_y_groups(
    #     X_fname=f"{model_patch_size}_{checkpoint}.npy",
    #     df_fname="metadata.csv",
    #     y_column_name=feature_type,
    #     group_column_name="subject_id",
    # )
    
    X, y = make_classification(n_samples=10000, n_features=128, n_informative=10, n_redundant=60, n_repeated=0, n_classes=5, 
                               n_clusters_per_class=3, weights=None, flip_y=0.01, class_sep=0.7)
    groups = np.array([''.join(random.choices(string.ascii_letters + string.digits, k=10)) for _ in range(len(y))])
        
    results = multiclass_clf_pipeline_runner(X, y, groups, model_type, cv_type, cv_params)
    all_probes[checkpoint] = results[model_type]


** ALL DATA USED FOR OUTER CV: (10000, 128) [0 1 2 3 4] [2000 1994 2000 2003 2003]
Fold 0:




Fold 1:




Fold 2:




Fold 3:




Fold 4:




Fold 5:




Fold 6:




Fold 7:




Fold 8:




Fold 9:




ValueError: Target is multiclass but average='binary'. Please choose another average setting, one of [None, 'micro', 'macro', 'weighted'].

### FIXME: TODO: parallel execution

In [None]:
# ....................

In [None]:
import pickle

date = "5_13_25"
# with open(f"{model_patch_size}_{feature_type}_probes_for_all_checkpoints_{date}.pkl", 'wb') as f:
#     pickle.dump(all_probes, f)

### plot heldout test probe scores (y-axis) w/ stderr (from kfoldcv) for all model checkpoints (x-axis)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
sns.set_style("whitegrid")

plt.rcParams["figure.figsize"] = (12, 6)
plt.rcParams['xtick.labelsize'] = 14
plt.rcParams['ytick.labelsize'] = 14
plt.rcParams['legend.fontsize'] = 11
plt.rcParams['font.size'] = 14

fig, ax = plt.subplots()

_df_rows = []
for key, item in all_probes.items():
    # # print(key, item.keys())
    all_mse = item["model_metrics_across_folds"]["sequence_level"]["mean_squared_errors"]
    all_mae = item["model_metrics_across_folds"]["sequence_level"]["mean_absolute_errors"]
    all_r2 = item["model_metrics_across_folds"]["sequence_level"]["r2_scores"]
    for (mse, mae, r2) in zip(all_mse, all_mae, all_r2):
        tmp = {}
        tmp['x'] = key
        tmp['mse'] = mse
        tmp['mae'] = mae
        tmp['r2'] = r2
        _df_rows.append(tmp)
        
plot_df = pd.DataFrame(_df_rows)

sns.lineplot(
    data=plot_df, x='x', y='mse', errorbar='sd', err_style="band",
    marker="o", 
    label=f"Patch size: {model_patch_size}",
    ax=ax,
)

plt.xticks(rotation=20)
plt.xlabel("Checkpoints", fontsize=20)
plt.ylabel(f"{feature_type} Probe MSE (heldout set)", fontsize=20)
plt.show()

In [None]:
end_time = time.perf_counter()
elapsed_time = end_time - start_time
print(f"Wall-clock time: {elapsed_time:.4f} seconds")