In [1]:
import wandb
import pandas as pd
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import itertools

In [2]:
# project = "lr-pi"
# project = "gmm-pi"
# project = "ar-pi"
project = "og-pi"

# use_cache = False
use_cache = True

# agg_by = 'max'
# agg_by = 'min'
agg_by = 'mean'

proj2diff = {
    "lr-pi": "data.curriculum.dims.end",
    "gmm-pi": "data.num_classes",
    "ar-pi": "data.vocab_size",
    "og-pi": "data.seq_config.p_bursty",
}

proj2nicediff = {
    "lr-pi": "d=",
    "gmm-pi": "k=",
    "ar-pi": "|V|=",
    "og-pi": "P(bursty)=",
}

proj2metric = {
    "lr-pi": "loss",
    "gmm-pi": "eval-eval.acc",
    "ar-pi": "eval-eval.acc",
    "og-pi": "eval-eval.acc",
}

proj2nicemetric = {
    "lr-pi": "mean squared error",
    "gmm-pi": "accuracy",
    "ar-pi": "accuracy",
    "og-pi": "accuracy",
}

max_runs = 10000
states = ("finished",)
entity = "iceberg"
hist_cols = [
    proj2metric[project],
]
config_cols = [
    "model",
    proj2diff[project],
    "train.merge_type",
    "model.twrap_kwargs.use_abs_pos_emb",
]

In [3]:
def get_wandb_runs(entity, project):
    api = wandb.Api()
    runs = api.runs(f"{entity}/{project}")
    return runs

def build_fp(**kwargs):
    return "_".join([f"{k}={v}" for k, v in kwargs.items()]) + ".pkl"

def export_wandb_project(
    entity,
    project,
    config_cols,
    hist_cols,
    max_runs=None,
    states=("finished"),
    use_cache=False,
):
    fp = build_fp(entity=entity, project=project)
    if use_cache:
        try:
            return pd.read_pickle(fp)
        except FileNotFoundError:
            print(f"Cache file not found: {fp}. Exporting from wandb.")
            pass

    runs = get_wandb_runs(entity, project)
    run_data = []

    for run in tqdm(runs, desc="Exporting run data"):
        if (max_runs is not None) and (len(run_data) >= max_runs):
            break

        if run.state not in states:
            continue

        id = {"run_id": run.id}
        config = {k: run.config.get(k) for k in config_cols}
        config.update(id)
        hist = run.history(keys=hist_cols)

        # Combine all information and history data
        # combined_data = pd.concat([id, config, hist], axis=1)
        combined_data = hist.assign(**config)

        # Append the combined data to run_data list
        run_data.append(combined_data)

    # Combine all run data into a single DataFrame
    all_run_data = pd.concat(run_data, ignore_index=True)

    all_run_data.to_pickle(fp)
    return all_run_data


df = export_wandb_project(
    entity=entity,
    project=project,
    config_cols=config_cols,
    hist_cols=hist_cols,
    max_runs=max_runs,
    states=states,
    use_cache=use_cache,
)
df.head()

Unnamed: 0,_step,eval-eval.acc,model,data.seq_config.p_bursty,train.merge_type,model.twrap_kwargs.use_abs_pos_emb,run_id
0,0,0.509,x-decoder,0.9,sum,False,u1fqj5hs
1,5000,0.937,x-decoder,0.9,sum,False,u1fqj5hs
2,10000,0.904,x-decoder,0.9,sum,False,u1fqj5hs
3,15000,0.776,x-decoder,0.9,sum,False,u1fqj5hs
4,20000,0.618,x-decoder,0.9,sum,False,u1fqj5hs


In [4]:
# only keep last _step of each run_id
df = df.sort_values(by=["run_id", "_step"]).groupby("run_id").tail(1)
df

Unnamed: 0,_step,eval-eval.acc,model,data.seq_config.p_bursty,train.merge_type,model.twrap_kwargs.use_abs_pos_emb,run_id
1022,50000,0.947,x-encoder,1.0,concat,True,02icv4jk
494,50000,0.792,x-decoder,0.9,sum,True,0i8nbr77
1099,50000,0.930,x-decoder,1.0,sum,False,0ianx7i2
835,50000,0.928,x-encoder,1.0,sum,True,0icr7wex
516,50000,0.581,x-decoder,0.9,sum,True,0sncpsjk
...,...,...,...,...,...,...,...
747,50000,0.932,x-encoder,1.0,sum,False,yrf4rutf
538,50000,0.599,x-decoder,0.9,sum,True,yx7gmqn7
791,50000,0.903,x-encoder,1.0,sum,False,yxg6clxp
461,50000,0.942,x-decoder,0.9,concat,True,zkw34mv7


In [5]:
# # summarize with a pivot table
# df_pivot = df.pivot_table(
#     index=config_cols,
#     values=proj2metric[project],
#     aggfunc=agg_by,
# )
# df_pivot

In [6]:
# rename 'x-decoder' to 'decoder-only' and 'x-encoder' to 'encoder-only' in the model column
df["model"] = df["model"].str.replace("x-decoder", "decoder")
df["model"] = df["model"].str.replace("x-encoder", "encoder")

In [7]:
# # add subtotals to the pivot table
# df_pivot = df_pivot.groupby(level=[0, 1, 2]).apply(
#     lambda x: x.append(
#         x.sum(numeric_only=True).rename("Total")
#     )
# )
# df_pivot

In [8]:
# groupby config_cols and average over proj2metric
a = [
    "train.merge_type",
    "model.twrap_kwargs.use_abs_pos_emb",
    "model",
    proj2diff[project],
]
b = proj2metric[project]
df = df.groupby(a).agg({b: agg_by})
# sort df by proj2diff
# df = df.sort_values(by=[proj2diff[project]])
# round to 3 decimal places
df = df.round(3)
df

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,eval-eval.acc
train.merge_type,model.twrap_kwargs.use_abs_pos_emb,model,data.seq_config.p_bursty,Unnamed: 4_level_1
concat,False,decoder,0.9,0.842
concat,False,decoder,1.0,0.928
concat,False,encoder,0.9,0.853
concat,False,encoder,1.0,0.928
concat,True,decoder,0.9,0.914
concat,True,decoder,1.0,0.933
concat,True,encoder,0.9,0.915
concat,True,encoder,1.0,0.931
sum,False,decoder,0.9,0.68
sum,False,decoder,1.0,0.879


In [9]:
# rename columns
old2new = {
    "train.merge_type": "token scheme",
    "model.twrap_kwargs.use_abs_pos_emb": "use pos embed",
    "model": "model",
    proj2diff[project]: proj2nicediff[project],
    proj2metric[project]: proj2nicemetric[project],
}
# old2new.pop(proj2diff[project], None)
# df = df.reset_index().rename(columns=old2new).set_index(list(old2new.values()))
df

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,eval-eval.acc
train.merge_type,model.twrap_kwargs.use_abs_pos_emb,model,data.seq_config.p_bursty,Unnamed: 4_level_1
concat,False,decoder,0.9,0.842
concat,False,decoder,1.0,0.928
concat,False,encoder,0.9,0.853
concat,False,encoder,1.0,0.928
concat,True,decoder,0.9,0.914
concat,True,decoder,1.0,0.933
concat,True,encoder,0.9,0.915
concat,True,encoder,1.0,0.931
sum,False,decoder,0.9,0.68
sum,False,decoder,1.0,0.879


In [10]:
# rename columns and indices
old2new = {
    "train.merge_type": "token scheme",
    "model.twrap_kwargs.use_abs_pos_emb": "use pos embed",
    "model": "model",
    proj2diff[project]: proj2nicediff[project],
    proj2metric[project]: proj2nicemetric[project],
}

# reset index to rename index
df = df.reset_index()
df = df.rename(columns=old2new)
df

Unnamed: 0,token scheme,use pos embed,model,P(bursty)=,accuracy
0,concat,False,decoder,0.9,0.842
1,concat,False,decoder,1.0,0.928
2,concat,False,encoder,0.9,0.853
3,concat,False,encoder,1.0,0.928
4,concat,True,decoder,0.9,0.914
5,concat,True,decoder,1.0,0.933
6,concat,True,encoder,0.9,0.915
7,concat,True,encoder,1.0,0.931
8,sum,False,decoder,0.9,0.68
9,sum,False,decoder,1.0,0.879


In [11]:
# groupby config_cols and average over proj2metric
a = [
    "token scheme",
    "use pos embed",
    "model",
    proj2nicediff[project],
]
b = proj2nicemetric[project]
df = df.groupby(a).agg({b: agg_by})
df

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,accuracy
token scheme,use pos embed,model,P(bursty)=,Unnamed: 4_level_1
concat,False,decoder,0.9,0.842
concat,False,decoder,1.0,0.928
concat,False,encoder,0.9,0.853
concat,False,encoder,1.0,0.928
concat,True,decoder,0.9,0.914
concat,True,decoder,1.0,0.933
concat,True,encoder,0.9,0.915
concat,True,encoder,1.0,0.931
sum,False,decoder,0.9,0.68
sum,False,decoder,1.0,0.879


In [12]:
# use group named 'd' as columns
df = df.unstack(-1)
df


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,accuracy,accuracy
Unnamed: 0_level_1,Unnamed: 1_level_1,P(bursty)=,0.9,1.0
token scheme,use pos embed,model,Unnamed: 3_level_2,Unnamed: 4_level_2
concat,False,decoder,0.842,0.928
concat,False,encoder,0.853,0.928
concat,True,decoder,0.914,0.933
concat,True,encoder,0.915,0.931
sum,False,decoder,0.68,0.879
sum,False,encoder,0.666,0.855
sum,True,decoder,0.678,0.927
sum,True,encoder,0.67,0.886
