In [26]:
"""
    helm.model_dkps
"""

import os

import argparse
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
from rich import print as rprint
from tqdm import trange
from joblib import Parallel, delayed
from sklearn.linear_model import LinearRegression
from sklearn.neighbors import KNeighborsRegressor

from utils import make_embedding_dict, onehot_embedding
from dkps.embed import embed_api
from dkps.dkps import DataKernelPerspectiveSpace as DKPS

import nest_asyncio
nest_asyncio.apply()

# --
# Helpers

def model2family(model):
    return model.split('_')[0]


def predict_null(df, mode='model'):
    """ average score of other models / families """
    assert mode in ['model', 'family']
    
    out = {}
    for model in df.model.unique():
        if mode == 'model':
            sel = df.model != model
        elif mode == 'family':
            sel = df.model.apply(model2family) != model2family(model)
        
        out[model] = df.score[sel].mean()
    
    return out


def _rel_err(act, pred):
    return np.abs(pred - act) / act

def _abs_err(act, pred):
    return np.abs(pred - act)

err_fns = {
    "abs" : _abs_err,
    "rel" : _rel_err,
}

def run_one(df_sample, n_samples, mode, seed, instance_ids):
    out = []
    model_names = df_sample.model.unique()
    
    S_all = df_sample.pivot(index='model', columns='instance_id', values='score').values
    
    embedding_dict = make_embedding_dict(df_sample)
    
    for target_model in model_names:
        
        # split data
        assert mode in ['model', 'family']
        if mode == 'model':
            train_models = np.array([m for m in model_names if m != target_model])
        elif mode == 'family':
            target_family = model2family(target_model)
            train_models  = np.array([m for m in model_names if model2family(m) != target_family])
        
        y_test  = y_acts[target_model]

        # average score over the `n_samples` evaluated
        p_sample = df_sample[df_sample.model == target_model].score.mean()

        # lr on DKPS embeddings of varying dimension
        p_lr_dkps = {}
        for n_components_cmds in [8]:
            for n_models in [len(train_models)]:
                _train_models   = np.random.choice(train_models, size=n_models, replace=False)
                _embedding_dict = {k:embedding_dict[k] for k in (set(_train_models) | set([target_model]))}
                
                P = DKPS(n_components_cmds=n_components_cmds)
                P = P.fit_transform(_embedding_dict, return_dict=True)
                
                _X_train = np.vstack([P[m] for m in _train_models])
                _y_train = np.array([y_acts[m] for m in _train_models])
                _X_test  = np.vstack([P[target_model]])

                # linear regression on DKPS embeddings        
                lr = LinearRegression().fit(_X_train, _y_train)
                
                if n_models != len(train_models):
                    p_lr_dkps[f'p_lr_dkps8__n_components_cmds={n_components_cmds}__n_models={n_models}'] = float(lr.predict(_X_test)[0])
                else:
                    p_lr_dkps[f'p_lr_dkps8__n_components_cmds={n_components_cmds}__n_models=ALL'] = float(lr.predict(_X_test)[0])

        out.append({
            "seed"         : seed,
            "n_samples"    : n_samples,
            "mode"         : mode,
            "target_model" : target_model,
            
            "y_act"        : y_test,
            "p_null"       : pred_null[mode][target_model],
            "p_sample"     : p_sample,

            "instance_ids" : instance_ids,
            
            **p_lr_dkps,
        })
    
    return out


dataset_dict = {
    # MATH dataset subjects
    "math:subject=algebra": {
        "score_col": "score",
        "embed_provider": "jina",
        "embed_model": None,
        "err_fn": "abs",
        "outdir": "results",
        "sample": 1,
        "seed": 1,
    },
    "math:subject=counting_and_probability": {
        "score_col": "score",
        "embed_provider": "jina",
        "embed_model": None,
        "err_fn": "abs",
        "outdir": "results",
        "sample": 1,
        "seed": 1,
    },
    "math:subject=geometry": {
        "score_col": "score",
        "embed_provider": "jina",
        "embed_model": None,
        "err_fn": "abs",
        "outdir": "results",
        "sample": 1,
        "seed": 1,
    },
    "math:subject=intermediate_algebra": {
        "score_col": "score",
        "embed_provider": "jina",
        "embed_model": None,
        "err_fn": "abs",
        "outdir": "results",
        "sample": 1,
        "seed": 1,
    },
    "math:subject=number_theory": {
        "score_col": "score",
        "embed_provider": "jina",
        "embed_model": None,
        "err_fn": "abs",
        "outdir": "results",
        "sample": 1,
        "seed": 1,
    },
    "math:subject=prealgebra": {
        "score_col": "score",
        "embed_provider": "jina",
        "embed_model": None,
        "err_fn": "abs",
        "outdir": "results",
        "sample": 1,
        "seed": 1,
    },
    "math:subject=precalculus": {
        "score_col": "score",
        "embed_provider": "jina",
        "embed_model": None,
        "err_fn": "abs",
        "outdir": "results",
        "sample": 1,
        "seed": 1,
    },

    # WMT14 language pairs (use meteor score, sample=0.2)
    "wmt_14:language_pair=cs-en": {
        "score_col": "meteor",
        "embed_provider": "jina",
        "embed_model": None,
        "err_fn": "abs",
        "outdir": "results",
        "sample": 0.2,
        "seed": 1,
    },
    "wmt_14:language_pair=de-en": {
        "score_col": "meteor",
        "embed_provider": "jina",
        "embed_model": None,
        "err_fn": "abs",
        "outdir": "results",
        "sample": 0.2,
        "seed": 1,
    },
    "wmt_14:language_pair=fr-en": {
        "score_col": "meteor",
        "embed_provider": "jina",
        "embed_model": None,
        "err_fn": "abs",
        "outdir": "results",
        "sample": 0.2,
        "seed": 1,
    },
    "wmt_14:language_pair=hi-en": {
        "score_col": "meteor",
        "embed_provider": "jina",
        "embed_model": None,
        "err_fn": "abs",
        "outdir": "results",
        "sample": 0.2,
        "seed": 1,
    },
    "wmt_14:language_pair=ru-en": {
        "score_col": "meteor",
        "embed_provider": "jina",
        "embed_model": None,
        "err_fn": "abs",
        "outdir": "results",
        "sample": 0.2,
        "seed": 1,
    },

    # MEDQA (embed_model=onehot)
    "med_qa": {
        "score_col": "score",
        "embed_provider": "jina",
        "embed_model": "onehot",
        "err_fn": "abs",
        "outdir": "results",
        "sample": 1,
        "seed": 1,
    },

    # LegalBench subsets (embed_model=onehot)
    "legalbench:subset=abercrombie": {
        "score_col": "score",
        "embed_provider": "jina",
        "embed_model": "onehot",
        "err_fn": "abs",
        "outdir": "results",
        "sample": 1,
        "seed": 1,
    },
    "legalbench:subset=international_citizenship_questions": {
        "score_col": "score",
        "embed_provider": "jina",
        "embed_model": "onehot",
        "err_fn": "abs",
        "outdir": "results",
        "sample": 1,
        "seed": 1,
    },
    "legalbench:subset=corporate_lobbying": {
        "score_col": "score",
        "embed_provider": "jina",
        "embed_model": "onehot",
        "err_fn": "abs",
        "outdir": "results",
        "sample": 1,
        "seed": 1,
    },
    "legalbench:subset=function_of_decision_section": {
        "score_col": "score",
        "embed_provider": "jina",
        "embed_model": "onehot",
        "err_fn": "abs",
        "outdir": "results",
        "sample": 1,
        "seed": 1,
    },
    "legalbench:subset=proa": {
        "score_col": "score",
        "embed_provider": "jina",
        "embed_model": "onehot",
        "err_fn": "abs",
        "outdir": "results",
        "sample": 1,
        "seed": 1,
    },
}

In [None]:
N_REPLICATES = 5000

for dataset in dataset_dict:
    print(dataset)
    
    score_col = dataset_dict[dataset]['score_col']
    embed_provider = dataset_dict[dataset]['embed_provider']
    embed_provider = dataset_dict[dataset]['embed_provider']
    embed_model = dataset_dict[dataset]['embed_model']
    err_fn = dataset_dict[dataset]['err_fn']
    outdir=dataset_dict[dataset]['outdir']
    sample = dataset_dict[dataset]['sample']
    seed=dataset_dict[dataset]['seed']
    
    inpath = Path('data') / f'{dataset.split(":")[0]}.tsv'
    outdir = Path(outdir)
    outdir.mkdir(parents=True, exist_ok=True)

    rprint('[blue]loading data ...[/blue]')

    df = pd.read_csv(inpath, sep='\t')
    df = df[df.dataset == dataset]

    if sample:
        rng           = np.random.default_rng(seed)
        uinstance_ids = df.instance_id.unique()
        keep          = rng.choice(uinstance_ids, int(len(uinstance_ids) * sample), replace=False)
        df            = df[df.instance_id.isin(keep)]

    df = df.sort_values(['model', 'instance_id']).reset_index(drop=True)
    
    if score_col != 'score':
        print(f'{score_col} -> score')
        df['score'] = df[score_col]

    # --
    # QC
    print(f'{len(df.response.unique())} / {df.shape[0]} responses are unique')
    _instance_ids = df.groupby('model').instance_id.apply(list)
    assert all([_instance_ids.iloc[0] == _instance_ids.iloc[i] for i in range(len(_instance_ids))]), 'instance_ids are not the same for each model'
    
    # --
    # Get embeddings
    
    if embed_model == 'onehot':
        df = onehot_embedding(df, dataset=dataset)
    else:
        df['embedding'] = list(embed_api(
            provider   = embed_provider, 
            input_strs = [str(xx) for xx in df.response.values],
            model      = embed_model
        ))
    
    # --
    # Run
    model_names  = df.model.unique()
    instance_ids = df.instance_id.unique()
    y_acts       = df.groupby('model').score.mean().to_dict()
    
    pred_null = {mode: predict_null(df, mode=mode) for mode in modes}
    
    # --
    # Simple - DKPS w/ more than one example
    
    outpath = outdir / f'{dataset}-{score_col}-res.tsv'
    
    jobs = []
    for iter in trange(N_REPLICATES):
        rng = np.random.default_rng(iter)
        for n_samples in [1, 4, 16, 64]:
            if n_samples > len(instance_ids):
                continue
            
            instance_ids_sample = rng.choice(instance_ids, size=n_samples, replace=False)
            df_sample           = df[df.instance_id.isin(instance_ids_sample)]
            jobs.append(delayed(run_one)(df_sample=df_sample, n_samples=n_samples, mode='family', seed=iter, instance_ids=instance_ids_sample))

    res    = sum(Parallel(n_jobs=-2, verbose=10)(jobs), [])
    df_res = pd.DataFrame(res)
    
    # compute errors - abs(pred - act) / act
    for c in df_res.columns:
        if 'p_' in c:
            df_res[c.replace('p_', 'e_')] = err_fns[err_fn](df_res.y_act, df_res[c])
    
    df_res.to_csv(outpath, sep='\t', index=False)

math:subject=algebra


12045 / 12825 responses are unique


Embedding chunks: 100%|█████████████████████████████████████████████████████████████| 257/257 [00:00<00:00, 2334.65it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 5000/5000 [00:11<00:00, 425.78it/s]
[Parallel(n_jobs=-2)]: Using backend LokyBackend with 25 concurrent workers.
[Parallel(n_jobs=-2)]: Done    11 out of 20000 | elapsed:    3.1s
[Parallel(n_jobs=-2)]: Done    22 out of 20000 | elapsed:    5.2s
[Parallel(n_jobs=-2)]: Done    35 out of 20000 | elapsed:   11.0s
[Parallel(n_jobs=-2)]: Done    48 out of 20000 | elapsed:   18.1s
[Parallel(n_jobs=-2)]: Done   263 out of 20000 | elapsed:  2.1min
[Parallel(n_jobs=-2)]: Done   288 out of 20000 | elapsed:  2.3min
[Parallel(n_jobs=-2)]: Done   315 out of 20000 | elapsed:  2.6min
[Parallel(n_jobs=-2)]: Done   342 out of 20000 | elapsed:  2.8min
[Parallel(n_jobs=-2)]: Done   371 out of 20000 | elapsed:  3.0min
[Parallel(n_jobs=-2)]: Done   400 out of 20000 | elapsed:  3.2min
[Parallel(n_jobs=-2)]

dkps.embed: unable to load google-genai
dkps.embed: unable to load google-genai
dkps.embed: unable to load google-genai
dkps.embed: unable to load google-genai
dkps.embed: unable to load google-genai
dkps.embed: unable to load google-genai
dkps.embed: unable to load google-genai
dkps.embed: unable to load google-genai
dkps.embed: unable to load google-genai
dkps.embed: unable to load google-genai
dkps.embed: unable to load google-genai
dkps.embed: unable to load google-genai
dkps.embed: unable to load google-genai
dkps.embed: unable to load google-genai
dkps.embed: unable to load google-genai
dkps.embed: unable to load google-genai
dkps.embed: unable to load google-genai
dkps.embed: unable to load google-genai
dkps.embed: unable to load google-genai
dkps.embed: unable to load google-genai
dkps.embed: unable to load google-genai
dkps.embed: unable to load google-genai
dkps.embed: unable to load google-genai
dkps.embed: unable to load google-genai
dkps.embed: unable to load google-genai


3537 / 3705 responses are unique


Embedding chunks: 100%|███████████████████████████████████████████████████████████████| 75/75 [00:00<00:00, 1985.14it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 5000/5000 [00:04<00:00, 1056.91it/s]
[Parallel(n_jobs=-2)]: Using backend LokyBackend with 25 concurrent workers.
[Parallel(n_jobs=-2)]: Done    11 out of 15000 | elapsed:    2.3s
[Parallel(n_jobs=-2)]: Done    22 out of 15000 | elapsed:    3.0s
[Parallel(n_jobs=-2)]: Done    35 out of 15000 | elapsed:    4.3s
[Parallel(n_jobs=-2)]: Done    48 out of 15000 | elapsed:    5.9s
[Parallel(n_jobs=-2)]: Done    63 out of 15000 | elapsed:    7.2s
[Parallel(n_jobs=-2)]: Done    78 out of 15000 | elapsed:    8.6s
[Parallel(n_jobs=-2)]: Done    95 out of 15000 | elapsed:   10.0s
[Parallel(n_jobs=-2)]: Done   112 out of 15000 | elapsed:   11.3s
[Parallel(n_jobs=-2)]: Done   131 out of 15000 | elapsed:   12.9s
[Parallel(n_jobs=-2)]: Done   150 out of 15000 | elapsed:   14.6s
[Parallel(n_jobs=-2)]

math:subject=geometry


3436 / 3610 responses are unique


Embedding chunks: 100%|███████████████████████████████████████████████████████████████| 73/73 [00:00<00:00, 1969.30it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 5000/5000 [00:04<00:00, 1099.08it/s]
[Parallel(n_jobs=-2)]: Using backend LokyBackend with 25 concurrent workers.
[Parallel(n_jobs=-2)]: Done    11 out of 15000 | elapsed:    1.2s
[Parallel(n_jobs=-2)]: Done    22 out of 15000 | elapsed:    1.9s
[Parallel(n_jobs=-2)]: Done    35 out of 15000 | elapsed:    3.2s
[Parallel(n_jobs=-2)]: Done    48 out of 15000 | elapsed:    4.8s
[Parallel(n_jobs=-2)]: Done    63 out of 15000 | elapsed:    5.8s
[Parallel(n_jobs=-2)]: Done    78 out of 15000 | elapsed:    7.4s
[Parallel(n_jobs=-2)]: Done    95 out of 15000 | elapsed:    9.2s
[Parallel(n_jobs=-2)]: Done   112 out of 15000 | elapsed:   10.3s
[Parallel(n_jobs=-2)]: Done   131 out of 15000 | elapsed:   12.1s
[Parallel(n_jobs=-2)]: Done   150 out of 15000 | elapsed:   14.1s
[Parallel(n_jobs=-2)]

math:subject=intermediate_algebra


4563 / 4940 responses are unique


Embedding chunks: 100%|███████████████████████████████████████████████████████████████| 99/99 [00:00<00:00, 2037.70it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 5000/5000 [00:04<00:00, 1025.67it/s]
[Parallel(n_jobs=-2)]: Using backend LokyBackend with 25 concurrent workers.
[Parallel(n_jobs=-2)]: Done    11 out of 15000 | elapsed:    1.2s
[Parallel(n_jobs=-2)]: Done    22 out of 15000 | elapsed:    2.0s
[Parallel(n_jobs=-2)]: Done    35 out of 15000 | elapsed:    3.5s
[Parallel(n_jobs=-2)]: Done    48 out of 15000 | elapsed:    5.1s
[Parallel(n_jobs=-2)]: Done    63 out of 15000 | elapsed:    6.1s
[Parallel(n_jobs=-2)]: Done    78 out of 15000 | elapsed:    7.2s
[Parallel(n_jobs=-2)]: Done    95 out of 15000 | elapsed:    9.0s
[Parallel(n_jobs=-2)]: Done   112 out of 15000 | elapsed:   10.2s
[Parallel(n_jobs=-2)]: Done   131 out of 15000 | elapsed:   11.8s
[Parallel(n_jobs=-2)]: Done   150 out of 15000 | elapsed:   13.4s
[Parallel(n_jobs=-2)]

math:subject=number_theory


2693 / 2850 responses are unique


Embedding chunks: 100%|███████████████████████████████████████████████████████████████| 57/57 [00:00<00:00, 2021.23it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 5000/5000 [00:04<00:00, 1109.61it/s]
[Parallel(n_jobs=-2)]: Using backend LokyBackend with 25 concurrent workers.
[Parallel(n_jobs=-2)]: Done    11 out of 15000 | elapsed:    1.2s
[Parallel(n_jobs=-2)]: Done    22 out of 15000 | elapsed:    1.8s
[Parallel(n_jobs=-2)]: Done    35 out of 15000 | elapsed:    3.0s
[Parallel(n_jobs=-2)]: Done    48 out of 15000 | elapsed:    4.5s
[Parallel(n_jobs=-2)]: Done    63 out of 15000 | elapsed:    5.5s
[Parallel(n_jobs=-2)]: Done    78 out of 15000 | elapsed:    6.4s
[Parallel(n_jobs=-2)]: Done    95 out of 15000 | elapsed:    7.8s
[Parallel(n_jobs=-2)]: Done   112 out of 15000 | elapsed:    9.2s
[Parallel(n_jobs=-2)]: Done   131 out of 15000 | elapsed:   10.5s
[Parallel(n_jobs=-2)]: Done   150 out of 15000 | elapsed:   12.1s
[Parallel(n_jobs=-2)]

math:subject=prealgebra


7690 / 8170 responses are unique


Embedding chunks: 100%|█████████████████████████████████████████████████████████████| 164/164 [00:00<00:00, 2546.54it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 5000/5000 [00:09<00:00, 524.28it/s]
[Parallel(n_jobs=-2)]: Using backend LokyBackend with 25 concurrent workers.
[Parallel(n_jobs=-2)]: Done    11 out of 20000 | elapsed:    1.6s
[Parallel(n_jobs=-2)]: Done    22 out of 20000 | elapsed:    3.4s
[Parallel(n_jobs=-2)]: Done    35 out of 20000 | elapsed:    6.4s
[Parallel(n_jobs=-2)]: Done    48 out of 20000 | elapsed:   10.5s
[Parallel(n_jobs=-2)]: Done    63 out of 20000 | elapsed:   18.1s
[Parallel(n_jobs=-2)]: Done    78 out of 20000 | elapsed:   39.8s
[Parallel(n_jobs=-2)]: Done    95 out of 20000 | elapsed:   44.9s
[Parallel(n_jobs=-2)]: Done   112 out of 20000 | elapsed:   49.0s
[Parallel(n_jobs=-2)]: Done   131 out of 20000 | elapsed:   54.4s
[Parallel(n_jobs=-2)]: Done   150 out of 20000 | elapsed:  1.0min
[Parallel(n_jobs=-2)]

In [51]:
len(df_res)

1860000

In [69]:
sets = [set(v) for v in df_res['instance_ids'].values]

In [63]:
len(np.unique(lists))

ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (1860000,) + inhomogeneous part.

In [66]:
lists[2]

['legalbench:subset=proa--id86']