In [None]:
from pathlib import Path

import pandas as pd
import plotly.express as px

from misc import model_config

In [None]:
main_model_config = (
    model_config.query("main")
    .drop(columns="main")
    .rename(columns={k: f"model_{k}" for k in model_config.columns})
)

new_name = {
    "powermoe": "PowerMoE",
    "llamamoe": "LLaMA-MoE-v1",
    "olmoe": "OLMoE",
    "switch": "SwitchTransformers",
    "llamamoe2": "LLaMA-MoE-v2",
    "jetmoe": "JetMoE",
    "openmoe": "OpenMoE",
    "minicpm": "MiniCPM-MoE",
    "qwen": "Qwen1.5-MoE",
    "deepseek2": "DeepSeek-V2-Lite",
    "deepseek": "DeepSeekMoE",
    "xverse": "XVERSE-MoE",
    "qwen3": "Qwen3",
    "yuan": "Yuan2.0",
    "phi": "Phi-3.5-MoE",
    "grin": "GRIN-MoE",
    "mixtral": "Mixtral-8x7B",
    "jamba": "Jamba-Mini",
    "nllb": "NLLB-MoE",
    "qwen2": "Qwen2",
}

model_colors = {
    key: px.colors.qualitative.Dark24[i] for i, key in enumerate(main_model_config.index.values)
}

methods = ("LRU", "LFU", "Beladi")
method_colors = {key: px.colors.qualitative.Plotly[i] for i, key in enumerate(methods)}
main_model_config

In [None]:
def make_abbr(df):
    return (
        f"{df['model_abbr']}{'d' if df['is_decoder'] else 'e'}"
        if df["model_type"] == "seq2seq"
        else df["model_abbr"]
    )

In [None]:
root_dir = Path("../output/chr_mpq")

dfs = {
    p.stem: pd.merge(pd.read_parquet(p), main_model_config, left_on="model", right_index=True)
    for p in root_dir.glob("*.parquet")
}

for df in dfs.values():
    df["model"] = df["model"].astype(model_config.index.dtype)

dfs["m"]

In [None]:
dfs["m"].query("cache_m == 2").groupby(
    ["model", "is_decoder", "method"], as_index=False, observed=True
)[["recall"]].mean().pivot(
    index=["model", "is_decoder"], columns="method", values="recall"
).sort_values("LRU", ascending=False)

In [None]:
dfs["m"].query("cache_m == 2").groupby(
    ["model", "is_decoder", "method"], as_index=False, observed=True
)[["ci_lb", "ci_ub"]].mean().pivot(index=["model", "is_decoder"], columns="method").swaplevel(
    0, 1, axis=1
).sort_index(axis=1).sort_values(("LRU", "ci_lb"), ascending=False)

In [None]:
mdf = pd.merge(
    dfs["m"]
    .groupby(["model", "is_decoder", "method", "cache_m"], as_index=False, observed=True)[
        ["recall"]
    ]
    .mean(),
    main_model_config,
    left_on="model",
    right_index=True,
)

mdf

In [None]:
sch_dir = Path("../output/sch_mpq")

rdf = pd.merge(
    pd.read_parquet(sch_dir / "m.parquet"), main_model_config, left_on="model", right_index=True
)

rdf["model"] = rdf["model"].astype(model_config.index.dtype)
rdf

In [None]:
bdf = pd.merge(
    dfs["m"][["model", "is_decoder", "dataset", "cache_m", "method", "recall"]].rename(
        columns={"recall": "chr"}
    ),
    rdf[["model", "is_decoder", "dataset", "cache_m", "seg_len", "recall"]].rename(
        columns={"recall": "sch"}
    ),
)

bdf

In [None]:
bdf.groupby(["seg_len", "method"], observed=True)[["sch", "chr"]].corr().unstack(-1)[
    "sch", "chr"
].unstack(1)

In [None]:
badf = (
    bdf.query("model == 'grin' and is_decoder and method == 'Beladi'")
    .groupby(["cache_m", "seg_len"], as_index=False)[["sch", "chr"]]
    .mean()
)

badf.assign(ratio=(badf["sch"] / badf["chr"]).clip(0, 1)).pivot(
    columns="cache_m", index="seg_len", values="ratio"
)[[0.5, 1.0, 1.5, 2.0, 2.5, 3.0]]

In [None]:
bbdf = (
    dfs["m"]
    .groupby(["model", "is_decoder", "method", "cache_m"], observed=True, as_index=False)[
        ["recall"]
    ]
    .mean()
)

bmdf = bbdf.query("method != 'Beladi'").merge(
    bbdf.query("method == 'Beladi'").drop(columns="method").rename(columns={"recall": "baseline"})
)

bmdf.query("model == 'grin' and is_decoder").assign(
    ratio=(bmdf["recall"] / bmdf["baseline"])
).pivot(columns="cache_m", index="method", values="ratio")[[0.5, 1.0, 1.5, 2.0, 2.5, 3.0]]