In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import json
from IPython.display import display
import numpy as np

# %matplotlib inline

from pathlib import Path


In [73]:
def collect_one(run, csv_folder, test_config='general'):
    #if run.name == 'baseline':
    parameters = {p.split('=')[0]: p.split('=')[1] for p in run.as_posix().split('/') if "=" in p}
    # else:
    #     with open(run / 'config.json', 'r') as f:
    #         cfg = json.load(f)
            
    #     run_dir = cfg["run_dir"]
    #     parameters = {p.split('=')[0]: p.split('=')[1] for p in run_dir.split('/') if "=" in p}

    data = []
    for yamlf in (run / 'inference' / csv_folder).rglob('*.csv'):
        test_config_name = yamlf.parent.stem
        if test_config not in yamlf.parent.stem:
            continue
        df = pd.read_csv(yamlf)
        # keep only the last line. This is because until this commit, we were appending the logs and not overriding them, so the last line is the one we want
        df = df.tail(1)
        df['prompt_ensemble'] = 1 if 'prompt_ensemble' in test_config_name else 0
        df['original_names'] = 1 if 'original_names' in test_config_name else 0
        df['tok_position_inference'] = 'beginning' if 'tok_beginning' in test_config_name else 'in_place' if 'tok_in_place' in test_config_name else None
        if run.name == 'baseline':
            df['model'] = 'clip_original'
        data.append(df)
    
    data = pd.concat(data)
    # data.columns.names = ['type', 'metric']
    # data.sort_values(by=['type', 'metric'], axis=1, inplace=True)
    data.drop(columns=['epoch', 'step'], inplace=True)
    
    if data.empty:
        print(f'Pred folder is empty: {csv_folder}')
    
    for k, v in parameters.items():
        data[k] = v
    
    return data

def collect_all(root, csv_folder, test_config='general'):
    root = Path(root)
    metrics = [collect_one(csvf.parents[1], csvf.name, test_config=test_config) for csvf in list(root.rglob(csv_folder))]
    metrics = pd.concat(metrics, ignore_index=True)
    return metrics

default_fields_dict = {
    'r1': lambda x: u"{:.1f}".format(x),
    'r5': lambda x: u"{:.1f}".format(x),
    'r10': lambda x: u"{:.1f}".format(x),
    'meanr': lambda x: u"{:.1f}".format(x),
    'medr': lambda x: int(x),
    'spice': lambda x: u"{:.3f}".format(x),
    'spacy': lambda x: u"{:.3f}".format(x),
}
def render_to_latex(metrics, rename_func=default_fields_dict, **latex_kwargs):
    m = metrics.copy()
    # renaming
    for col, lambda_fn in rename_func.items():
        m[col] = m[col].apply(lambda_fn)
    # m = m.applymap(lambda x: u"{:.2f}".format(x))
    ltex = m.style.to_latex(
        **latex_kwargs
    )
    return ltex

In [71]:
# Compute metrics for each detected run

def summarize_metrics(
        metrics, 
        dataset=None, 
        model=None,
        translator=None,
        tok_position=None, 
        training_setup=None, 
        loss=None, 
        learning_rate=None, 
        finetuning=None, 
        drop_i2t=True,
        decimal_places=3):
    
    if dataset is not None:
        metrics = metrics[(metrics['data'] == dataset)]
        metrics.drop(columns="data", inplace=True)

    # TODO: as of now, there is only one split seed.
    # In the future, we would have to average among different splits
    # metrics.drop(columns="split_seed", inplace=True)

    id_vars = ['data', 'model', 'translator', 'tok_position', 'training-setup', 'loss', 'lr', 'finetuning']

    if translator is not None:
        metrics = metrics[metrics['translator'].isin(learning_rate)]
        if len(translator) == 1:
            metrics.drop(columns="translator", inplace=True)
            id_vars.remove('translator')
    if model is not None:
        metrics = metrics[metrics['model'].isin(learning_rate)]
        if len(model) == 1:
            metrics.drop(columns="model", inplace=True)
            id_vars.remove('model')
    if learning_rate is not None:
        metrics = metrics[metrics['lr'].isin(learning_rate)]
        if len(learning_rate) == 1:
            metrics.drop(columns="lr", inplace=True)
            id_vars.remove('lr')
    if finetuning is not None:
        metrics = metrics[metrics['finetuning'].isin(finetuning)]
        if len(finetuning) == 1:
            metrics.drop(columns="finetuning", inplace=True)
            id_vars.remove('finetuning')
    if tok_position is not None:
        metrics = metrics[metrics['tok_position'].isin(tok_position)]
        if len(tok_position) == 1:
            metrics.drop(columns="tok_position", inplace=True)
            id_vars.remove('tok_position')
    if training_setup is not None:
        metrics = metrics[metrics['training-setup'].isin(training_setup)]
        if len(training_setup) == 1:
            metrics.drop(columns="training-setup", inplace=True)
            id_vars.remove('training-setup')
    if loss is not None:
        metrics = metrics[metrics['loss'].isin(loss)]
        if len(loss) == 1:
            metrics.drop(columns="loss", inplace=True)
            id_vars.remove('loss')

    if drop_i2t:
        # remove columns containing i2t in the name of the second level of the multiindex
        metrics = metrics.loc[:, ~metrics.columns.str.contains('i2t')]

    # round to given decimal places
    metrics = metrics.round(decimal_places)

    metrics.set_index(id_vars, inplace=True)
    # split into different dataframes, one for each column (first level of the multiindex)
    # column_types = list(metrics.columns.get_level_values(0).unique())
    # metrics = {c: metrics.loc[:, c].copy() for c in column_types}
    # {k: v.columns.set_names(k, inplace=True) for k, v in metrics.items()}
    return metrics

In [4]:
# rename content of the table
def rename_fn(v):
    mapping = {'ContrastiveFixed': 'Triplet',
               'InfoNCELoss': 'InfoNCE'}
    if v in mapping:
        return mapping[v]
    return v

# Results - General retrieval (best contrastive sum checkpoint)

In [74]:
# collect all data
ROOT = "runs"

metrics = collect_all(ROOT, 'best-contrastive-sum')
metrics_baselines = collect_all(ROOT, 'original_checkpoint')    # baseline model

metrics_concat = pd.concat([metrics, metrics_baselines], axis=0, join='outer')

metrics = summarize_metrics(
    metrics_concat,
    training_setup=["with_entities", np.nan],
    finetuning=["disabled", "shallow-vpt-5", np.nan],
    tok_position=["tok_in_place_multi_prompts", np.nan],
    )

metrics

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,t2i-r@1,t2i-r@50,contrastive_sum,t2i-r@5,contrastive_t2i_sum,t2i-r@10,prompt_ensemble,original_names,tok_position_inference
data,model,translator,tok_position,training-setup,loss,lr,finetuning,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
coco_faceswap_5_entities,idclip,mlp-1-layer,tok_in_place_multi_prompts,with_entities,info-nce,5e-05,disabled,0.399,0.965,5.659,0.701,2.875,0.811,0,0,in_place
coco_faceswap_5_entities,idclip,mlp-1-layer,tok_in_place_multi_prompts,with_entities,info-nce,5e-05,disabled,0.386,0.95,5.437,0.675,2.791,0.78,1,1,in_place
coco_faceswap_5_entities,idclip,mlp-1-layer,tok_in_place_multi_prompts,with_entities,info-nce,5e-05,disabled,0.386,0.951,5.458,0.68,2.801,0.784,1,1,beginning
coco_faceswap_5_entities,idclip,mlp-1-layer,tok_in_place_multi_prompts,with_entities,info-nce,5e-05,disabled,0.398,0.967,5.698,0.708,2.888,0.815,0,0,beginning
coco_faceswap_5_entities,idclip,mlp-1-layer,tok_in_place_multi_prompts,with_entities,info-nce,5e-05,disabled,0.402,0.967,5.678,0.704,2.886,0.814,1,0,in_place
coco_faceswap_5_entities,idclip,mlp-1-layer,tok_in_place_multi_prompts,with_entities,info-nce,5e-05,disabled,0.398,0.967,5.686,0.707,2.886,0.813,1,0,beginning
coco_faceswap_5_entities,idclip,mlp-1-layer,tok_in_place_multi_prompts,with_entities,info-nce,5e-05,shallow-vpt-5,0.429,0.957,5.591,0.7,2.88,0.795,0,0,in_place
coco_faceswap_5_entities,idclip,mlp-1-layer,tok_in_place_multi_prompts,with_entities,info-nce,5e-05,shallow-vpt-5,0.388,0.937,5.301,0.652,2.735,0.758,1,1,in_place
coco_faceswap_5_entities,idclip,mlp-1-layer,tok_in_place_multi_prompts,with_entities,info-nce,5e-05,shallow-vpt-5,0.391,0.937,5.321,0.655,2.743,0.76,1,1,beginning
coco_faceswap_5_entities,idclip,mlp-1-layer,tok_in_place_multi_prompts,with_entities,info-nce,5e-05,shallow-vpt-5,0.431,0.959,5.63,0.704,2.894,0.8,0,0,beginning


# Results - Entities retrieval (best contrastive sum checkpoint)

In [76]:
# collect all data
ROOT = "runs"

metrics = collect_all(ROOT, 'best-contrastive-sum', test_config='entities')
metrics_baselines = collect_all(ROOT, 'original_checkpoint', test_config='entities')    # baseline model

metrics_concat = pd.concat([metrics, metrics_baselines], axis=0, join='outer')
metrics = summarize_metrics(
    metrics_concat,
    training_setup=["with_entities", np.nan],
    finetuning=["disabled", "shallow-vpt-5", np.nan],
    tok_position=["tok_in_place_multi_prompts", np.nan],
    )

metrics

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,entity-r@10,entity-r@1,entity-r@50,entity-kmin-r@10,mAP,entities_kmin_sum,entity-kmin-r@5,entity-kmin-r@1,entity-r@5,entities_sum,entity-kmin-r@50,prompt_ensemble,original_names,tok_position_inference
data,model,translator,tok_position,training-setup,loss,lr,finetuning,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
coco_faceswap_5_entities,idclip,mlp-1-layer,tok_in_place_multi_prompts,with_entities,info-nce,5e-05,disabled,0.123,0.032,0.276,0.123,0.028,0.639,0.089,0.151,0.076,1.174,0.276,0,0,
coco_faceswap_5_entities,idclip,mlp-1-layer,tok_in_place_multi_prompts,with_entities,info-nce,5e-05,disabled,0.124,0.036,0.28,0.124,0.029,0.654,0.091,0.159,0.079,1.202,0.28,1,0,
coco_faceswap_5_entities,idclip,mlp-1-layer,tok_in_place_multi_prompts,with_entities,info-nce,5e-05,disabled,0.133,0.051,0.295,0.133,0.038,0.776,0.119,0.229,0.103,1.398,0.295,1,1,
coco_faceswap_5_entities,idclip,mlp-1-layer,tok_in_place_multi_prompts,with_entities,info-nce,5e-05,shallow-vpt-5,0.147,0.042,0.307,0.147,0.038,0.774,0.117,0.204,0.102,1.409,0.307,0,0,
coco_faceswap_5_entities,idclip,mlp-1-layer,tok_in_place_multi_prompts,with_entities,info-nce,5e-05,shallow-vpt-5,0.146,0.041,0.317,0.146,0.038,0.773,0.118,0.193,0.103,1.417,0.317,1,0,
coco_faceswap_5_entities,idclip,mlp-1-layer,tok_in_place_multi_prompts,with_entities,info-nce,5e-05,shallow-vpt-5,0.164,0.046,0.317,0.164,0.043,0.839,0.135,0.223,0.118,1.526,0.317,1,1,
coco_faceswap_5_entities,baseline,,,,,,,0.084,0.024,0.176,0.084,0.021,0.448,0.06,0.128,0.052,0.805,0.176,0,1,
