In [None]:
from pathlib import Path

import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from misc import data_config, model_config

In [None]:
main_model_config = (
    model_config.query("main")
    .drop(columns="main")
    .rename(columns={k: f"model_{k}" for k in model_config.columns})
)

new_name = {
    "powermoe": "PowerMoE",
    "llamamoe": "LLaMA-MoE-v1",
    "olmoe": "OLMoE",
    "switch": "SwitchTransformers",
    "llamamoe2": "LLaMA-MoE-v2",
    "jetmoe": "JetMoE",
    "openmoe": "OpenMoE",
    "minicpm": "MiniCPM-MoE",
    "qwen": "Qwen1.5-MoE",
    "deepseek2": "DeepSeek-V2-Lite",
    "deepseek": "DeepSeekMoE",
    "xverse": "XVERSE-MoE",
    "qwen3": "Qwen3",
    "yuan": "Yuan2.0",
    "phi": "Phi-3.5-MoE",
    "grin": "GRIN-MoE",
    "mixtral": "Mixtral-8x7B",
    "jamba": "Jamba-Mini",
    "nllb": "NLLB-MoE",
    "qwen2": "Qwen2",
}

model_colors = {
    key: px.colors.qualitative.Dark24[i] for i, key in enumerate(main_model_config.index.values)
}

seg_lens = (4, 16, 64, 256)
seg_len_colors = {key: px.colors.qualitative.Plotly[i] for i, key in enumerate(seg_lens)}
main_model_config

In [None]:
main_data_config = data_config.rename(columns={k: f"data_{k}" for k in data_config.columns})
main_data_config

In [None]:
def make_abbr(df):
    return (
        f"{df['model_abbr']}{'d' if df['is_decoder'] else 'e'}"
        if df["model_type"] == "seq2seq"
        else df["model_abbr"]
    )

In [None]:
root_dir = Path("../output/srp_mpq")

dfs = {
    p.stem: pd.merge(pd.read_parquet(p), main_model_config, left_on="model", right_index=True)
    for p in root_dir.glob("*.parquet")
}

for key in dfs.keys():
    if "dataset" in dfs[key].columns:
        dfs[key] = pd.merge(dfs[key], main_data_config, left_on="dataset", right_index=True)

for df in dfs.values():
    df["model"] = df["model"].astype(model_config.index.dtype)
    if "dataset" in df.columns:
        df["dataset"] = df["dataset"].astype(data_config.index.dtype)

sorted_model_keys = (
    dfs["mg"].query("is_decoder and seg_len == 16").sort_values("best_f1", ascending=False)["model"]
)

dfs["md"].query("seg_len == 16").pivot(
    index=["model", "is_decoder"], columns="dataset", values="best_f1"
)

In [None]:
sample_seg_len = 16

mdf = (
    pd.merge(
        dfs["mg"]
        .drop(columns=["best_m", "ci_lb", "ci_ub"])
        .rename(columns={"best_f1": "gen_best_f1"}),
        dfs["md"].drop(columns=["act_r", "best_m", "ci_lb", "ci_ub"]),
    )
    .query(f"seg_len == {sample_seg_len}")
    .drop(columns="seg_len")
)

mdf["f1_diff"] = (mdf["best_f1"] - mdf["gen_best_f1"]) / mdf["gen_best_f1"]
mdf.pivot(index=["model", "is_decoder"], columns="dataset", values="f1_diff")

In [None]:
num_cols = 10
num_rows = (len(main_model_config) - 1) // num_cols + 1

fig = make_subplots(
    rows=num_rows,
    cols=num_cols,
    shared_xaxes="all",
    shared_yaxes="all",
    horizontal_spacing=0.005,
    vertical_spacing=0.15,
    subplot_titles=sorted_model_keys.map(new_name).values,
)

font_size = {"tick_x": 12, "tick_y": 16, "title": 20}

for i, key in enumerate(sorted_model_keys):
    if (mdf["model"] == key).sum() == 0:
        continue

    row = i // num_cols + 1
    col = i % num_cols + 1

    for is_decoder in (False, True):
        tmpdf = mdf.query(f"model == '{key}' and is_decoder == {is_decoder}")
        if len(tmpdf) == 0:
            continue

        fig.add_bar(
            x=tmpdf["data_abbr"],
            y=tmpdf["f1_diff"],
            hoverinfo="skip",
            marker=go.bar.Marker(color=model_colors[key]),
            opacity=1 if is_decoder else 0.5,
            showlegend=False,
            row=row,
            col=col,
        )

    fig.update_xaxes(
        showticklabels=True,
        tickangle=90,
        tickfont=go.layout.xaxis.Tickfont(size=font_size["tick_x"]),
        row=row,
        col=col,
    )

    fig.update_yaxes(showticklabels=col == 1, tickvals=[-0.1, 0, 0.1], row=row, col=col)

    if col == 1:
        fig.update_yaxes(
            tickfont=go.layout.yaxis.Tickfont(size=font_size["tick_y"]),
            ticktext=["-10%", "SRP", "+10%"],
            row=row,
            col=col,
        )

fig.update_annotations(font=go.layout.annotation.Font(size=font_size["title"]))
fig.update_layout(margin=go.layout.Margin(l=60, r=30, t=30, b=15), width=1800, height=400)
fig.write_image("../plot/msrpdd.pdf", width=fig.layout.width, height=fig.layout.height)
fig.show()

In [None]:
sedf = dfs["ed"].query(f"seg_len == {sample_seg_len}").drop(columns="seg_len")
sedf

In [None]:
dsedf = (
    sedf.pivot(
        index=["model", "is_decoder", "layer_idx", "expert_idx"],
        columns="dataset",
        values="act_r",
    )
    .groupby("model", observed=True)
    .corr()
)

dsedf

In [None]:
num_cols = 10
num_rows = (len(main_model_config) - 1) // num_cols + 1

fig = make_subplots(
    rows=num_rows,
    cols=num_cols,
    shared_xaxes="all",
    shared_yaxes="all",
    horizontal_spacing=0.005,
    vertical_spacing=0.12,
    subplot_titles=sorted_model_keys.map(new_name).values,
)

font_size = {"tick": 12, "bar_tick": 16, "title": 20}
show_scale = True

for i, key in enumerate(sorted_model_keys):
    if key not in dsedf.index.levels[0]:
        continue

    row = i // num_cols + 1
    col = i % num_cols + 1
    tmpdf = dsedf.loc[key]

    fig.add_heatmap(
        z=tmpdf.values,
        x=tmpdf.columns.map(main_data_config["data_abbr"]),
        y=tmpdf.index.map(main_data_config["data_abbr"]),
        colorbar=go.heatmap.ColorBar(
            tickfont=go.heatmap.colorbar.Tickfont(size=font_size["bar_tick"])
        ),
        colorscale="RdBu_r",
        hoverinfo="skip",
        zmin=-1,
        zmax=1,
        showscale=show_scale,
        row=row,
        col=col,
    )

    fig.update_xaxes(
        showticklabels=True,
        tickangle=90,
        tickfont=go.layout.xaxis.Tickfont(size=font_size["tick"]),
        row=row,
        col=col,
    )

    if col == 1:
        fig.update_yaxes(
            tickfont=go.layout.yaxis.Tickfont(size=font_size["tick"]), row=row, col=col
        )

    show_scale = False

fig.update_annotations(font=go.layout.annotation.Font(size=font_size["title"]))
fig.update_layout(margin=go.layout.Margin(l=30, r=15, t=30, b=15), width=1800, height=450)
fig.write_image("../plot/mactdr.pdf", width=fig.layout.width, height=fig.layout.height)
fig.show()

In [None]:
csedf = (
    sedf.pivot(
        index=["model", "is_decoder", "layer_idx", "expert_idx"],
        columns="dataset",
        values="best_f1",
    )
    .groupby("model", observed=True)
    .corr()
)

csedf

In [None]:
num_cols = 10
num_rows = (len(main_model_config) - 1) // num_cols + 1

fig = make_subplots(
    rows=num_rows,
    cols=num_cols,
    shared_xaxes="all",
    shared_yaxes="all",
    horizontal_spacing=0.005,
    vertical_spacing=0.12,
    subplot_titles=sorted_model_keys.map(new_name).values,
)

font_size = {"tick": 12, "bar_tick": 16, "title": 20}
show_scale = True

for i, key in enumerate(sorted_model_keys):
    if key not in csedf.index.levels[0]:
        continue

    row = i // num_cols + 1
    col = i % num_cols + 1
    tmpdf = csedf.loc[key]

    fig.add_heatmap(
        z=tmpdf.values,
        x=tmpdf.columns.map(main_data_config["data_abbr"]),
        y=tmpdf.index.map(main_data_config["data_abbr"]),
        colorbar=go.heatmap.ColorBar(
            tickfont=go.heatmap.colorbar.Tickfont(size=font_size["bar_tick"])
        ),
        colorscale="RdBu_r",
        hoverinfo="skip",
        zmin=-1,
        zmax=1,
        showscale=show_scale,
        row=row,
        col=col,
    )

    fig.update_xaxes(
        showticklabels=True,
        tickangle=90,
        tickfont=go.layout.xaxis.Tickfont(size=font_size["tick"]),
        row=row,
        col=col,
    )

    if col == 1:
        fig.update_yaxes(
            tickfont=go.layout.yaxis.Tickfont(size=font_size["tick"]), row=row, col=col
        )

    show_scale = False

fig.update_annotations(font=go.layout.annotation.Font(size=font_size["title"]))
fig.update_layout(margin=go.layout.Margin(l=30, r=15, t=30, b=15), width=1800, height=450)
fig.write_image("../plot/msrpdr.pdf", width=fig.layout.width, height=fig.layout.height)
fig.show()