In [13]:
#!/usr/bin/env python
"""
    helm.plot_dkps
"""

import os
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm
from pathlib import Path
import matplotlib.pyplot as plt
from rich import print as rprint

# from utils import dkps_df, onehot_embedding
# from dkps.embed import embed_api

import numpy as np
from graspologic.embed import ClassicalMDS
# from graspologic.embed import OmnibusEmbed

from scipy.spatial.distance import pdist, squareform

# from .utils import knn_graph


class DataKernelPerspectiveSpace:
    def __init__(
            self,
            response_distribution_fn=np.mean,
            response_distribution_axis=1,
            metric_cmds='euclidean',
            n_components_cmds=None,
            n_elbows_cmds=2,
            dissimilarity="precomputed",
        ):
        
        self.response_distribution_fn   = response_distribution_fn
        self.response_distribution_axis = response_distribution_axis
        self.metric_cmds                = metric_cmds
        self.n_components_cmds          = n_components_cmds
        self.n_elbows_cmds              = n_elbows_cmds
        self.dissimilarity              = dissimilarity

    def fit_transform(self, data, return_dict=True):
        """
        data: dict {model_name: np.array(n_queries, n_replicates, embedding_dim)}
        """
        
        # qc checks
        assert isinstance(data, dict),                                  'data must be a dict'
        assert all([isinstance(x, np.ndarray) for x in data.values()]), 'all values must be numpy arrays'
        assert all([x.ndim == 3 for x in data.values()]),               'all arrays must be 3D - np.array(n_queries, n_replicates, embedding_dim)'
        assert len(set([x.shape for x in data.values()])) == 1,         'all arrays must have the same shape'

        # aggregate over replicates -> (n_models, n_queries, embedding_dim)
        X = np.stack([self.response_distribution_fn(v, axis=self.response_distribution_axis) for k,v in data.items()])
        n_models, n_queries, embedding_dim = X.shape
        
        # flatten -> (n_models, n_queries * embedding_dim)
        X_flat = X.reshape(len(X), -1)

        dist_matrix = squareform(pdist(X_flat, metric=self.metric_cmds)) / np.sqrt(n_queries)
        cmds_embds  = ClassicalMDS(n_components=self.n_components_cmds, n_elbows=self.n_elbows_cmds, dissimilarity=self.dissimilarity).fit_transform(dist_matrix)
        
        if return_dict:
            return {key: cmds_embds[i] for i, key in enumerate(data.keys())}
        else:
            return cmds_embds


def make_embedding_dict(df):
    model_names  = df.model.unique()
    instance_ids = df.instance_id.unique()
    
    embedding_dict = {}
    for model_name in model_names:
        sub = df[df.model == model_name]
        assert (sub.instance_id.values == instance_ids).all(), f'instance_ids are not the same for model {model_name}'
        embedding_dict[model_name] = np.vstack(sub.embedding.values)
    
    embedding_dict = {k:v[:,None] for k,v in embedding_dict.items()}
    
    return embedding_dict


def dkps_df(df, **kwargs):
    embedding_dict = make_embedding_dict(df)
    return DataKernelPerspectiveSpace(**kwargs).fit_transform(embedding_dict, return_dict=True)


# --

def onehot_embedding(df, dataset):
    if dataset == 'med_qa':
        lookup = {'A' : 0, 'B' : 1, 'C' : 2, 'D' : 3}
        
        embeddings = np.zeros((len(df), 4))
        for i, xx in enumerate(df.response.values):
            if xx in lookup:
                embeddings[i, lookup[xx]] = 1
        
        df['embedding'] = embeddings.tolist()
    
    elif 'legalbench' in dataset:
        # slightly different - bad values get mapped to 0
        n_levels   = len(df.response.unique())
        embeddings = np.zeros((len(df), n_levels))
        for i, xx in enumerate(df.response.values):
            embeddings[i, xx] = 1

        df['embedding'] = embeddings.tolist()
    else:
        raise ValueError(f'{dataset} is not supported for onehot embeddings')
    
    return df


np.random.seed(0)
dataset = 'math:subject=algebra'
# dataset = file.split('-score-res')[0]
# dataset = dataset.split('-meteor-res')[0]
score_col = 'score-res'
outdir = '/home/user/helivan-project-generation/ep-dkps-results/results'
tsv_path = Path(outdir) / f'{dataset}-{score_col}.tsv'
# plot_dir = Path('plots') / dataset.replace(':', '-')


rprint('[blue]loading data ...[/blue]')

df = pd.read_csv(tsv_path, sep='\t')
df = df[df.dataset == dataset]

if sample:
    rng           = np.random.default_rng(seed)
    uinstance_ids = df.instance_id.unique()
    keep          = rng.choice(uinstance_ids, int(len(uinstance_ids) * sample), replace=False)
    df            = df[df.instance_id.isin(keep)]

df = df.sort_values(['model', 'instance_id']).reset_index(drop=True)

if score_col != 'score':
    print(f'{score_col} -> score')
    df['score'] = df[score_col]

# BAD_MODELS  = open('bad_models.txt').read().splitlines()

# --
# QC

print(f'{len(df.response.unique())} / {df.shape[0]} responses are unique')
instance_ids = df.groupby('model').instance_id.apply(list)
assert all([instance_ids.iloc[0] == instance_ids.iloc[i] for i in range(len(instance_ids))]), 'instance_ids are not the same for each model'

# --
# Get embeddings

if embed_model == 'onehot':
    df = onehot_embedding(df, dataset=dataset)
else:
    df['embedding'] = list(embed_api(
        provider   = embed_provider, 
        input_strs = [str(xx) for xx in df.response.values],
        model      = embed_model
    ))

model2score = df.groupby('model').score.mean().to_dict()

# --
# Plot 1.a - whole DKPS

P = dkps_df(df, n_components_cmds=2)
P = np.vstack([P[m] for m in model2score.keys()])

_ = plt.scatter(P[:, 0], P[:,1], c=model2score.values(), cmap='viridis')
_ = plt.xticks([])
_ = plt.yticks([])
_ = plt.xlabel('DKPS-0')
_ = plt.ylabel('DKPS-1')
_ = plt.grid('both', alpha=0.25, c='gray')
_ = plt.title(f'DKPS - {dataset}')
_ = plt.colorbar(label='Score')
_ = plt.savefig(plot_dir / 'dkps.png')
_ = plt.close()

# # --
# # Plot 1.b - whole DKPS, bad models removed

# thresh     = np.percentile(list(model2score.values()), 10)
# BAD_MODELS = [m for m in model2score.keys() if model2score[m] < thresh]

# _model2score = {m:model2score[m] for m in model2score.keys() if m not in BAD_MODELS}
# P = dkps_df(df[~df.model.isin(BAD_MODELS)], n_components_cmds=2)
# P = np.vstack([P[m] for m in _model2score.keys()])

# _ = plt.scatter(P[:, 0], P[:,1], c=_model2score.values(), cmap='viridis')
# _ = plt.xticks([])
# _ = plt.yticks([])
# _ = plt.xlabel('DKPS-0')
# _ = plt.ylabel('DKPS-1')
# _ = plt.grid('both', alpha=0.25, c='gray')
# _ = plt.title(f'DKPS - {dataset}')
# _ = plt.colorbar(label='Score')
# _ = plt.savefig(plot_dir / 'dkps-excl.png')
# _ = plt.close()

# --
# Plot 2 - grid, varying number of instances and models

uinstance_ids = np.random.choice(df.instance_id.unique(), size=min(50, len(df.instance_id.unique())), replace=False)
umodels       = np.random.permutation(df.model.unique())

fig, axes = plt.subplots(2, 3, figsize=(12, 10))

Ps = {}
for c, n_instances in enumerate([2, 10, len(uinstance_ids)]):
    _instance_ids = uinstance_ids[:n_instances]
    for r, n_models in enumerate([20, len(umodels)]):
        _models      = umodels[:n_models]
        _model2score = {m:model2score[m] for m in _models}
        
        df_sub = df[df.instance_id.isin(_instance_ids)]
        df_sub = df_sub[df_sub.model.isin(_models)]
        P_sub  = dkps_df(df_sub, n_components_cmds=2)
        P_sub  = np.vstack([P_sub[m] for m in _model2score.keys()])

        ax = axes[r, c]

        scatter = ax.scatter(P_sub[:, 0], P_sub[:, 1], c=list(_model2score.values()), cmap='viridis')
        _ = ax.set_xticks([])
        _ = ax.set_yticks([])
        _ = ax.set_xlabel('DKPS-0')
        _ = ax.set_ylabel('DKPS-1')
        _ = ax.grid('both', alpha=0.25, c='gray')
        _ = ax.set_title(f'n_models={n_models} | n_instances={n_instances}')

_ = plt.suptitle(f'DKPS - {dataset}')
_ = plt.tight_layout()

# Add colorbar to the figure
cbar = plt.colorbar(scatter, ax=axes, shrink=0.8, aspect=20)
cbar.set_label('Score')

_ = plt.savefig(plot_dir / 'dkps-grid.png')
_ = plt.close()


AttributeError: 'DataFrame' object has no attribute 'dataset'

In [10]:
dataset

'math:subject=algebra'