In [None]:
from pathlib import Path

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from misc import model_config
from plotly.subplots import make_subplots

In [None]:
main_model_config = (
    model_config.query("main")
    .drop(columns="main")
    .rename(columns={k: f"model_{k}" for k in model_config.columns})
)

new_name = {
    "powermoe": "PowerMoE",
    "llamamoe": "LLaMA-MoE-v1",
    "olmoe": "OLMoE",
    "switch": "SwitchTransformers",
    "llamamoe2": "LLaMA-MoE-v2",
    "jetmoe": "JetMoE",
    "openmoe": "OpenMoE",
    "minicpm": "MiniCPM-MoE",
    "qwen": "Qwen1.5-MoE",
    "deepseek2": "DeepSeek-V2-Lite",
    "deepseek": "DeepSeekMoE",
    "xverse": "XVERSE-MoE",
    "qwen3": "Qwen3",
    "yuan": "Yuan2.0",
    "phi": "Phi-3.5-MoE",
    "grin": "GRIN-MoE",
    "mixtral": "Mixtral-8x7B",
    "jamba": "Jamba-Mini",
    "nllb": "NLLB-MoE",
    "qwen2": "Qwen2",
}

model_colors = {
    key: px.colors.qualitative.Dark24[i] for i, key in enumerate(main_model_config.index.values)
}

seg_lens = (4, 16, 64, 256)
seg_len_colors = {key: px.colors.qualitative.Plotly[i] for i, key in enumerate(seg_lens)}
main_model_config

In [None]:
def make_abbr(df):
    return (
        f"{df['model_abbr']}{'d' if df['is_decoder'] else 'e'}"
        if df["model_type"] == "seq2seq"
        else df["model_abbr"]
    )

In [None]:
root_dir = Path("../output/sch_mpq")

dfs = {
    p.stem: pd.merge(pd.read_parquet(p), main_model_config, left_on="model", right_index=True)
    for p in root_dir.glob("*.parquet")
}

for df in dfs.values():
    df["model"] = df["model"].astype(model_config.index.dtype)

dfs["m"]

In [None]:
dfs["m"].query("cache_m == 2").groupby(
    ["model", "is_decoder", "seg_len"], as_index=False, observed=True
)[["recall"]].mean().pivot(
    index=["model", "is_decoder"], columns="seg_len", values="recall"
).sort_values(16, ascending=False)

In [None]:
dfs["m"].query("cache_m == 2").groupby(
    ["model", "is_decoder", "seg_len"], as_index=False, observed=True
)[["ci_lb", "ci_ub"]].mean().pivot(index=["model", "is_decoder"], columns="seg_len").swaplevel(
    0, 1, axis=1
).sort_index(axis=1).sort_values((16, "ci_lb"), ascending=False).multiply(100).round(2)

In [None]:
dfs["m"].assign(
    ci_dist=np.maximum(
        dfs["m"]["ci_ub"] - dfs["m"]["recall"], dfs["m"]["recall"] - dfs["m"]["ci_lb"]
    )
).groupby(["model", "is_decoder", "dataset", "seg_len"], as_index=False, observed=True)[
    ["ci_dist"]
].max().pivot(index=["model", "is_decoder"], columns=["dataset", "seg_len"], values="ci_dist")

In [None]:
mdf = pd.merge(
    dfs["m"]
    .groupby(["model", "is_decoder", "seg_len", "cache_m"], as_index=False, observed=True)[
        ["recall"]
    ]
    .mean(),
    main_model_config,
    left_on="model",
    right_index=True,
)

mdf

In [None]:
dash_level = {
    "llamamoe2": 0,
    "yuan": 0,
    "powermoe": 0,
    "qwen3": 0,
    "phi": 0,
    "olmoe": 0,
    "grin": 0,
    "mixtral": 1,
    "minicpm": 1,
    "jetmoe": 1,
    "llamamoe": 1,
    "xverse": 2,
    "jamba": 2,
    "deepseek2": 2,
    "deepseek": 2,
    "qwen2": 2,
    "nllb": 3,
    "qwen": 3,
    "openmoe": 3,
    "switch": 3,
}

dash_style = ["solid", "longdash", "dashdot", "dot"]
dash_type = {k: dash_style[v] for k, v in dash_level.items()}

In [None]:
fig = make_subplots(rows=1, cols=len(seg_lens), shared_xaxes="all", horizontal_spacing=0.01)
font_size = {"tick": 20, "legend": 18, "title": 24}
show_legend = True

for i, seg_len in enumerate(seg_lens):
    col = i + 1

    for j, key in enumerate(main_model_config.index.values):
        tmpdf = mdf.query(f"model == '{key}' and seg_len == {seg_len}")

        if model_config.loc[key, "type"] == "seq2seq":
            for is_decoder in (False, True):
                subdf = tmpdf.query(f"is_decoder == {is_decoder}")

                fig.add_scatter(
                    x=subdf["cache_m"],
                    y=subdf["recall"],
                    hoverinfo="skip",
                    legendgroup=key,
                    line=go.scatter.Line(color=model_colors[key], dash=dash_type[key], width=2),
                    mode="lines",
                    name=f"{new_name[key]} ({'De' if is_decoder else 'En'}coder)",
                    opacity=1 if is_decoder else 0.5,
                    showlegend=show_legend,
                    row=1,
                    col=col,
                )
        else:
            fig.add_scatter(
                x=tmpdf["cache_m"],
                y=tmpdf["recall"],
                hoverinfo="skip",
                legendgroup=key,
                line=go.scatter.Line(color=model_colors[key], dash=dash_type[key], width=2),
                mode="lines",
                name=new_name[key],
                showlegend=show_legend,
                row=1,
                col=col,
            )

    fig.add_annotation(
        xref="x domain",
        yref="y domain",
        x=0.05,
        y=0.95,
        ax=0,
        ay=0,
        text=f"$m={seg_len}$",
        showarrow=False,
        font=go.layout.annotation.Font(size=font_size["title"], shadow="auto"),
        row=1,
        col=col,
    )

    fig.update_xaxes(
        range=[0, 4],
        tickfont=go.layout.xaxis.Tickfont(size=font_size["tick"]),
        title=go.layout.xaxis.Title(
            font=go.layout.xaxis.title.Font(size=font_size["title"]), text=r"$\rho$"
        ),
        row=1,
        col=col,
    )

    fig.update_yaxes(range=[0, 1], showticklabels=col == 1, row=1, col=col)

    if col == 1:
        fig.update_yaxes(
            tickfont=go.layout.yaxis.Tickfont(size=font_size["tick"]),
            title=go.layout.yaxis.Title(
                font=go.layout.yaxis.title.Font(size=font_size["title"]), text="SCH"
            ),
            row=1,
            col=col,
        )

    show_legend = False

fig.update_layout(
    legend=go.layout.Legend(
        font=go.layout.legend.Font(size=font_size["legend"]),
        itemsizing="constant",
        orientation="h",
        y=-0.15,
        yanchor="top",
    ),
    margin=go.layout.Margin(l=60, r=30, t=0, b=30),
    width=1800,
    height=550,
)

fig.write_image("../plot/msch.pdf", width=fig.layout.width, height=fig.layout.height)
fig.show()

In [None]:
srp_dir = Path("../output/srp_mpq")

srp_dfs = {
    p.stem: pd.merge(pd.read_parquet(p), main_model_config, left_on="model", right_index=True)
    for p in srp_dir.glob("*.parquet")
}

sorted_model_keys = (
    srp_dfs["mg"]
    .query("is_decoder and seg_len == 16")
    .sort_values("best_f1", ascending=False)["model"]
)

pd.merge(
    dfs["l"][["model", "is_decoder", "layer_idx", "dataset", "seg_len", "cache_m", "recall"]],
    srp_dfs["ld"][["model", "is_decoder", "layer_idx", "dataset", "seg_len", "best_f1"]],
).groupby(["seg_len", "cache_m"])[["recall", "best_f1"]].corr().unstack(2)[
    "recall", "best_f1"
].unstack(1)[[0.5, 1.0, 1.5, 2.0, 2.5, 3.0]].multiply(100).round(2)

In [None]:
ldf = pd.merge(
    dfs["l"]
    .groupby(
        ["model", "is_decoder", "layer_idx", "seg_len", "cache_m"], as_index=False, observed=True
    )[["recall"]]
    .mean(),
    main_model_config,
    left_on="model",
    right_index=True,
)

ldf

In [None]:
num_cols = 10
num_rows = (len(main_model_config) - 1) // num_cols + 1

fig = make_subplots(
    rows=num_rows,
    cols=num_cols,
    shared_xaxes="all",
    shared_yaxes="all",
    horizontal_spacing=0.005,
    vertical_spacing=0.11,
    subplot_titles=sorted_model_keys.map(new_name).values,
)

font_size = {"tick": 16, "legend": 18, "title": 20}
show_legend = True

for i, key in enumerate(sorted_model_keys):
    if (ldf["model"] == key).sum() == 0:
        continue

    row = i // num_cols + 1
    col = i % num_cols + 1
    num_layers = model_config.loc[key, "num_layers"]

    for seg_len in seg_lens:
        for j, cache_m in enumerate((1, 2)):
            tmpdf = ldf.query(f"model == '{key}' and seg_len == {seg_len} and cache_m == {cache_m}")

            if len(tmpdf) == 0:
                continue

            layer_idx = tmpdf["layer_idx"] + 1
            if model_config.loc[key, "type"] == "seq2seq":
                layer_idx += tmpdf["is_decoder"] * model_config.loc[key, "num_layers"] // 2

            line_name = f"$\\mathrm{{SCH}}(m={seg_len},\\,\\rho={cache_m})$"

            fig.add_scatter(
                x=layer_idx / num_layers,
                y=tmpdf["recall"],
                hoverinfo="skip",
                legendgroup=line_name,
                line=go.scatter.Line(color=seg_len_colors[seg_len], dash="dot" if j else "solid"),
                marker=go.scatter.Marker(size=4),
                mode="lines" if j else "lines+markers",
                name=line_name,
                opacity=0.5 if j else 1,
                showlegend=show_legend,
                row=row,
                col=col,
            )

    show_legend = False

    fig.update_xaxes(
        showticklabels=True,
        tickfont=go.layout.xaxis.Tickfont(size=font_size["tick"]),
        tickvals=[1 / num_layers, 0.5, 1],
        ticktext=[1, num_layers // 2, num_layers],
        row=row,
        col=col,
    )

    if row == num_rows:
        fig.update_xaxes(
            title=go.layout.xaxis.Title(
                font=go.layout.xaxis.title.Font(size=font_size["title"]), standoff=1, text="Layer"
            ),
            row=row,
            col=col,
        )

    fig.update_yaxes(showticklabels=col == 1, row=row, col=col)

    if col == 1:
        fig.update_yaxes(
            tickfont=go.layout.yaxis.Tickfont(size=font_size["tick"]),
            title=go.layout.yaxis.Title(
                font=go.layout.yaxis.title.Font(size=font_size["title"]), text="SCH"
            ),
            row=row,
            col=col,
        )

fig.update_annotations(font=go.layout.annotation.Font(size=font_size["title"]))

fig.update_layout(
    legend=go.layout.Legend(
        font=go.layout.legend.Font(size=font_size["legend"]),
        orientation="h",
        x=0.47,
        xanchor="center",
    ),
    margin=go.layout.Margin(l=60, r=15, t=30, b=90),
    width=2000,
    height=500,
)

fig.write_image("../plot/lsch.pdf", width=fig.layout.width, height=fig.layout.height)
fig.show()