In [None]:
from pathlib import Path

import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from misc import model_config

In [None]:
main_model_config = (
    model_config.query("main")
    .drop(columns="main")
    .rename(columns={k: f"model_{k}" for k in model_config.columns})
)

new_name = {
    "powermoe": "PowerMoE",
    "llamamoe": "LLaMA-MoE-v1",
    "olmoe": "OLMoE",
    "switch": "SwitchTransformers",
    "llamamoe2": "LLaMA-MoE-v2",
    "jetmoe": "JetMoE",
    "openmoe": "OpenMoE",
    "minicpm": "MiniCPM-MoE",
    "qwen": "Qwen1.5-MoE",
    "deepseek2": "DeepSeek-V2-Lite",
    "deepseek": "DeepSeekMoE",
    "xverse": "XVERSE-MoE",
    "qwen3": "Qwen3",
    "yuan": "Yuan2.0",
    "phi": "Phi-3.5-MoE",
    "grin": "GRIN-MoE",
    "mixtral": "Mixtral-8x7B",
    "jamba": "Jamba-Mini",
    "nllb": "NLLB-MoE",
    "qwen2": "Qwen2",
}

model_colors = {
    key: px.colors.qualitative.Dark24[i] for i, key in enumerate(main_model_config.index.values)
}

seg_lens = (4, 16, 64, 256)
seg_len_colors = {key: px.colors.qualitative.Plotly[i] for i, key in enumerate(seg_lens)}
main_model_config

In [None]:
def make_abbr(df):
    return (
        f"{df['model_abbr']}{'d' if df['is_decoder'] else 'e'}"
        if df["model_type"] == "seq2seq"
        else df["model_abbr"]
    )

In [None]:
root_dir = Path("../output/srp_pos_mpq")

dfs = {
    p.stem: pd.merge(pd.read_parquet(p), main_model_config, left_on="model", right_index=True)
    for p in root_dir.glob("m?.parquet")
}

for df in dfs.values():
    df["model"] = df["model"].astype(model_config.index.dtype)

dfs["mg"]

In [None]:
smgdf = (
    dfs["mg"]
    .groupby(["model", "is_decoder", "seg_len"], as_index=False, observed=True)[["best_f1"]]
    .std()
)

smgdf.pivot(index=["model", "is_decoder"], columns="seg_len", values="best_f1")

In [None]:
sample_seg_len = 16

smddf = (
    dfs["md"]
    .query(f"seg_len == {sample_seg_len}")
    .drop(columns="seg_len")
    .groupby(["model", "is_decoder", "dataset"], as_index=False, observed=True)[["best_f1"]]
    .std()
)

smddf.pivot(index=["model", "is_decoder"], columns="dataset", values="best_f1")

In [None]:
sorted_model_keys = (
    pd.merge(
        pd.read_parquet("../output/srp_mpq/mg.parquet"),
        main_model_config,
        left_on="model",
        right_index=True,
    )
    .query("is_decoder and seg_len == 16")
    .sort_values("best_f1", ascending=False)["model"]
)

sorted_model_keys

In [None]:
mgdf = (
    dfs["mg"]
    .assign(seg_pos=dfs["mg"]["start_pos"] + dfs["mg"]["seg_len"] // 2)
    .drop(columns="start_pos")
)

mgdf

In [None]:
num_cols = 10
num_rows = (len(main_model_config) - 1) // num_cols + 1

fig = make_subplots(
    rows=num_rows,
    cols=num_cols,
    shared_xaxes="all",
    shared_yaxes="all",
    horizontal_spacing=0.005,
    vertical_spacing=0.11,
    subplot_titles=sorted_model_keys.map(new_name).values,
)

font_size = {"tick": 16, "legend": 18, "title": 20}
show_legend = True

for i, key in enumerate(sorted_model_keys):
    if (mgdf["model"] == key).sum() == 0:
        continue

    row = i // num_cols + 1
    col = i % num_cols + 1
    num_layers = model_config.loc[key, "num_layers"]

    for seg_len in seg_lens:
        tmpdf = mgdf.query(f"model == '{key}' and seg_len == {seg_len}")
        if len(tmpdf) == 0:
            continue

        line_name = f"$\\mathrm{{SRP}}(m={seg_len})$"

        if main_model_config.loc[key, "model_type"] == "seq2seq":
            for is_decoder in (False, True):
                subdf = tmpdf.query(f"is_decoder == {is_decoder}")
                if len(subdf) == 0:
                    continue

                fig.add_scatter(
                    x=subdf["seg_pos"],
                    y=subdf["best_f1"],
                    hoverinfo="skip",
                    legendgroup=line_name,
                    line=go.scatter.Line(
                        color=seg_len_colors[seg_len], dash="solid" if is_decoder else "dot"
                    ),
                    mode="lines",
                    name=line_name,
                    opacity=1 if is_decoder else 0.5,
                    showlegend=show_legend,
                    row=row,
                    col=col,
                )
        else:
            fig.add_scatter(
                x=tmpdf["seg_pos"],
                y=tmpdf["best_f1"],
                hoverinfo="skip",
                legendgroup=line_name,
                line=go.scatter.Line(color=seg_len_colors[seg_len]),
                mode="lines",
                name=line_name,
                showlegend=show_legend,
                row=row,
                col=col,
            )

    show_legend = False

    fig.update_xaxes(
        showticklabels=True,
        tickfont=go.layout.xaxis.Tickfont(size=font_size["tick"]),
        row=row,
        col=col,
    )

    if row == num_rows:
        fig.update_xaxes(
            title=go.layout.xaxis.Title(
                font=go.layout.xaxis.title.Font(size=font_size["title"]),
                standoff=1,
                text="Position",
            ),
            row=row,
            col=col,
        )

    fig.update_yaxes(showticklabels=col == 1, row=row, col=col)

    if col == 1:
        fig.update_yaxes(
            tickfont=go.layout.yaxis.Tickfont(size=font_size["tick"]),
            title=go.layout.yaxis.Title(
                font=go.layout.yaxis.title.Font(size=font_size["title"]), text="SRP"
            ),
            row=row,
            col=col,
        )

fig.update_annotations(font=go.layout.annotation.Font(size=font_size["title"]))

fig.update_layout(
    legend=go.layout.Legend(
        font=go.layout.legend.Font(size=font_size["legend"]),
        orientation="h",
        x=0.47,
        xanchor="center",
    ),
    margin=go.layout.Margin(l=60, r=15, t=30, b=90),
    width=2000,
    height=500,
)

fig.write_image("../plot/msrpp.pdf", width=fig.layout.width, height=fig.layout.height)
fig.show()