In [None]:
from pathlib import Path

import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from misc import model_config

In [None]:
main_model_config = (
    model_config.query("main")
    .drop(columns="main")
    .rename(columns={k: f"model_{k}" for k in model_config.columns})
)

new_name = {
    "powermoe": "PowerMoE",
    "llamamoe": "LLaMA-MoE-v1",
    "olmoe": "OLMoE",
    "switch": "SwitchTransformers",
    "llamamoe2": "LLaMA-MoE-v2",
    "jetmoe": "JetMoE",
    "openmoe": "OpenMoE",
    "minicpm": "MiniCPM-MoE",
    "qwen": "Qwen1.5-MoE",
    "deepseek2": "DeepSeek-V2-Lite",
    "deepseek": "DeepSeekMoE",
    "xverse": "XVERSE-MoE",
    "qwen3": "Qwen3",
    "yuan": "Yuan2.0",
    "phi": "Phi-3.5-MoE",
    "grin": "GRIN-MoE",
    "mixtral": "Mixtral-8x7B",
    "jamba": "Jamba-Mini",
    "nllb": "NLLB-MoE",
    "qwen2": "Qwen2",
}

model_colors = {
    key: px.colors.qualitative.Dark24[i] for i, key in enumerate(main_model_config.index.values)
}

seg_lens = (4, 16, 64, 256)
seg_len_colors = {key: px.colors.qualitative.Plotly[i] for i, key in enumerate(seg_lens)}
main_model_config

In [None]:
def make_abbr(df):
    return (
        f"{df['model_abbr']}{'d' if df['is_decoder'] else 'e'}"
        if df["model_type"] == "seq2seq"
        else df["model_abbr"]
    )

In [None]:
root_dir = Path("../output/srp_mpq")

dfs = {
    p.stem: pd.merge(pd.read_parquet(p), main_model_config, left_on="model", right_index=True)
    for p in root_dir.glob("*.parquet")
}

for df in dfs.values():
    df["model"] = df["model"].astype(model_config.index.dtype)

sorted_model_keys = (
    dfs["mg"].query("is_decoder and seg_len == 16").sort_values("best_f1", ascending=False)["model"]
)

dfs["mg"].pivot(index=["model", "is_decoder"], columns="seg_len", values=["best_f1", "best_m"])

In [None]:
vocab_dir = Path("../output/vocab_pq")
vdf = pd.read_parquet(vocab_dir / "gen.parquet")
vdf

In [None]:
sample_seg_len = 16
sedf = dfs["eg"].query(f"seg_len == {sample_seg_len}").drop(columns="seg_len")
sedf

In [None]:
dedf = pd.merge(
    sedf,
    dfs["ed"]
    .query(f"seg_len == {sample_seg_len} and act_r > 1e-5")
    .drop(columns="seg_len")
    .groupby(["model", "is_decoder", "layer_idx", "expert_idx"], as_index=False, observed=True)
    .apply(
        lambda df: pd.Series({"act_r_cv": df["act_r"].std() / df["act_r"].mean()}),
        include_groups=False,
    ),
)

dedf

In [None]:
vedf = pd.merge(
    sedf,
    vdf.pivot(index=["model", "layer_idx", "expert_idx"], columns="token_type", values="freq")
    .rename(columns={k: f"{k}_freq" for k in ("in", "pred", "out")})
    .reset_index(),
)

vedf

In [None]:
smdf = dfs["mg"].query(f"seg_len == {sample_seg_len}").drop(columns="seg_len")
smdf.sort_values("best_f1", ascending=False)

In [None]:
dmcdf = pd.merge(
    smdf,
    dedf.groupby(["model", "is_decoder"], as_index=False, observed=True).apply(
        lambda df: pd.Series(
            {"ds": df["act_r_cv"].mean(), "corr": df["best_f1"].corr(df["act_r_cv"])}
        ),
        include_groups=False,
    ),
)

dmcdf.sort_values("best_f1", ascending=False)

In [None]:
vmcdf = pd.merge(
    smdf,
    vedf.groupby(["model", "is_decoder"], as_index=False, observed=True).apply(
        lambda df: pd.concat(
            [
                pd.Series(
                    {
                        f"{k}_vs": df[f"{k}_freq"].mean(),
                        f"{k}_corr": df["best_f1"].corr(df[f"{k}_freq"]),
                    }
                )
                for k in ("in", "pred", "out")
            ]
        ),
        include_groups=False,
    ),
)

vmcdf.sort_values("best_f1", ascending=False)

In [None]:
fig = make_subplots(
    rows=2,
    cols=2,
    shared_xaxes="all",
    shared_yaxes="all",
    horizontal_spacing=0.01,
    vertical_spacing=0.02,
)

text_pos = [
    {
        "powermoe": [(20, -20)],
        "llamamoe": [(-20, -20)],
        "olmoe": [(0, 30)],
        "switch": [(35, 0), (35, 0)],
        "llamamoe2": [(-25, -25)],
        "jetmoe": [(-20, 20)],
        "openmoe": [(20, 20)],
        "minicpm": [(-20, -20)],
        "qwen": [(-20, -20)],
        "deepseek2": [(0, 25)],
        "deepseek": [(-30, 0)],
        "xverse": [(30, 0)],
        "qwen3": [(25, 25)],
        "yuan": [(25, -25)],
        "phi": [(25, -25)],
        "grin": [(-25, 25)],
        "mixtral": [(-20, -20)],
        "jamba": [(-20, -20)],
        "nllb": [(25, 25), (-25, -25)],
        "qwen2": [(-20, -20)],
    },
    {
        "powermoe": [(20, -20)],
        "llamamoe": [(-20, -20)],
        "olmoe": [(20, -20)],
        "llamamoe2": [(-20, -20)],
        "jetmoe": [(20, -20)],
        "openmoe": [(-25, -25)],
        "minicpm": [(-20, -20)],
        "qwen": [(-20, -20)],
        "deepseek2": [(20, -20)],
        "deepseek": [(30, 0)],
        "xverse": [(-20, 20)],
        "qwen3": [(20, 20)],
        "yuan": [(20, -20)],
        "phi": [(20, 20)],
        "grin": [(-20, 20)],
        "mixtral": [(20, -20)],
        "jamba": [(-20, -20)],
        "qwen2": [(20, 20)],
    },
    {
        "powermoe": [(20, 20)],
        "llamamoe": [(20, 20)],
        "olmoe": [(20, -20)],
        "llamamoe2": [(-20, -20)],
        "jetmoe": [(-20, -20)],
        "openmoe": [(-20, -20)],
        "minicpm": [(20, 20)],
        "qwen": [(-20, -20)],
        "deepseek2": [(20, 20)],
        "deepseek": [(-20, 20)],
        "xverse": [(20, 20)],
        "qwen3": [(20, -20)],
        "yuan": [(20, -20)],
        "phi": [(20, 20)],
        "grin": [(-20, -20)],
        "mixtral": [(20, 20)],
        "jamba": [(20, -20)],
        "qwen2": [(-20, 20)],
    },
    {
        "powermoe": [(20, 20)],
        "llamamoe": [(20, 20)],
        "olmoe": [(20, -20)],
        "llamamoe2": [(-20, -20)],
        "jetmoe": [(-20, -20)],
        "openmoe": [(-20, -20)],
        "minicpm": [(20, 20)],
        "qwen": [(-20, -20)],
        "deepseek2": [(20, 20)],
        "deepseek": [(-20, 20)],
        "xverse": [(20, 20)],
        "qwen3": [(20, -20)],
        "yuan": [(20, -20)],
        "phi": [(20, 20)],
        "grin": [(-20, -20)],
        "mixtral": [(20, 20)],
        "jamba": [(20, -20)],
        "qwen2": [(-20, 20)],
    },
]

font_size = {"label": 18, "tick": 20, "legend": 20, "title": 24}

for j, key in enumerate(main_model_config.index.values):
    tmpdf = dmcdf.query(f"model == '{key}'")
    if len(tmpdf) == 0:
        continue

    px = tmpdf["best_f1"]
    py = tmpdf["corr"]
    ps = tmpdf["ds"] ** 0.5 * 30
    pc = model_colors[key]
    pt = tmpdf.apply(make_abbr, axis=1)
    pa = text_pos[0][key]

    fig.add_scatter(
        x=px,
        y=py,
        hoverinfo="skip",
        marker=go.scatter.Marker(
            color=pc, line=go.scatter.marker.Line(color="white", width=1), opacity=0.7, size=ps
        ),
        legendgroup=key,
        mode="markers",
        name=new_name[key],
        showlegend=True,
        zorder=100 - j,
        row=1,
        col=1,
    )

    for k in range(px.shape[0]):
        fig.add_annotation(
            x=px.iloc[k],
            y=py.iloc[k],
            arrowcolor=pc,
            ax=pa[k][0],
            ay=pa[k][1],
            text=pt.iloc[k],
            standoff=max(ps.iloc[k] / 2 - 1, 0),
            showarrow=True,
            font=go.layout.annotation.Font(size=font_size["label"], shadow="auto"),
            row=1,
            col=1,
        )

    if j == 0:
        fig.add_annotation(
            xref="x domain",
            yref="y domain",
            x=0.95,
            y=0.95,
            ax=0,
            ay=0,
            text="Domain Spec.",
            showarrow=False,
            font=go.layout.annotation.Font(size=font_size["title"], shadow="auto"),
            row=1,
            col=1,
        )

        fig.update_yaxes(
            range=[-1, 1],
            showticklabels=True,
            tickfont=go.layout.yaxis.Tickfont(size=font_size["tick"]),
            title=go.layout.yaxis.Title(
                font=go.layout.yaxis.title.Font(size=font_size["title"]),
                text="Correlation",
            ),
            row=1,
            col=1,
        )

for j, key in enumerate(main_model_config.index.values):
    tmpdf = vmcdf.query(f"model == '{key}'")
    if len(tmpdf) == 0:
        continue

    px = tmpdf["best_f1"]
    pc = model_colors[key]
    pt = tmpdf.apply(make_abbr, axis=1)

    for i, (prefix, title_prefix) in enumerate(
        (("in", "Input"), ("pred", "Pred."), ("out", "G.T."))
    ):
        row = (i + 1) // 2 + 1
        col = (i + 1) % 2 + 1
        py = tmpdf[f"{prefix}_corr"]
        ps = tmpdf[f"{prefix}_vs"] * 30
        pa = text_pos[i + 1][key]

        fig.add_scatter(
            x=px,
            y=py,
            hoverinfo="skip",
            marker=go.scatter.Marker(
                color=pc, line=go.scatter.marker.Line(color="white", width=1), opacity=0.7, size=ps
            ),
            legendgroup=key,
            mode="markers",
            name=new_name[key],
            showlegend=False,
            zorder=100 - j,
            row=row,
            col=col,
        )

        for k in range(px.shape[0]):
            fig.add_annotation(
                x=px.iloc[k],
                y=py.iloc[k],
                arrowcolor=pc,
                ax=pa[k][0],
                ay=pa[k][1],
                text=pt.iloc[k],
                standoff=max(ps.iloc[k] / 2 - 1, 0),
                showarrow=True,
                font=go.layout.annotation.Font(size=font_size["label"], shadow="auto"),
                row=row,
                col=col,
            )

        if j == 0:
            fig.add_annotation(
                xref="x domain",
                yref="y domain",
                x=0.95,
                y=0.95,
                ax=0,
                ay=0,
                text=f"{title_prefix} Vocab. Spec.",
                showarrow=False,
                font=go.layout.annotation.Font(size=font_size["title"], shadow="auto"),
                row=row,
                col=col,
            )

            if row == 2:
                fig.update_xaxes(
                    tickfont=go.layout.xaxis.Tickfont(size=font_size["tick"]),
                    title=go.layout.xaxis.Title(
                        font=go.layout.xaxis.title.Font(size=font_size["title"]), text="SRP"
                    ),
                    row=row,
                    col=col,
                )

            fig.update_yaxes(range=[-1, 1], showticklabels=col == 1, row=row, col=col)

            if col == 1:
                fig.update_yaxes(
                    tickfont=go.layout.yaxis.Tickfont(size=font_size["tick"]),
                    title=go.layout.yaxis.Title(
                        font=go.layout.yaxis.title.Font(size=font_size["title"]),
                        text="Correlation",
                    ),
                    row=row,
                    col=col,
                )

fig.update_layout(
    legend=go.layout.Legend(
        font=go.layout.legend.Font(size=font_size["legend"]),
        itemsizing="constant",
        orientation="h",
        y=-0.06,
        yanchor="top",
    ),
    margin=go.layout.Margin(l=60, r=30, t=30, b=15),
    width=1300,
    height=1300,
)

fig.write_image("../plot/msrpsc.pdf", width=fig.layout.width, height=fig.layout.height)
fig.show()

In [None]:
num_cols = 10
num_rows = (len(main_model_config) - 1) // num_cols + 1

fig = make_subplots(
    rows=num_rows,
    cols=num_cols,
    shared_xaxes="all",
    shared_yaxes="all",
    horizontal_spacing=0.02,
    vertical_spacing=0.08,
    subplot_titles=sorted_model_keys.map(new_name).values,
)

font_size = {"bar_tick": 14, "bar_title": 16, "tick": 16, "title": 18}

for i, key in enumerate(sorted_model_keys):
    if (dedf["model"] == key).sum() == 0:
        continue

    row = i // num_cols + 1
    col = i % num_cols + 1
    tmpdf = dedf.query(f"model == '{key}'")
    num_layers = model_config.loc[key, "num_layers"]
    layer_idx = tmpdf["layer_idx"] + 1

    if model_config.loc[key, "type"] == "seq2seq":
        layer_idx += tmpdf["is_decoder"] * num_layers // 2

    fig.add_scatter(
        x=tmpdf["act_r_cv"],
        y=tmpdf["best_f1"],
        hoverinfo="skip",
        marker=go.scatter.Marker(
            cmax=num_layers,
            cmin=1,
            color=layer_idx + 1,
            colorbar=go.scatter.marker.ColorBar(
                len=0.49,
                thickness=10,
                x=(col - 0.23) * 0.102,
                y=(num_rows - 1 - row + 1.45) * 0.54,
                tickfont=go.scatter.marker.colorbar.Tickfont(size=font_size["bar_tick"]),
                tickvals=[1, num_layers // 2, num_layers],
                title=go.scatter.marker.colorbar.Title(
                    font=go.scatter.marker.colorbar.title.Font(size=font_size["bar_title"]),
                    text="Ly.",
                ),
            ),
            size=5,
        ),
        mode="markers",
        showlegend=False,
        row=row,
        col=col,
    )

    fig.update_xaxes(showticklabels=row == num_rows, row=row, col=col)

    if row == num_rows:
        fig.update_xaxes(
            tickfont=go.layout.xaxis.Tickfont(size=font_size["tick"]),
            title=go.layout.xaxis.Title(
                font=go.layout.xaxis.title.Font(size=font_size["title"]),
                standoff=5,
                text="Domain Spec.",
            ),
            row=row,
            col=col,
        )

    if col == 1:
        fig.update_yaxes(
            showticklabels=True,
            tickfont=go.layout.yaxis.Tickfont(size=font_size["tick"]),
            title=go.layout.yaxis.Title(
                font=go.layout.yaxis.title.Font(size=font_size["title"]), text="SRP"
            ),
            row=row,
            col=col,
        )

fig.update_annotations(font=go.layout.annotation.Font(size=font_size["title"]))
fig.update_layout(margin=go.layout.Margin(l=60, r=15, t=30, b=60), width=2000, height=450)
fig.write_image("../plot/esrpds.pdf", width=fig.layout.width, height=fig.layout.height)
fig.show()

In [None]:
num_cols = 10
num_rows = (len(main_model_config) - 1) // num_cols + 1

fig = make_subplots(
    rows=num_rows,
    cols=num_cols,
    shared_xaxes="all",
    shared_yaxes="all",
    horizontal_spacing=0.02,
    vertical_spacing=0.08,
    subplot_titles=[
        "" if (vedf["model"] == key).sum() == 0 else new_name[key] for key in sorted_model_keys
    ],
)

font_size = {"bar_tick": 14, "bar_title": 16, "tick": 16, "title": 18}

for i, key in enumerate(sorted_model_keys):
    if (vedf["model"] == key).sum() == 0:
        continue

    row = i // num_cols + 1
    col = i % num_cols + 1
    tmpdf = vedf.query(f"model == '{key}'")
    num_layers = model_config.loc[key, "num_layers"]
    layer_idx = tmpdf["layer_idx"] + 1

    if model_config.loc[key, "type"] == "seq2seq":
        layer_idx += tmpdf["is_decoder"] * num_layers // 2

    fig.add_scatter(
        x=tmpdf["in_freq"],
        y=tmpdf["best_f1"],
        hoverinfo="skip",
        marker=go.scatter.Marker(
            cmax=num_layers,
            cmin=1,
            color=layer_idx + 1,
            colorbar=go.scatter.marker.ColorBar(
                len=0.49,
                thickness=10,
                x=(col - 0.23) * 0.102,
                y=(num_rows - 1 - row + 1.45) * 0.54,
                tickfont=go.scatter.marker.colorbar.Tickfont(size=font_size["bar_tick"]),
                tickvals=[1, num_layers // 2, num_layers],
                title=go.scatter.marker.colorbar.Title(
                    font=go.scatter.marker.colorbar.title.Font(size=font_size["bar_title"]),
                    text="Ly.",
                ),
            ),
            size=5,
        ),
        mode="markers",
        showlegend=False,
        row=row,
        col=col,
    )

    fig.update_xaxes(showticklabels=row == num_rows, row=row, col=col)

    if row == num_rows:
        fig.update_xaxes(
            tickfont=go.layout.xaxis.Tickfont(size=font_size["tick"]),
            title=go.layout.xaxis.Title(
                font=go.layout.xaxis.title.Font(size=font_size["title"]),
                standoff=5,
                text="In. Voc. Spec.",
            ),
            row=row,
            col=col,
        )

    if col == 1:
        fig.update_yaxes(
            showticklabels=True,
            tickfont=go.layout.yaxis.Tickfont(size=font_size["tick"]),
            title=go.layout.yaxis.Title(
                font=go.layout.yaxis.title.Font(size=font_size["title"]), text="SRP"
            ),
            row=row,
            col=col,
        )

fig.update_annotations(font=go.layout.annotation.Font(size=font_size["title"]))
fig.update_layout(margin=go.layout.Margin(l=60, r=15, t=30, b=60), width=2000, height=450)
fig.write_image("../plot/esrpivs.pdf", width=fig.layout.width, height=fig.layout.height)
fig.show()

In [None]:
num_cols = 10
num_rows = (len(main_model_config) - 1) // num_cols + 1

fig = make_subplots(
    rows=num_rows,
    cols=num_cols,
    shared_xaxes="all",
    shared_yaxes="all",
    horizontal_spacing=0.02,
    vertical_spacing=0.08,
    subplot_titles=[
        "" if (vedf["model"] == key).sum() == 0 else new_name[key] for key in sorted_model_keys
    ],
)

font_size = {"bar_tick": 14, "bar_title": 16, "tick": 16, "title": 18}

for i, key in enumerate(sorted_model_keys):
    if (vedf["model"] == key).sum() == 0:
        continue

    row = i // num_cols + 1
    col = i % num_cols + 1
    tmpdf = vedf.query(f"model == '{key}'")
    num_layers = model_config.loc[key, "num_layers"]
    layer_idx = tmpdf["layer_idx"] + 1

    if model_config.loc[key, "type"] == "seq2seq":
        layer_idx += tmpdf["is_decoder"] * num_layers // 2

    fig.add_scatter(
        x=tmpdf["pred_freq"],
        y=tmpdf["best_f1"],
        hoverinfo="skip",
        marker=go.scatter.Marker(
            cmax=num_layers,
            cmin=1,
            color=layer_idx + 1,
            colorbar=go.scatter.marker.ColorBar(
                len=0.49,
                thickness=10,
                x=(col - 0.23) * 0.102,
                y=(num_rows - 1 - row + 1.45) * 0.54,
                tickfont=go.scatter.marker.colorbar.Tickfont(size=font_size["bar_tick"]),
                tickvals=[1, num_layers // 2, num_layers],
                title=go.scatter.marker.colorbar.Title(
                    font=go.scatter.marker.colorbar.title.Font(size=font_size["bar_title"]),
                    text="Ly.",
                ),
            ),
            size=5,
        ),
        mode="markers",
        showlegend=False,
        row=row,
        col=col,
    )

    fig.update_xaxes(showticklabels=row == num_rows, row=row, col=col)

    if row == num_rows:
        fig.update_xaxes(
            tickfont=go.layout.xaxis.Tickfont(size=font_size["tick"]),
            title=go.layout.xaxis.Title(
                font=go.layout.xaxis.title.Font(size=font_size["title"]),
                standoff=5,
                text="G. T. O. Voc. Spec.",
            ),
            row=row,
            col=col,
        )

    if col == 1:
        fig.update_yaxes(
            showticklabels=True,
            tickfont=go.layout.yaxis.Tickfont(size=font_size["tick"]),
            title=go.layout.yaxis.Title(
                font=go.layout.yaxis.title.Font(size=font_size["title"]), text="SRP"
            ),
            row=row,
            col=col,
        )

fig.update_annotations(font=go.layout.annotation.Font(size=font_size["title"]))
fig.update_layout(margin=go.layout.Margin(l=60, r=15, t=30, b=60), width=2000, height=450)
fig.write_image("../plot/esrpovs.pdf", width=fig.layout.width, height=fig.layout.height)
fig.show()

In [None]:
num_cols = 10
num_rows = (len(main_model_config) - 1) // num_cols + 1

fig = make_subplots(
    rows=num_rows,
    cols=num_cols,
    shared_xaxes="all",
    shared_yaxes="all",
    horizontal_spacing=0.02,
    vertical_spacing=0.08,
    subplot_titles=[
        "" if (vedf["model"] == key).sum() == 0 else new_name[key] for key in sorted_model_keys
    ],
)

font_size = {"bar_tick": 14, "bar_title": 16, "tick": 16, "title": 18}

for i, key in enumerate(sorted_model_keys):
    if (vedf["model"] == key).sum() == 0:
        continue

    row = i // num_cols + 1
    col = i % num_cols + 1
    tmpdf = vedf.query(f"model == '{key}'")
    num_layers = model_config.loc[key, "num_layers"]
    layer_idx = tmpdf["layer_idx"] + 1

    if model_config.loc[key, "type"] == "seq2seq":
        layer_idx += tmpdf["is_decoder"] * num_layers // 2

    fig.add_scatter(
        x=tmpdf["pred_freq"],
        y=tmpdf["best_f1"],
        hoverinfo="skip",
        marker=go.scatter.Marker(
            cmax=num_layers,
            cmin=1,
            color=layer_idx + 1,
            colorbar=go.scatter.marker.ColorBar(
                len=0.49,
                thickness=10,
                x=(col - 0.23) * 0.102,
                y=(num_rows - 1 - row + 1.45) * 0.54,
                tickfont=go.scatter.marker.colorbar.Tickfont(size=font_size["bar_tick"]),
                tickvals=[1, num_layers // 2, num_layers],
                title=go.scatter.marker.colorbar.Title(
                    font=go.scatter.marker.colorbar.title.Font(size=font_size["bar_title"]),
                    text="Ly.",
                ),
            ),
            size=5,
        ),
        mode="markers",
        showlegend=False,
        row=row,
        col=col,
    )

    fig.update_xaxes(showticklabels=row == num_rows, row=row, col=col)

    if row == num_rows:
        fig.update_xaxes(
            tickfont=go.layout.xaxis.Tickfont(size=font_size["tick"]),
            title=go.layout.xaxis.Title(
                font=go.layout.xaxis.title.Font(size=font_size["title"]),
                standoff=5,
                text="P. O. Voc. Spec.",
            ),
            row=row,
            col=col,
        )

    if col == 1:
        fig.update_yaxes(
            showticklabels=True,
            tickfont=go.layout.yaxis.Tickfont(size=font_size["tick"]),
            title=go.layout.yaxis.Title(
                font=go.layout.yaxis.title.Font(size=font_size["title"]), text="SRP"
            ),
            row=row,
            col=col,
        )

fig.update_annotations(font=go.layout.annotation.Font(size=font_size["title"]))
fig.update_layout(margin=go.layout.Margin(l=60, r=15, t=30, b=60), width=2000, height=450)
fig.write_image("../plot/esrppvs.pdf", width=fig.layout.width, height=fig.layout.height)
fig.show()