In [2]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import json
from IPython.display import display
import numpy as np

# %matplotlib inline

from pathlib import Path


In [3]:
def collect_one(run, csv_folder, test_config='general'):
    #if run.name == 'baseline':
    parameters = {p.split('=')[0]: p.split('=')[1] for p in run.as_posix().split('/') if "=" in p}
    # else:
    #     with open(run / 'config.json', 'r') as f:
    #         cfg = json.load(f)
            
    #     run_dir = cfg["run_dir"]
    #     parameters = {p.split('=')[0]: p.split('=')[1] for p in run_dir.split('/') if "=" in p}

    data = []
    for yamlf in (run / 'inference' / csv_folder).rglob('*.csv'):
        test_config_name = yamlf.parent.stem
        if test_config not in yamlf.parent.stem:
            continue
        df = pd.read_csv(yamlf)
        # keep only the last line. This is because until this commit, we were appending the logs and not overriding them, so the last line is the one we want
        df = df.tail(1)
        df['prompt_ensemble'] = 1 if 'prompt_ensemble' in test_config_name else 0
        df['original_names'] = 1 if 'original_names' in test_config_name else 0
        df['tok_position_inference'] = 'beginning' if 'tok_beginning' in test_config_name else 'in_place' if 'tok_in_place' in test_config_name else None
        if run.name == 'baseline':
            df['model'] = 'clip_original'
        data.append(df)
    
    data = pd.concat(data)
    # data.columns.names = ['type', 'metric']
    # data.sort_values(by=['type', 'metric'], axis=1, inplace=True)
    data.drop(columns=['epoch', 'step'], inplace=True)
    
    if data.empty:
        print(f'Pred folder is empty: {csv_folder}')
    
    for k, v in parameters.items():
        data[k] = v
    
    return data

def collect_all(root, csv_folder, test_config='general'):
    root = Path(root)
    metrics = [collect_one(csvf.parents[1], csvf.name, test_config=test_config) for csvf in list(root.rglob(csv_folder))]
    metrics = pd.concat(metrics, ignore_index=True)
    return metrics

default_fields_dict = {
    'r1': lambda x: u"{:.1f}".format(x),
    'r5': lambda x: u"{:.1f}".format(x),
    'r10': lambda x: u"{:.1f}".format(x),
    'meanr': lambda x: u"{:.1f}".format(x),
    'medr': lambda x: int(x),
    'spice': lambda x: u"{:.3f}".format(x),
    'spacy': lambda x: u"{:.3f}".format(x),
}
def render_to_latex(metrics, rename_func=default_fields_dict, **latex_kwargs):
    m = metrics.copy()
    # renaming
    for col, lambda_fn in rename_func.items():
        m[col] = m[col].apply(lambda_fn)
    # m = m.applymap(lambda x: u"{:.2f}".format(x))
    ltex = m.style.to_latex(
        **latex_kwargs
    )
    return ltex

In [9]:
# Compute metrics for each detected run

def summarize_metrics(
        metrics, 
        dataset=None, 
        model=None,
        translator=None,
        tok_position=None, 
        tok_position_inference=None,
        training_setup=None, 
        loss=None, 
        learning_rate=None, 
        finetuning=None, 
        drop_i2t=True,
        decimal_places=3):
    
    if dataset is not None:
        metrics = metrics[(metrics['data'] == dataset)]
        metrics.drop(columns="data", inplace=True)

    # TODO: as of now, there is only one split seed.
    # In the future, we would have to average among different splits
    # metrics.drop(columns="split_seed", inplace=True)

    id_vars = ['data', 'model', 'translator', 'tok_position', 'training-setup', 'loss', 'lr', 'finetuning', 'prompt_ensemble', 'original_names', 'tok_position_inference']

    if translator is not None:
        metrics = metrics[metrics['translator'].isin(learning_rate)]
        if len(translator) == 1:
            metrics.drop(columns="translator", inplace=True)
            id_vars.remove('translator')
    if model is not None:
        metrics = metrics[metrics['model'].isin(learning_rate)]
        if len(model) == 1:
            metrics.drop(columns="model", inplace=True)
            id_vars.remove('model')
    if learning_rate is not None:
        metrics = metrics[metrics['lr'].isin(learning_rate)]
        if len(learning_rate) == 1:
            metrics.drop(columns="lr", inplace=True)
            id_vars.remove('lr')
    if finetuning is not None:
        metrics = metrics[metrics['finetuning'].isin(finetuning)]
        if len(finetuning) == 1:
            metrics.drop(columns="finetuning", inplace=True)
            id_vars.remove('finetuning')
    if tok_position is not None:
        metrics = metrics[metrics['tok_position'].isin(tok_position)]
        if len(tok_position) == 1:
            metrics.drop(columns="tok_position", inplace=True)
            id_vars.remove('tok_position')
    if training_setup is not None:
        metrics = metrics[metrics['training-setup'].isin(training_setup)]
        if len(training_setup) == 1:
            metrics.drop(columns="training-setup", inplace=True)
            id_vars.remove('training-setup')
    if loss is not None:
        metrics = metrics[metrics['loss'].isin(loss)]
        if len(loss) == 1:
            metrics.drop(columns="loss", inplace=True)
            id_vars.remove('loss')
    if tok_position_inference is not None:
        metrics = metrics[metrics['tok_position_inference'].isin(tok_position_inference)]
        if len(tok_position_inference) == 1:
            metrics.drop(columns="tok_position_inference", inplace=True)
            id_vars.remove('tok_position_inference')

    if drop_i2t:
        # remove columns containing i2t in the name of the second level of the multiindex
        metrics = metrics.loc[:, ~metrics.columns.str.contains('i2t')]

    # round to given decimal places
    metrics = metrics.round(decimal_places)

    metrics.set_index(id_vars, inplace=True)
    # split into different dataframes, one for each column (first level of the multiindex)
    # column_types = list(metrics.columns.get_level_values(0).unique())
    # metrics = {c: metrics.loc[:, c].copy() for c in column_types}
    # {k: v.columns.set_names(k, inplace=True) for k, v in metrics.items()}
    return metrics

In [31]:
# rename content of the table
def rename_fn(v):
    mapping = {'ContrastiveFixed': 'Triplet',
               'InfoNCELoss': 'InfoNCE'}
    if v in mapping:
        return mapping[v]
    return v

def render_to_latex(metrics, rename_func=default_fields_dict, **latex_kwargs):
    m = metrics.copy()
     # make bold the best values

    # Custom function to highlight the maximum value in each group
    def highlight_best(data):
        attr = 'font-weight: bold'
        result = pd.DataFrame('', index=data.index, columns=data.columns)
        for col in data.columns:
            best_idx = data[col].idxmax()
            # for idx in best_idx.values:
            result.loc[best_idx, col] = attr
        return result

    styled_df = m.style.apply(highlight_best, axis=None)
    ltex = styled_df.format(precision=2).to_latex(
        **latex_kwargs
    )
    return ltex

# Results - General retrieval (best contrastive sum checkpoint)

In [38]:
# collect all data
ROOT = "runs"

metrics = collect_all(ROOT, 'best-contrastive-sum')
metrics_baselines = collect_all(ROOT, 'original_checkpoint')    # baseline model

metrics_concat = pd.concat([metrics, metrics_baselines], axis=0, join='outer')

metrics = summarize_metrics(
    metrics_concat,
    training_setup=["with_entities", np.nan],
    finetuning=["disabled", "shallow-vpt-5", np.nan],
    tok_position=["tok_in_place_multi_prompts", np.nan],
    tok_position_inference=["in_place", None]
    )

# remove contrastive_sum columns
metrics.drop(columns="contrastive_sum", inplace=True)

# remove data, translator, tok_position, training-setup, loss, lr, tok_position_inference from the multi index
metrics.index = metrics.index.droplevel(['data', 'translator', 'tok_position', 'training-setup', 'loss', 'lr', 'tok_position_inference'])

# reorder columns to t2i-r@1 t2i-r@5 t2i-r@10 t2i-r@50 contrastive_t2i_sum
metrics = metrics[["t2i-r@1", "t2i-r@5", "t2i-r@10", "t2i-r@50", "contrastive_t2i_sum"]]

# transform in percentage
metrics = metrics * 100

# latex = render_to_latex(
#     metrics, 
#     caption="General Retrieval",
#     clines="skip-last;data",
#     hrules=True,
#     column_format="llllccccc",
#     convert_css=True
# )

# print(latex)

metrics

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,t2i-r@1,t2i-r@5,t2i-r@10,t2i-r@50,contrastive_t2i_sum
model,finetuning,prompt_ensemble,original_names,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
idclip,disabled,0,0,39.9,70.1,81.1,96.5,287.5
idclip,disabled,1,1,38.6,67.5,78.0,95.0,279.1
idclip,disabled,1,0,40.2,70.4,81.4,96.7,288.6
idclip,shallow-vpt-5,0,0,42.9,70.0,79.5,95.7,288.0
idclip,shallow-vpt-5,1,1,38.8,65.2,75.8,93.7,273.5
idclip,shallow-vpt-5,1,0,42.8,69.8,79.7,95.7,288.0
baseline,,0,1,23.1,50.1,59.5,79.6,212.4
baseline,,0,0,10.5,48.4,59.2,80.2,198.4


# Results - Entities retrieval (best contrastive sum checkpoint)

In [37]:
# collect all data
ROOT = "runs"

metrics = collect_all(ROOT, 'best-contrastive-sum', test_config='entities')
metrics_baselines = collect_all(ROOT, 'original_checkpoint', test_config='entities')    # baseline model

metrics_concat = pd.concat([metrics, metrics_baselines], axis=0, join='outer')
metrics = summarize_metrics(
    metrics_concat,
    training_setup=["with_entities", np.nan],
    finetuning=["disabled", "shallow-vpt-5", np.nan],
    tok_position=["tok_in_place_multi_prompts", np.nan],
    tok_position_inference=["in_place", None]
    )

# remove all columns containing "entity-r"
metrics = metrics.loc[:, ~metrics.columns.str.contains('entity-r')]
metrics.drop(columns="entities_sum", inplace=True)

# remove data, translator, tok_position, training-setup, loss, lr, tok_position_inference from the multi index
metrics.index = metrics.index.droplevel(['data', 'translator', 'tok_position', 'training-setup', 'loss', 'lr', 'tok_position_inference'])

# reorder columns to t2i-r@1 t2i-r@5 t2i-r@10 t2i-r@50 contrastive_t2i_sum
metrics = metrics[["entity-kmin-r@1", "entity-kmin-r@5", "entity-kmin-r@10", "entity-kmin-r@50", "entities_kmin_sum", "mAP"]]

# transform in percentage
metrics = metrics * 100

metrics

# latex = render_to_latex(
#     metrics, 
#     caption="Entities Retrieval",
#     clines="skip-last;data",
#     hrules=True,
#     column_format="llllcccccc",
#     convert_css=True
# )

# print(latex)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,entity-kmin-r@1,entity-kmin-r@5,entity-kmin-r@10,entity-kmin-r@50,entities_kmin_sum,mAP
model,finetuning,prompt_ensemble,original_names,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
idclip,disabled,0,0,15.1,8.9,12.3,27.6,63.9,2.8
idclip,disabled,1,0,15.9,9.1,12.4,28.0,65.4,2.9
idclip,disabled,1,1,22.9,11.9,13.3,29.5,77.6,3.8
idclip,shallow-vpt-5,0,0,20.4,11.7,14.7,30.7,77.4,3.8
idclip,shallow-vpt-5,1,0,19.3,11.8,14.6,31.7,77.3,3.8
idclip,shallow-vpt-5,1,1,22.3,13.5,16.4,31.7,83.9,4.3
baseline,,0,1,12.8,6.0,8.4,17.6,44.8,2.1
