In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import plotly.express as px

In [None]:
# df = pd.read_parquet("metrics_per_sample.parquet")

df = pd.read_parquet(
    # "/lustre/scwpod02/client/kyutai/vaclav/data/tts_longform_debug/2/metrics.parquet"
    "/lustre/scwpod02/client/kyutai/vaclav/data/tts_longform_debug/2/metrics_0930_15000.parquet"
)
df["gender"] = df["tags"].apply(lambda x: "f" if "gender=f" in x else "m")
df = df.drop(
    columns=[
        "tags",
        "file",
        "language",
        # "nn",
        # "n_q0.25",
        # "n_q0.5",
        # "n_q0.75",
        # "n_q1.0",
    ]
)
# Reset index but also change its order to match the file paths
df = df.reset_index().set_index(["dataset", "method", "sample_id"]).reset_index()

In [None]:
df

In [None]:
quantiles = [0.25, 0.5, 0.75, 1.0]
wer_cols = [f"w_q{q}" for q in quantiles]
sim_cols = [f"s_q{q}" for q in quantiles]
drift_cols = [f"d_q{q}" for q in quantiles]

# Build a list of DataFrames for each quantile
dfs = []
for q, w_col, s_col, d_col in zip(quantiles, wer_cols, sim_cols, drift_cols):
    temp = df.copy()
    temp["quantile"] = q
    temp["wer"] = temp[w_col]
    temp["sim"] = temp[s_col]
    temp["drift"] = temp[d_col]
    dfs.append(temp)

# Concatenate and select relevant columns
df_quantile = pd.concat(dfs)
df_quantile = df_quantile.drop(columns=wer_cols + sim_cols + drift_cols)
df_quantile

In [None]:
df_quantile.head().to_json()

In [None]:
import numpy as np

metric = "drift"

df_bar = (
    df_quantile.groupby(["quantile", "method"])
    .agg(
        wer_mean=("wer", "mean"),
        wer_ste=("wer", lambda x: np.std(x, ddof=1) / np.sqrt(len(x))),
        sim_mean=("sim", "mean"),
        sim_ste=("sim", lambda x: np.std(x, ddof=1) / np.sqrt(len(x))),
        drift_mean=("drift", "mean"),
        drift_ste=("drift", lambda x: np.std(x, ddof=1) / np.sqrt(len(x))),
    )
    .reset_index()
)

fig = px.bar(
    df_bar.query("quantile.isin([0.25, 1.0])"),
    x="quantile",
    y="wer_mean",
    color="method",
    barmode="group",
    error_y="wer_ste",
)
fig.show()

In [None]:
fig = px.bar(
    df_bar.query("quantile.isin([0.25, 1.0])"),
    x="quantile",
    y="drift_mean",
    color="method",
    barmode="group",
    error_y="drift_ste",
)
fig.show()

In [None]:
fig = px.bar(
    df_bar.query("quantile.isin([0.25, 1.0])"),
    x="quantile",
    y="sim_mean",
    color="method",
    barmode="group",
    error_y="sim_ste",
)
fig.show()

In [None]:
px.bar(
    df_quantile.query("dataset == 'wikibooks_15000_en'"),
    x="quantile",
    y="drift",
    color="method",
)

In [None]:
dfc = df_quantile.reset_index().drop(columns=["gender"])
dfc = dfc.groupby(["method", "quantile", "dataset"]).agg(
    {"wer": "mean", "sim": "mean", "drift": "mean"}
)
dfc = dfc.reset_index()
dfc

In [None]:
px.line(
    dfc,
    x="quantile",
    y="wer",
    color="method",
    facet_col="dataset",
    title="WER per quantile",
)

In [None]:
px.bar(
    dfc.query("quantile == 1"),
    x="dataset",
    color="method",
    barmode="group",
    y="wer",
    # facet_col="dataset",
    title="WER of last quantile",
)

In [None]:
px.line(
    dfc,
    x="quantile",
    y="sim",
    color="method",
    facet_col="dataset",
    title="Speaker sim per quantile",
)

In [None]:
px.line(
    dfc,
    x="quantile",
    y="drift",
    color="method",
    facet_col="dataset",
    title="Drift per quantile",
)

In [None]:
px.scatter(
    dfc.loc[dfc["quantile"].isin([1.0])],
    x="drift",
    y="sim",
    color="method",
    facet_col="dataset",
    # title="Speaker sim per quantile",
)

## Sample-level


In [None]:
dfc = df.loc[df["sample_id"] == "en_speaker_5_text_21"]
dfc

In [None]:
df

In [None]:
dfd = df.reset_index()
dfd = dfd.loc[dfd["method"].isin(["16s_context", "opensourced"])]
# dfd = dfd.query("dataset == 'wikibooks_fr'")

In [None]:
px.histogram(
    dfd.loc[dfd["w_q1.0"] < 1],
    x="w_q1.0",
    color="method",
    barmode="group",
    facet_col="dataset",
)

In [None]:
px.histogram(
    dfd,
    x="s_q1.0",
    color="method",
    barmode="group",
    facet_col="dataset",
)

In [None]:
df

In [None]:
df.sort_values("s_q1.0").head(30)  # .query("method == '16s_context'")