In [1]:
import wandb
import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt

from dotenv import load_dotenv

load_dotenv("../.env")

True

In [2]:
########################################################################################
# query data from wandb

# add data to frames
api = wandb.Api()

data_description_map = {
    "tiny-few": "tiny-few (50k files shared between 100 speakers)",
    "tiny-high": "tiny-high (8 files from 8 sessions for 5994 speakers)",
}


def load_runs(name: str):
    runs = api.runs(name)
    df = pd.DataFrame(columns=["eer", "data", "ablation"])

    for r in runs:
        tags = r.tags

        if r.state != "finished":
            continue

        eer = r.summary["test_eer_hard"]

        if "tiny_few" in tags:
            tags.remove("tiny_few")
            data = "tiny_few"
        elif "tiny_many_high" in tags:
            tags.remove("tiny_many_high")
            data = "tiny_many_high"
        else:
            raise ValueError(f"undetermined dataset from {tags=}")

        ablation = tags[0]

        df = pd.concat(
            [
                df,
                pd.DataFrame(
                    {
                        "ablation": [ablation],
                        "eer": [eer],
                        "data": [data],
                    }
                ),
            ],
            ignore_index=True,
        )

    return df


df = load_runs("wav2vec2-ablation")

In [3]:
df

Unnamed: 0,eer,data,ablation
0,0.077306,tiny_many_high,reg_mask
1,0.08045,tiny_many_high,reg_dropout
2,0.06733,tiny_many_high,reg_layerdrop
3,0.462065,tiny_many_high,weights_random_init
4,0.077615,tiny_many_high,reg_none
...,...,...,...
68,0.067629,tiny_many_high,baseline
69,0.168396,tiny_few,weights_xlsr
70,0.066954,tiny_many_high,weights_xlsr
71,0.14854,tiny_few,baseline


In [4]:
df_grouped = df.groupby(by=["data", "ablation",])

In [5]:
df_agg = df_grouped.agg(
    eer_min=("eer", "min"),
    eer_max=("eer", "max"),
    eer_mean=("eer", "mean"),
    eer_std=("eer", "std"),
    count=("eer", "count")
)

In [6]:
df_agg

Unnamed: 0_level_0,Unnamed: 1_level_0,eer_min,eer_max,eer_mean,eer_std,count
data,ablation,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
tiny_few,baseline,0.14854,0.151224,0.150236,0.001475,3
tiny_few,lr_1_cycle,0.155269,0.159669,0.157871,0.002308,3
tiny_few,lr_constant,0.1658,0.170624,0.167673,0.002586,3
tiny_few,lr_exp_decay,0.165158,0.169016,0.166832,0.001979,3
tiny_few,reg_dropout,0.165495,0.167647,0.166676,0.001091,3
tiny_few,reg_layerdrop,0.148761,0.151634,0.150259,0.00144,3
tiny_few,reg_mask,0.160021,0.163214,0.162132,0.001828,3
tiny_few,reg_none,0.163649,0.168333,0.166722,0.002662,3
tiny_few,weights_freeze_cycle,0.14398,0.145876,0.145224,0.001078,3
tiny_few,weights_freeze_cycle_cnn,0.152798,0.157593,0.154793,0.002497,3


In [7]:
df_agg['count'].sum()

73