In [2]:
import wandb
import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt

from dotenv import load_dotenv

load_dotenv("../.env")

True

In [3]:
########################################################################################
# query data from wandb

# add data to frames
api = wandb.Api()

data_description_map = {
    "vox2": "voxceleb2 (1M files shared between 5994 speakers)",
    "tiny-few": "tiny-few (50k files shared between 100 speakers)",
    "tiny-low": "tiny-low (8 files from 1 session for 5994 speakers)",
    "tiny-high": "tiny-high (8 files from 8 sessions for 5994 speakers)",
}


def load_runs(name: str):
    runs = api.runs(name)
    df = pd.DataFrame(columns=["num_steps", "eer", "data", "network"])

    for r in runs:
        tags = r.tags

        if r.state != "finished":
            continue

        if (
            "step_exp" in tags
        ):
            num_steps = r.config["trainer"]["max_steps"]
            eer = r.summary["test_eer_hard"]

            if eer >= 1:
                continue

            if "vox2_full" in tags:
                data = data_description_map["vox2"]
            elif "tiny_few" in tags:
                data = data_description_map["tiny-few"]
            elif "tiny_many_high" in tags:
                data = data_description_map["tiny-high"]
            elif "tiny_many_low" in tags:
                data = data_description_map["tiny-low"]
            elif "tiny_many" in tags:
                continue
            else:
                raise ValueError(f"undetermined dataset from {tags=}")

            df = pd.concat(
                [
                    df,
                    pd.DataFrame(
                        {
                            "num_steps": [num_steps],
                            "eer": [eer],
                            "data": [data],
                            "network": [name],
                        }
                    ),
                ],
                ignore_index=True,
            )

    return df


df_xvector = load_runs("xvector")
df_ecapa = load_runs("ecapa")
df_w2v2 = load_runs("wav2vec2")

df = pd.concat([df_xvector, df_ecapa, df_w2v2], ignore_index=True)

In [4]:
df

Unnamed: 0,num_steps,eer,data,network
0,25000,0.284384,tiny-low (8 files from 1 session for 5994 spea...,xvector
1,400000,0.241335,tiny-low (8 files from 1 session for 5994 spea...,xvector
2,400000,0.240692,tiny-low (8 files from 1 session for 5994 spea...,xvector
3,400000,0.237838,tiny-low (8 files from 1 session for 5994 spea...,xvector
4,100000,0.266305,tiny-low (8 files from 1 session for 5994 spea...,xvector
...,...,...,...,...
136,25000,0.224425,tiny-few (50k files shared between 100 speakers),wav2vec2
137,25000,0.077285,voxceleb2 (1M files shared between 5994 speakers),wav2vec2
138,25000,0.238626,tiny-low (8 files from 1 session for 5994 spea...,wav2vec2
139,25000,0.109838,tiny-high (8 files from 8 sessions for 5994 sp...,wav2vec2


In [5]:
df['network'].unique()


array(['xvector', 'ecapa', 'wav2vec2'], dtype=object)

In [6]:
df_grouped = df.groupby(by=["network", "data", "num_steps"])

In [7]:
df_agg = df_grouped.agg(
    num_steps=("num_steps", "first"),
    eer_min=("eer", "min"),
    eer_max=("eer", "max"),
    eer_mean=("eer", "mean"),
    eer_std=("eer", "std"),
    count=("eer", "count")
)

In [8]:
df_agg

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,num_steps,eer_min,eer_max,eer_mean,eer_std,count
network,data,num_steps,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
ecapa,tiny-few (50k files shared between 100 speakers),25000,25000,0.182483,0.18297,0.182727,0.000344,2
ecapa,tiny-few (50k files shared between 100 speakers),50000,50000,0.172436,0.180316,0.175923,0.004017,3
ecapa,tiny-few (50k files shared between 100 speakers),100000,100000,0.174317,0.175294,0.174806,0.000691,2
ecapa,tiny-few (50k files shared between 100 speakers),400000,400000,0.168415,0.170526,0.169471,0.001493,2
ecapa,tiny-high (8 files from 8 sessions for 5994 speakers),25000,25000,0.112034,0.11946,0.115067,0.003895,3
ecapa,tiny-high (8 files from 8 sessions for 5994 speakers),50000,50000,0.097068,0.0982,0.097801,0.000636,3
ecapa,tiny-high (8 files from 8 sessions for 5994 speakers),100000,100000,0.090172,0.092418,0.091222,0.00113,3
ecapa,tiny-high (8 files from 8 sessions for 5994 speakers),400000,400000,0.093005,0.094475,0.093552,0.000804,3
ecapa,tiny-low (8 files from 1 session for 5994 speakers),25000,25000,0.224367,0.229007,0.226721,0.002321,3
ecapa,tiny-low (8 files from 1 session for 5994 speakers),50000,50000,0.207617,0.216042,0.212518,0.004378,3


In [9]:
df_agg['count'].sum()

141