In [None]:
from pathlib import Path

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from misc import model_config

In [2]:
main_model_config = (
    model_config.query("main")
    .drop(columns="main")
    .rename(columns={k: f"model_{k}" for k in model_config.columns})
)

new_name = {
    "powermoe": "PowerMoE",
    "llamamoe": "LLaMA-MoE-v1",
    "olmoe": "OLMoE",
    "switch": "SwitchTransformers",
    "llamamoe2": "LLaMA-MoE-v2",
    "jetmoe": "JetMoE",
    "openmoe": "OpenMoE",
    "minicpm": "MiniCPM-MoE",
    "qwen": "Qwen1.5-MoE",
    "deepseek2": "DeepSeek-V2-Lite",
    "deepseek": "DeepSeekMoE",
    "xverse": "XVERSE-MoE",
    "qwen3": "Qwen3",
    "yuan": "Yuan2.0",
    "phi": "Phi-3.5-MoE",
    "grin": "GRIN-MoE",
    "mixtral": "Mixtral-8x7B",
    "jamba": "Jamba-Mini",
    "nllb": "NLLB-MoE",
    "qwen2": "Qwen2",
}

model_colors = {
    key: px.colors.qualitative.Dark24[i] for i, key in enumerate(main_model_config.index.values)
}

seg_lens = (4, 16, 64, 256)
seg_len_colors = {key: px.colors.qualitative.Plotly[i] for i, key in enumerate(seg_lens)}
main_model_config

Unnamed: 0_level_0,model_name,model_abbr,model_type,model_num_params,model_num_layers,model_num_experts,model_top_k,model_attn
key,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
powermoe,PowerMoE-3B,PW,causal,3.3,32,40,8,flash_attention_2
llamamoe,LLaMA-MoE-v1-3.5B,LL1,causal,6.74,32,16,4,eager
olmoe,OLMoE-1B-7B-0125,OL,causal,6.92,16,64,8,flash_attention_2
switch,SwitchTransformers-Base-128,ST,seq2seq,7.42,24,128,1,eager
llamamoe2,LLaMA-MoE-v2-3.8B,LL2,causal,8.03,32,8,2,flash_attention_2
jetmoe,JetMoE-8B,JT,causal,8.52,24,8,2,flash_attention_2
openmoe,OpenMoE-8B,OP,causal,11.86,24,32,2,eager
minicpm,MiniCPM-MoE-8x2B,MC,causal,13.87,40,8,2,flash_attention_2
qwen,Qwen1.5-MoE-A2.7B,QW1,causal,14.32,24,60,4,flash_attention_2
deepseek2,DeepSeek-V2-Lite,DS2,causal,15.71,27,64,6,flash_attention_2


In [3]:
def make_abbr(df):
    return (
        f"{df['model_abbr']}{'d' if df['is_decoder'] else 'e'}"
        if df["model_type"] == "seq2seq"
        else df["model_abbr"]
    )

In [None]:
root_dir = Path("../output/srp_mpq")

dfs = {
    p.stem: pd.merge(pd.read_parquet(p), main_model_config, left_on="model", right_index=True)
    for p in root_dir.glob("*.parquet")
}

for df in dfs.values():
    df["model"] = df["model"].astype(model_config.index.dtype)

sorted_model_keys = (
    dfs["mg"].query("is_decoder and seg_len == 16").sort_values("best_f1", ascending=False)["model"]
)

dfs["mg"].pivot(index=["model", "is_decoder"], columns="seg_len", values=["best_f1", "best_m"])

Unnamed: 0_level_0,Unnamed: 1_level_0,best_f1,best_f1,best_f1,best_f1,best_m,best_m,best_m,best_m
Unnamed: 0_level_1,seg_len,4,16,64,256,4,16,64,256
model,is_decoder,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
powermoe,True,0.673095,0.55172,0.505882,0.485872,1.103515,1.392749,1.488992,1.579756
llamamoe,True,0.557779,0.45288,0.416099,0.40618,1.029198,2.392017,2.923047,3.521086
olmoe,True,0.646949,0.509072,0.455343,0.426438,0.997344,1.056418,1.205091,1.187848
switch,False,0.416367,0.193265,0.123377,0.099904,3.803456,2.009145,2.649861,1.913562
switch,True,0.418338,0.192715,0.117753,,3.780828,2.059083,2.665879,
llamamoe2,True,0.83185,0.781586,0.764461,0.75481,1.12093,1.034675,1.058977,1.062069
jetmoe,True,0.602158,0.47454,0.427762,0.410949,1.093253,2.259037,2.687439,3.154335
openmoe,True,0.455252,0.287725,0.217618,0.187988,3.393169,1.588945,2.63504,2.177667
minicpm,True,0.625336,0.488538,0.437157,0.417272,1.105926,2.186703,2.57877,2.916163
qwen,True,0.472319,0.30706,0.225031,0.187072,3.234429,1.842482,2.105532,2.889928


In [5]:
dfs["mg"].pivot(index=["model", "is_decoder"], columns="seg_len", values="best_f1").sort_values(
    16, ascending=False
)

Unnamed: 0_level_0,seg_len,4,16,64,256
model,is_decoder,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
llamamoe2,True,0.83185,0.781586,0.764461,0.75481
yuan,True,0.711998,0.634819,0.611333,0.59833
powermoe,True,0.673095,0.55172,0.505882,0.485872
qwen3,True,0.672163,0.54143,0.478911,0.438075
phi,True,0.651531,0.519771,0.466027,0.437805
olmoe,True,0.646949,0.509072,0.455343,0.426438
grin,True,0.637753,0.503854,0.450368,0.4212
mixtral,True,0.625494,0.493599,0.445111,0.424732
minicpm,True,0.625336,0.488538,0.437157,0.417272
jetmoe,True,0.602158,0.47454,0.427762,0.410949


In [6]:
dfs["mg"].pivot(
    index=["model", "is_decoder"], columns="seg_len", values=["ci_lb", "ci_ub"]
).swaplevel(0, 1, axis=1).sort_index(axis=1).sort_values((16, "ci_lb"), ascending=False)

Unnamed: 0_level_0,seg_len,4,4,16,16,64,64,256,256
Unnamed: 0_level_1,Unnamed: 1_level_1,ci_lb,ci_ub,ci_lb,ci_ub,ci_lb,ci_ub,ci_lb,ci_ub
model,is_decoder,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
llamamoe2,True,0.831584,0.832086,0.781215,0.781927,0.764056,0.764843,0.75436,0.755251
yuan,True,0.711763,0.712224,0.634489,0.635146,0.610977,0.6117,0.597908,0.598744
powermoe,True,0.672926,0.673263,0.551487,0.551961,0.505597,0.50619,0.485549,0.486205
qwen3,True,0.671854,0.672453,0.540983,0.541858,0.478366,0.479414,0.437414,0.438741
phi,True,0.651136,0.651935,0.519196,0.520402,0.465376,0.466704,0.437034,0.438603
olmoe,True,0.646629,0.64729,0.508583,0.509616,0.454758,0.455937,0.425742,0.427139
grin,True,0.63737,0.638107,0.503271,0.504423,0.449715,0.451004,0.420436,0.421933
mixtral,True,0.625377,0.625612,0.493493,0.493707,0.444976,0.445228,0.424597,0.424857
minicpm,True,0.625198,0.62547,0.488427,0.48865,0.437062,0.437256,0.417183,0.417355
jetmoe,True,0.60204,0.602271,0.474448,0.474627,0.427677,0.427847,0.410866,0.411022


In [7]:
dfs["mg"].assign(
    ci_dist=np.maximum(
        dfs["mg"]["ci_ub"] - dfs["mg"]["best_f1"], dfs["mg"]["best_f1"] - dfs["mg"]["ci_lb"]
    )
).pivot(index=["model", "is_decoder"], columns="seg_len", values="ci_dist")

Unnamed: 0_level_0,seg_len,4,16,64,256
model,is_decoder,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
powermoe,True,0.000169,0.000241,0.000308,0.000333
llamamoe,True,0.000146,7e-05,6.4e-05,5e-05
olmoe,True,0.000341,0.000544,0.000593,0.000701
switch,False,8.8e-05,0.00026,0.000262,0.000341
switch,True,9e-05,0.000299,0.000294,
llamamoe2,True,0.000265,0.000371,0.000406,0.00045
jetmoe,True,0.000118,9.2e-05,8.5e-05,8.2e-05
openmoe,True,0.000162,0.000403,0.000338,0.00045
minicpm,True,0.000138,0.000112,9.9e-05,8.9e-05
qwen,True,9.2e-05,0.00018,0.000211,0.000203


In [8]:
mldf = pd.merge(
    dfs["mg"],
    pd.read_parquet("./output/loss.parquet")
    .groupby("model", as_index=False, observed=True)[["loss"]]
    .mean(),
)

mldf

Unnamed: 0,model,is_decoder,seg_len,act_r,best_f1,ci_lb,ci_ub,best_m,model_name,model_abbr,model_type,model_num_params,model_num_layers,model_num_experts,model_top_k,model_attn,loss
0,powermoe,True,4,0.200000,0.673095,0.672926,0.673263,1.103515,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2,1.675532
1,powermoe,True,16,0.200000,0.551720,0.551487,0.551961,1.392749,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2,1.675532
2,powermoe,True,64,0.200000,0.505882,0.505597,0.506190,1.488992,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2,1.675532
3,powermoe,True,256,0.200000,0.485872,0.485549,0.486205,1.579756,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2,1.675532
4,llamamoe,True,4,0.250000,0.557779,0.557638,0.557925,1.029198,LLaMA-MoE-v1-3.5B,LL1,causal,6.74,32,16,4,eager,1.779338
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
82,nllb,True,256,0.015625,0.219095,0.218652,0.219534,1.454345,NLLB-MoE-54B,NL,seq2seq,54.50,48,128,2,eager,2.530380
83,qwen2,True,4,0.125000,0.502406,0.502261,0.502550,0.818744,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2,1.616623
84,qwen2,True,16,0.125000,0.367437,0.367379,0.367488,2.588414,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2,1.616623
85,qwen2,True,64,0.125000,0.303384,0.303329,0.303440,2.460632,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2,1.616623


In [9]:
fig = make_subplots(rows=1, cols=len(seg_lens), shared_xaxes="all", horizontal_spacing=0.01)

text_pos = [
    {
        "powermoe": "middle right",
        "llamamoe": "top center",
        "olmoe": "middle right",
        "switch": ["middle left", "middle right"],
        "llamamoe2": "bottom center",
        "jetmoe": "top center",
        "openmoe": "middle right",
        "minicpm": "middle right",
        "qwen": "bottom center",
        "deepseek2": "middle right",
        "deepseek": "middle left",
        "xverse": "bottom right",
        "qwen3": "middle right",
        "yuan": "bottom center",
        "phi": "bottom center",
        "grin": "middle left",
        "mixtral": "top center",
        "jamba": "top center",
        "nllb": ["middle left", "middle left"],
        "qwen2": "bottom center",
    },
    {
        "powermoe": "top center",
        "llamamoe": "middle left",
        "olmoe": "bottom center",
        "switch": ["bottom center", "top center"],
        "llamamoe2": "bottom center",
        "jetmoe": "middle left",
        "openmoe": "middle left",
        "minicpm": "bottom center",
        "qwen": "bottom center",
        "deepseek2": "middle right",
        "deepseek": "bottom center",
        "xverse": "bottom center",
        "qwen3": "bottom center",
        "yuan": "bottom center",
        "phi": "top center",
        "grin": "middle left",
        "mixtral": "middle right",
        "jamba": "top center",
        "nllb": ["bottom center", "bottom center"],
        "qwen2": "middle left",
    },
    {
        "powermoe": "top center",
        "llamamoe": "middle left",
        "olmoe": "bottom center",
        "switch": ["middle right", "middle left"],
        "llamamoe2": "bottom center",
        "jetmoe": "middle left",
        "openmoe": "bottom center",
        "minicpm": "middle right",
        "qwen": "middle left",
        "deepseek2": "middle right",
        "deepseek": "middle left",
        "xverse": "bottom center",
        "qwen3": "bottom center",
        "yuan": "bottom center",
        "phi": "middle right",
        "grin": "middle left",
        "mixtral": "middle right",
        "jamba": "top center",
        "nllb": ["bottom center", "bottom center"],
        "qwen2": "bottom center",
    },
    {
        "powermoe": "top center",
        "llamamoe": "middle left",
        "olmoe": "middle left",
        "switch": ["bottom center", "bottom center"],
        "llamamoe2": "bottom center",
        "jetmoe": "middle left",
        "openmoe": "bottom center",
        "minicpm": "middle left",
        "qwen": "bottom center",
        "deepseek2": "middle right",
        "deepseek": "middle right",
        "xverse": "bottom center",
        "qwen3": "bottom center",
        "yuan": "bottom center",
        "phi": "middle right",
        "grin": "middle left",
        "mixtral": "middle right",
        "jamba": "top center",
        "nllb": ["bottom center", "bottom center"],
        "qwen2": "middle right",
    },
]

font_size = [16, 20, 24, 28]
show_legend = True

for i, seg_len in enumerate(seg_lens):
    col = i + 1

    for j, key in enumerate(main_model_config.index.values):
        tmpdf = mldf.query(f"model == '{key}' and seg_len == {seg_len}")

        fig.add_scatter(
            x=tmpdf["best_f1"],
            y=tmpdf["best_m"],
            hoverinfo="skip",
            marker=go.scatter.Marker(
                color=model_colors[key],
                line=go.scatter.marker.Line(color="white", width=1),
                opacity=0.7,
                size=main_model_config.loc[key, "model_num_params"] ** 0.5 * 4,
            ),
            legendgroup=key,
            mode="markers+text",
            name=new_name[key],
            showlegend=show_legend,
            text=tmpdf.apply(make_abbr, axis=1),
            textfont=go.scatter.Textfont(size=font_size[0], shadow="auto"),
            textposition=text_pos[i][key],
            zorder=100 - j,
            row=1,
            col=col,
        )

        fig.update_xaxes(
            tickfont=go.layout.xaxis.Tickfont(size=font_size[1]),
            title=go.layout.xaxis.Title(
                font=go.layout.xaxis.title.Font(size=font_size[2]), text=f"SRP(E,{seg_len})"
            ),
            row=1,
            col=col,
        )

        fig.update_yaxes(range=[0.5, 4], showticklabels=col == 1, row=1, col=col)

        if col == 1:
            fig.update_yaxes(
                tickfont=go.layout.yaxis.Tickfont(size=font_size[1]),
                title=go.layout.yaxis.Title(
                    font=go.layout.yaxis.title.Font(size=font_size[2]), text="ρ(E,m)"
                ),
                row=1,
                col=col,
            )

    show_legend = False

fig.update_annotations(font=go.layout.annotation.Font(size=font_size[3]))

fig.update_layout(
    legend=go.layout.Legend(
        font=go.layout.legend.Font(size=font_size[1]),
        itemsizing="constant",
        orientation="h",
        y=-0.15,
        yanchor="top",
    ),
    margin=go.layout.Margin(l=60, r=30, t=15, b=15),
    width=2000,
    height=600,
)

fig.write_image("./plot/msrp.pdf", width=fig.layout.width, height=fig.layout.height)
fig.show()

In [10]:
fig = make_subplots(rows=1, cols=len(seg_lens), shared_xaxes="all", horizontal_spacing=0.01)

text_pos = [
    {
        "powermoe": "middle right",
        "llamamoe": "middle right",
        "olmoe": "bottom center",
        "llamamoe2": "bottom center",
        "jetmoe": "middle left",
        "openmoe": "bottom center",
        "minicpm": "top center",
        "qwen": "middle left",
        "deepseek2": "bottom center",
        "deepseek": "top center",
        "xverse": "middle left",
        "qwen3": "middle right",
        "yuan": "middle right",
        "phi": "middle right",
        "grin": "middle left",
        "mixtral": "bottom center",
        "jamba": "top center",
        "qwen2": "middle left",
    },
    {
        "powermoe": "middle right",
        "llamamoe": "top center",
        "olmoe": "bottom center",
        "llamamoe2": "bottom center",
        "jetmoe": "top center",
        "openmoe": "bottom center",
        "minicpm": "top center",
        "qwen": "top center",
        "deepseek2": "middle right",
        "deepseek": "middle left",
        "xverse": "middle left",
        "qwen3": "middle right",
        "yuan": "middle right",
        "phi": "middle right",
        "grin": "middle left",
        "mixtral": "bottom center",
        "jamba": "top center",
        "qwen2": "bottom center",
    },
    {
        "powermoe": "middle right",
        "llamamoe": "middle left",
        "olmoe": "bottom center",
        "llamamoe2": "bottom center",
        "jetmoe": "middle left",
        "openmoe": "bottom center",
        "minicpm": "top center",
        "qwen": "top center",
        "deepseek2": "middle right",
        "deepseek": "middle left",
        "xverse": "middle left",
        "qwen3": "middle right",
        "yuan": "middle right",
        "phi": "middle right",
        "grin": "middle left",
        "mixtral": "bottom center",
        "jamba": "top center",
        "qwen2": "bottom center",
    },
    {
        "powermoe": "middle right",
        "llamamoe": "middle left",
        "olmoe": "bottom center",
        "llamamoe2": "bottom center",
        "jetmoe": "middle left",
        "openmoe": "bottom center",
        "minicpm": "top center",
        "qwen": "top center",
        "deepseek2": "bottom center",
        "deepseek": "middle left",
        "xverse": "middle left",
        "qwen3": "middle right",
        "yuan": "middle right",
        "phi": "middle right",
        "grin": "middle left",
        "mixtral": "bottom center",
        "jamba": "top center",
        "qwen2": "middle right",
    },
]

font_size = [16, 20, 24, 28]
show_legend = True

for i, seg_len in enumerate(seg_lens):
    col = i + 1

    for j, key in enumerate(main_model_config.query("model_type == 'causal'").index.values):
        tmpdf = mldf.query(f"model == '{key}' and seg_len == {seg_len}")

        fig.add_scatter(
            x=tmpdf["best_f1"],
            y=tmpdf["loss"],
            hoverinfo="skip",
            marker=go.scatter.Marker(
                color=model_colors[key],
                line=go.scatter.marker.Line(color="white", width=1),
                opacity=0.7,
                size=main_model_config.loc[key, "model_num_params"] ** 0.5 * 4,
            ),
            legendgroup=key,
            mode="markers+text",
            name=new_name[key],
            showlegend=show_legend,
            text=tmpdf.apply(make_abbr, axis=1),
            textfont=go.scatter.Textfont(size=font_size[0], shadow="auto"),
            textposition=text_pos[i][key],
            zorder=100 - j,
            row=1,
            col=col,
        )

        fig.update_xaxes(
            tickfont=go.layout.xaxis.Tickfont(size=font_size[1]),
            title=go.layout.xaxis.Title(
                font=go.layout.xaxis.title.Font(size=font_size[2]), text=f"SRP(E,{seg_len})"
            ),
            row=1,
            col=col,
        )

        fig.update_yaxes(range=[1, 4], showticklabels=col == 1, row=1, col=col)

        if col == 1:
            fig.update_yaxes(
                tickfont=go.layout.yaxis.Tickfont(size=font_size[1]),
                title=go.layout.yaxis.Title(
                    font=go.layout.yaxis.title.Font(size=font_size[2]), text="log PPL"
                ),
                row=1,
                col=col,
            )

    show_legend = False

fig.update_annotations(font=go.layout.annotation.Font(size=font_size[3]))

fig.update_layout(
    legend=go.layout.Legend(
        font=go.layout.legend.Font(size=font_size[1]),
        itemsizing="constant",
        orientation="h",
        y=-0.15,
        yanchor="top",
    ),
    margin=go.layout.Margin(l=60, r=30, t=15, b=15),
    width=2000,
    height=600,
)

fig.write_image("./plot/msrpls.pdf", width=fig.layout.width, height=fig.layout.height)
fig.show()

In [11]:
ldf = dfs["lg"]
ldf

Unnamed: 0,model,is_decoder,layer_idx,seg_len,act_r,best_f1,ci_lb,ci_ub,best_m,model_name,model_abbr,model_type,model_num_params,model_num_layers,model_num_experts,model_top_k,model_attn
0,powermoe,True,0,4,0.200,0.661282,0.661011,0.661563,1.178276,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
1,powermoe,True,0,16,0.200,0.581990,0.581716,0.582293,1.628545,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
2,powermoe,True,0,64,0.200,0.560074,0.559798,0.560361,1.731949,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
3,powermoe,True,0,256,0.200,0.552070,0.551776,0.552368,1.744225,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
4,powermoe,True,1,4,0.200,0.877082,0.876993,0.877178,1.021213,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2713,qwen2,True,26,256,0.125,0.237058,0.236984,0.237134,4.488435,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2
2714,qwen2,True,27,4,0.125,0.471690,0.471625,0.471753,3.240069,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2
2715,qwen2,True,27,16,0.125,0.328063,0.327986,0.328135,2.584436,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2
2716,qwen2,True,27,64,0.125,0.260114,0.260041,0.260187,3.259918,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2


In [12]:
num_cols = 10
num_rows = (len(main_model_config) - 1) // num_cols + 1

fig = make_subplots(
    rows=num_rows,
    cols=num_cols,
    shared_xaxes="all",
    shared_yaxes="all",
    horizontal_spacing=0.07,
    vertical_spacing=0.11,
    subplot_titles=sorted_model_keys.map(new_name).values,
    specs=[[{"secondary_y": True, "r": -0.06} for _ in range(num_cols)] for _ in range(num_rows)],
)

font_size = [16, 16, 18, 20]
show_legend = True

for i, key in enumerate(sorted_model_keys):
    if (ldf["model"] == key).sum() == 0:
        continue

    row = i // num_cols + 1
    col = i % num_cols + 1
    num_layers = model_config.loc[key, "num_layers"]

    for seg_len in seg_lens:
        tmpdf = ldf.query(f"model == '{key}' and seg_len == {seg_len}")
        if len(tmpdf) == 0:
            continue

        layer_idx = tmpdf["layer_idx"] + 1
        if model_config.loc[key, "type"] == "seq2seq":
            layer_idx += tmpdf["is_decoder"] * model_config.loc[key, "num_layers"] // 2

        f1_name = f"SRP(E,{seg_len})"
        m_name = f"ρ(E,{seg_len})"

        fig.add_scatter(
            x=layer_idx / num_layers,
            y=tmpdf["best_f1"],
            hoverinfo="skip",
            legendgroup=f1_name,
            line=go.scatter.Line(color=seg_len_colors[seg_len]),
            mode="lines",
            name=f1_name,
            showlegend=show_legend,
            row=row,
            col=col,
        )

        fig.add_scatter(
            x=layer_idx / num_layers,
            y=tmpdf["best_m"],
            customdata=layer_idx,
            hoverinfo="skip",
            legendgroup=m_name,
            line=go.scatter.Line(color=seg_len_colors[seg_len], dash="dot"),
            mode="lines",
            name=m_name,
            opacity=0.5,
            secondary_y=True,
            showlegend=show_legend,
            row=row,
            col=col,
        )

    show_legend = False

    fig.update_xaxes(
        showticklabels=True,
        tickfont=go.layout.xaxis.Tickfont(size=font_size[1]),
        tickvals=[1 / num_layers, 0.5, 1],
        ticktext=[1, num_layers // 2, num_layers],
        row=row,
        col=col,
    )

    if row == num_rows:
        fig.update_xaxes(
            title=go.layout.xaxis.Title(
                font=go.layout.xaxis.title.Font(size=font_size[2]), standoff=1, text="Layer"
            ),
            row=row,
            col=col,
        )

    fig.update_yaxes(showticklabels=col == 1, row=row, col=col, secondary_y=False)

    fig.update_yaxes(
        range=[0.5, 5.5], showticklabels=col == num_cols, row=row, col=col, secondary_y=True
    )

    if col == 1:
        fig.update_yaxes(
            tickfont=go.layout.yaxis.Tickfont(size=font_size[1]),
            title=go.layout.yaxis.Title(
                font=go.layout.yaxis.title.Font(size=font_size[2]), text="SRP(E,m)"
            ),
            row=row,
            col=col,
            secondary_y=False,
        )

    if col == num_cols:
        fig.update_yaxes(
            tickfont=go.layout.yaxis.Tickfont(size=font_size[1]),
            title=go.layout.yaxis.Title(
                font=go.layout.yaxis.title.Font(size=font_size[2]), text="ρ(E,m)"
            ),
            row=row,
            col=col,
            secondary_y=True,
        )

fig.update_annotations(font=go.layout.annotation.Font(size=font_size[3]))

fig.update_layout(
    legend=go.layout.Legend(
        font=go.layout.legend.Font(size=font_size[2]), orientation="h", x=0.47, xanchor="center"
    ),
    margin=go.layout.Margin(l=60, r=15, t=30, b=15),
    width=2000,
    height=500,
)

fig.write_image("./plot/lsrp.pdf", width=fig.layout.width, height=fig.layout.height)
fig.show()

In [13]:
sddf = pd.merge(
    dfs["mg"],
    dfs["eg"]
    .groupby(["model", "is_decoder", "seg_len"], as_index=False, observed=True)
    .aggregate(act_r_std=("act_r", "std")),
)

sddf.pivot(
    index=["model", "is_decoder"], columns="seg_len", values=["best_f1", "act_r_std"]
).sort_values(("best_f1", 16), ascending=False)

Unnamed: 0_level_0,Unnamed: 1_level_0,best_f1,best_f1,best_f1,best_f1,act_r_std,act_r_std,act_r_std,act_r_std
Unnamed: 0_level_1,seg_len,4,16,64,256,4,16,64,256
model,is_decoder,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
llamamoe2,True,0.83185,0.781586,0.764461,0.75481,0.289824,0.290156,0.290434,0.290639
yuan,True,0.711998,0.634819,0.611333,0.59833,0.138641,0.138634,0.13861,0.138607
powermoe,True,0.673095,0.55172,0.505882,0.485872,0.129049,0.129093,0.129119,0.129163
qwen3,True,0.672163,0.54143,0.478911,0.438075,0.031899,0.031877,0.031715,0.031542
phi,True,0.651531,0.519771,0.466027,0.437805,0.048878,0.048935,0.048964,0.04897
olmoe,True,0.646949,0.509072,0.455343,0.426438,0.067964,0.067876,0.067566,0.067115
grin,True,0.637753,0.503854,0.450368,0.4212,0.038926,0.03889,0.038726,0.038545
mixtral,True,0.625494,0.493599,0.445111,0.424732,0.027064,0.027038,0.026876,0.026741
minicpm,True,0.625336,0.488538,0.437157,0.417272,0.026123,0.025898,0.025459,0.025024
jetmoe,True,0.602158,0.47454,0.427762,0.410949,0.011271,0.01115,0.010808,0.010442


In [14]:
fig = make_subplots(rows=1, cols=len(seg_lens), shared_xaxes="all", horizontal_spacing=0.01)

text_pos = [
    {
        "powermoe": "top center",
        "llamamoe": "top center",
        "olmoe": "top center",
        "switch": ["top center", "top center"],
        "llamamoe2": "top center",
        "jetmoe": "top center",
        "openmoe": "top center",
        "minicpm": "top center",
        "qwen": "top center",
        "deepseek2": "top center",
        "deepseek": "top center",
        "xverse": "top center",
        "qwen3": "top center",
        "yuan": "top center",
        "phi": "top center",
        "grin": "top center",
        "mixtral": "top center",
        "jamba": "top center",
        "nllb": ["top center", "top center"],
        "qwen2": "top center",
    },
    {
        "powermoe": "top center",
        "llamamoe": "top center",
        "olmoe": "top center",
        "switch": ["top center", "top center"],
        "llamamoe2": "top center",
        "jetmoe": "top center",
        "openmoe": "top center",
        "minicpm": "top center",
        "qwen": "top center",
        "deepseek2": "top center",
        "deepseek": "top center",
        "xverse": "top center",
        "qwen3": "top center",
        "yuan": "top center",
        "phi": "top center",
        "grin": "top center",
        "mixtral": "top center",
        "jamba": "top center",
        "nllb": ["top center", "top center"],
        "qwen2": "top center",
    },
    {
        "powermoe": "top center",
        "llamamoe": "top center",
        "olmoe": "top center",
        "switch": ["top center", "top center"],
        "llamamoe2": "top center",
        "jetmoe": "top center",
        "openmoe": "top center",
        "minicpm": "top center",
        "qwen": "top center",
        "deepseek2": "top center",
        "deepseek": "top center",
        "xverse": "top center",
        "qwen3": "top center",
        "yuan": "top center",
        "phi": "top center",
        "grin": "top center",
        "mixtral": "top center",
        "jamba": "top center",
        "nllb": ["top center", "top center"],
        "qwen2": "top center",
    },
    {
        "powermoe": "top center",
        "llamamoe": "top center",
        "olmoe": "top center",
        "switch": ["top center", "top center"],
        "llamamoe2": "top center",
        "jetmoe": "top center",
        "openmoe": "top center",
        "minicpm": "top center",
        "qwen": "top center",
        "deepseek2": "top center",
        "deepseek": "top center",
        "xverse": "top center",
        "qwen3": "top center",
        "yuan": "top center",
        "phi": "top center",
        "grin": "top center",
        "mixtral": "top center",
        "jamba": "top center",
        "nllb": ["top center", "top center"],
        "qwen2": "top center",
    },
]

font_size = [16, 20, 24, 28]
show_legend = True

for i, seg_len in enumerate(seg_lens):
    col = i + 1

    for j, key in enumerate(main_model_config.index.values):
        tmpdf = sddf.query(f"model == '{key}' and seg_len == {seg_len}")

        fig.add_scatter(
            x=tmpdf["best_f1"],
            y=tmpdf["act_r_std"],
            hoverinfo="skip",
            marker=go.scatter.Marker(
                color=model_colors[key],
                line=go.scatter.marker.Line(color="white", width=1),
                opacity=0.7,
                size=main_model_config.loc[key, "model_num_params"] ** 0.5 * 4,
            ),
            legendgroup=key,
            mode="markers+text",
            name=new_name[key],
            showlegend=show_legend,
            text=tmpdf.apply(make_abbr, axis=1),
            textfont=go.scatter.Textfont(size=font_size[0], shadow="auto"),
            textposition=text_pos[i][key],
            zorder=100 - j,
            row=1,
            col=col,
        )

        fig.update_xaxes(
            tickfont=go.layout.xaxis.Tickfont(size=font_size[1]),
            title=go.layout.xaxis.Title(
                font=go.layout.xaxis.title.Font(size=font_size[2]), text=f"SRP(E,{seg_len})"
            ),
            row=1,
            col=col,
        )

        fig.update_yaxes(range=[0, 0.35], showticklabels=col == 1, row=1, col=col)

        if col == 1:
            fig.update_yaxes(
                tickfont=go.layout.yaxis.Tickfont(size=font_size[1]),
                title=go.layout.yaxis.Title(
                    font=go.layout.yaxis.title.Font(size=font_size[2]), text="Act. Freq. SD"
                ),
                row=1,
                col=col,
            )

    show_legend = False

fig.update_annotations(font=go.layout.annotation.Font(size=font_size[3]))

fig.update_layout(
    legend=go.layout.Legend(
        font=go.layout.legend.Font(size=font_size[1]),
        itemsizing="constant",
        orientation="h",
        y=-0.15,
        yanchor="top",
    ),
    margin=go.layout.Margin(l=60, r=30, t=15, b=15),
    width=2000,
    height=600,
)

fig.write_image("./plot/msrpsd.pdf", width=fig.layout.width, height=fig.layout.height)
fig.show()

In [15]:
sample_seg_len = 16
edf = dfs["eg"].query(f"seg_len == {sample_seg_len}").drop(columns="seg_len")
edf

Unnamed: 0,model,is_decoder,layer_idx,expert_idx,act_r,best_f1,ci_lb,ci_ub,best_m,model_name,model_abbr,model_type,model_num_params,model_num_layers,model_num_experts,model_top_k,model_attn
1,powermoe,True,0,0,0.568817,0.740981,0.740284,0.741636,1.563594,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
5,powermoe,True,0,1,0.500232,0.686511,0.685632,0.687438,1.669474,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
9,powermoe,True,0,2,0.351072,0.572114,0.571304,0.572940,1.843012,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
13,powermoe,True,0,3,0.380526,0.600333,0.599478,0.601122,1.901469,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
17,powermoe,True,0,4,0.354294,0.570837,0.570008,0.571744,1.919815,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
110573,qwen2,True,27,59,0.124139,0.326662,0.326257,0.327032,2.575495,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2
110577,qwen2,True,27,60,0.133005,0.334691,0.334344,0.335032,2.702744,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2
110581,qwen2,True,27,61,0.131245,0.334319,0.333973,0.334675,2.684064,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2
110585,qwen2,True,27,62,0.133081,0.324485,0.324125,0.324846,2.682198,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2


In [16]:
num_cols = 10
num_rows = (len(main_model_config) - 1) // num_cols + 1

fig = make_subplots(
    rows=num_rows,
    cols=num_cols,
    shared_xaxes="all",
    shared_yaxes="all",
    horizontal_spacing=0.02,
    vertical_spacing=0.08,
    subplot_titles=sorted_model_keys.map(new_name).values,
)


def tf(x):
    return np.log(x + 1e-5) - np.log(1 - x + 1e-5)


r = np.arange(1, 10000) / 10000
rx = tf(r)
ry = 2 * r / (r + 1)
font_size = [14, 16, 18, 20]


for i, key in enumerate(sorted_model_keys):
    if (edf["model"] == key).sum() == 0:
        continue

    row = i // num_cols + 1
    col = i % num_cols + 1
    tmpdf = edf.query(f"model == '{key}'")
    num_layers = model_config.loc[key, "num_layers"]
    layer_idx = tmpdf["layer_idx"] + 1

    if model_config.loc[key, "type"] == "seq2seq":
        layer_idx += tmpdf["is_decoder"] * num_layers // 2

    fig.add_scatter(
        x=rx,
        y=ry,
        line=go.scatter.Line(color="DarkSlateGrey", dash="dot"),
        hoverinfo="skip",
        mode="lines",
        opacity=0.5,
        showlegend=False,
        row=row,
        col=col,
    )

    fig.add_vline(
        x=tf(model_config.loc[key, "top_k"] / model_config.loc[key, "num_experts"]),
        line_dash="dash",
        line_color="green",
        opacity=0.5,
        row=row,
        col=col,
    )

    fig.add_scatter(
        x=tf(tmpdf["act_r"]),
        y=tmpdf["best_f1"],
        hoverinfo="skip",
        marker=go.scatter.Marker(
            cmax=num_layers,
            cmin=1,
            color=layer_idx + 1,
            colorbar=go.scatter.marker.ColorBar(
                len=0.49,
                thickness=10,
                x=(col - 0.23) * 0.102,
                y=(num_rows - 1 - row + 1.45) * 0.54,
                tickfont=go.scatter.marker.colorbar.Tickfont(size=font_size[0]),
                tickvals=[1, num_layers // 2, num_layers],
                title=go.scatter.marker.colorbar.Title(
                    font=go.scatter.marker.colorbar.title.Font(size=font_size[1]), text="Ly."
                ),
            ),
            size=5,
        ),
        mode="markers",
        showlegend=False,
        row=row,
        col=col,
    )

    fig.update_xaxes(
        showticklabels=row == num_rows,
        tickvals=tf(np.array([0.0001, 0.01, 0.5, 0.99, 0.9999])),
        row=row,
        col=col,
    )

    if row == num_rows:
        fig.update_xaxes(
            title=go.layout.xaxis.Title(
                font=go.layout.xaxis.title.Font(size=font_size[2]), standoff=0, text="Act. Freq."
            ),
            tickfont=go.layout.xaxis.Tickfont(size=font_size[1]),
            ticktext=[".0001", ".01", ".5", ".99", ".9999"],
            row=row,
            col=col,
        )

    if col == 1:
        fig.update_yaxes(
            showticklabels=True,
            title=go.layout.yaxis.Title(
                font=go.layout.yaxis.title.Font(size=font_size[2]), text=f"SRP(e,{sample_seg_len})"
            ),
            tickfont=go.layout.yaxis.Tickfont(size=font_size[1]),
            row=row,
            col=col,
        )

fig.update_annotations(font=go.layout.annotation.Font(size=font_size[3]))
fig.update_layout(margin=go.layout.Margin(l=60, r=15, t=30, b=90), width=2000, height=500)
fig.write_image("./plot/esrp.pdf", width=fig.layout.width, height=fig.layout.height)
fig.show()