In [None]:
from pathlib import Path

import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from misc import model_config

In [2]:
main_model_config = (
    model_config.query("main")
    .drop(columns="main")
    .rename(columns={k: f"model_{k}" for k in model_config.columns})
)

new_name = {
    "powermoe": "PowerMoE",
    "llamamoe": "LLaMA-MoE-v1",
    "olmoe": "OLMoE",
    "switch": "SwitchTransformers",
    "llamamoe2": "LLaMA-MoE-v2",
    "jetmoe": "JetMoE",
    "openmoe": "OpenMoE",
    "minicpm": "MiniCPM-MoE",
    "qwen": "Qwen1.5-MoE",
    "deepseek2": "DeepSeek-V2-Lite",
    "deepseek": "DeepSeekMoE",
    "xverse": "XVERSE-MoE",
    "qwen3": "Qwen3",
    "yuan": "Yuan2.0",
    "phi": "Phi-3.5-MoE",
    "grin": "GRIN-MoE",
    "mixtral": "Mixtral-8x7B",
    "jamba": "Jamba-Mini",
    "nllb": "NLLB-MoE",
    "qwen2": "Qwen2",
}

model_colors = {
    key: px.colors.qualitative.Dark24[i] for i, key in enumerate(main_model_config.index.values)
}

seg_lens = (4, 16, 64, 256)
seg_len_colors = {key: px.colors.qualitative.Plotly[i] for i, key in enumerate(seg_lens)}
main_model_config

Unnamed: 0_level_0,model_name,model_abbr,model_type,model_num_params,model_num_layers,model_num_experts,model_top_k,model_attn
key,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
powermoe,PowerMoE-3B,PW,causal,3.3,32,40,8,flash_attention_2
llamamoe,LLaMA-MoE-v1-3.5B,LL1,causal,6.74,32,16,4,eager
olmoe,OLMoE-1B-7B-0125,OL,causal,6.92,16,64,8,flash_attention_2
switch,SwitchTransformers-Base-128,ST,seq2seq,7.42,24,128,1,eager
llamamoe2,LLaMA-MoE-v2-3.8B,LL2,causal,8.03,32,8,2,flash_attention_2
jetmoe,JetMoE-8B,JT,causal,8.52,24,8,2,flash_attention_2
openmoe,OpenMoE-8B,OP,causal,11.86,24,32,2,eager
minicpm,MiniCPM-MoE-8x2B,MC,causal,13.87,40,8,2,flash_attention_2
qwen,Qwen1.5-MoE-A2.7B,QW1,causal,14.32,24,60,4,flash_attention_2
deepseek2,DeepSeek-V2-Lite,DS2,causal,15.71,27,64,6,flash_attention_2


In [3]:
def make_abbr(df):
    return (
        f"{df['model_abbr']}{'d' if df['is_decoder'] else 'e'}"
        if df["model_type"] == "seq2seq"
        else df["model_abbr"]
    )

In [None]:
root_dir = Path("../output/srp_pos_mpq")

dfs = {
    p.stem: pd.merge(pd.read_parquet(p), main_model_config, left_on="model", right_index=True)
    for p in root_dir.glob("*.parquet")
}

for df in dfs.values():
    df["model"] = df["model"].astype(model_config.index.dtype)

dfs["mg"]

Unnamed: 0,model,is_decoder,seg_len,start_pos,act_r,best_f1,best_m,model_name,model_abbr,model_type,model_num_params,model_num_layers,model_num_experts,model_top_k,model_attn
0,powermoe,True,4,0,0.200,0.616218,1.025195,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
1,powermoe,True,4,1,0.200,0.668296,1.089514,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
2,powermoe,True,4,2,0.200,0.673125,1.096706,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
3,powermoe,True,4,3,0.200,0.674550,1.098278,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
4,powermoe,True,4,4,0.200,0.675213,1.100563,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
48045,qwen2,True,256,252,0.125,0.275368,2.530671,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2
48046,qwen2,True,256,253,0.125,0.275366,2.530599,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2
48047,qwen2,True,256,254,0.125,0.275364,2.530506,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2
48048,qwen2,True,256,255,0.125,0.275366,2.530530,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2


In [5]:
smgdf = (
    dfs["mg"]
    .groupby(["model", "is_decoder", "seg_len"], as_index=False, observed=True)[["best_f1"]]
    .std()
)

smgdf.pivot(index=["model", "is_decoder"], columns="seg_len", values="best_f1")

Unnamed: 0_level_0,seg_len,4,16,64,256
model,is_decoder,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
powermoe,True,0.002574,0.000756,0.000314,0.000138
llamamoe,True,0.002251,0.001232,0.000644,0.000197
olmoe,True,0.004865,0.006085,0.005223,0.002647
switch,False,0.002649,0.002792,0.002069,0.000584
switch,True,0.013866,0.008928,0.000835,
llamamoe2,True,0.006526,0.003851,0.001854,0.000809
jetmoe,True,0.001859,0.001641,0.001258,0.000504
openmoe,True,0.007175,0.014517,0.009706,0.004261
minicpm,True,0.002201,0.001795,0.001372,0.000552
qwen,True,0.002048,0.003713,0.002968,0.001334


In [6]:
sample_seg_len = 16

smddf = (
    dfs["md"]
    .query(f"seg_len == {sample_seg_len}")
    .drop(columns="seg_len")
    .groupby(["model", "is_decoder", "dataset"], as_index=False, observed=True)[["best_f1"]]
    .std()
)

smddf.pivot(index=["model", "is_decoder"], columns="dataset", values="best_f1")

Unnamed: 0_level_0,dataset,c4,cc2306,book,wikipedia,arxiv,stackexchange,github,lmarena,math,code,science
model,is_decoder,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
powermoe,True,0.000962,0.000894,0.000945,0.000931,0.001202,0.000969,0.001341,0.001008,0.001214,0.00103,0.001095
llamamoe,True,0.000784,0.000811,0.000856,0.00095,0.001065,0.001368,0.002167,0.001707,0.001497,0.001142,0.001798
olmoe,True,0.001936,0.001909,0.002129,0.00213,0.009202,0.007762,0.014043,0.00571,0.00816,0.008573,0.005582
switch,False,0.00267,0.00288,0.003012,0.00425,0.003537,0.003223,0.003628,0.003465,0.002074,0.00187,0.002746
switch,True,0.011998,0.011939,0.012141,0.007719,0.008127,0.008674,0.007738,0.007896,0.007286,0.009359,0.009324
llamamoe2,True,0.004573,0.004275,0.004853,0.002754,0.004945,0.003722,0.004105,0.003023,0.00288,0.002784,0.00569
jetmoe,True,0.000857,0.000813,0.001045,0.00184,0.001525,0.001649,0.003165,0.002207,0.001871,0.001622,0.002372
openmoe,True,0.015577,0.01417,0.013227,0.011982,0.016763,0.014949,0.01302,0.012163,0.019367,0.016962,0.012137
minicpm,True,0.000858,0.000973,0.001209,0.001851,0.001937,0.001738,0.002963,0.002559,0.002161,0.001984,0.002436
qwen,True,0.001428,0.002061,0.00206,0.003159,0.004384,0.004701,0.006945,0.004562,0.004049,0.004163,0.004512


In [7]:
sorted_model_keys = (
    pd.merge(
        pd.read_parquet("./output/srp_mpq/mg.parquet"),
        main_model_config,
        left_on="model",
        right_index=True,
    )
    .query("is_decoder and seg_len == 16")
    .sort_values("best_f1", ascending=False)["model"]
)

sorted_model_keys

36     llamamoe2
84          yuan
1       powermoe
80         qwen3
88           phi
13         olmoe
92          grin
96       mixtral
60       minicpm
44        jetmoe
5       llamamoe
76        xverse
100        jamba
68     deepseek2
72      deepseek
112        qwen2
108         nllb
64          qwen
56       openmoe
33        switch
Name: model, dtype: category
Categories (27, object): ['powermoe' < 'llamamoe' < 'llamamoes' < 'olmoe' ... 'mixtral' < 'jamba' < 'nllb' < 'qwen2']

In [8]:
mgdf = (
    dfs["mg"]
    .assign(seg_pos=dfs["mg"]["start_pos"] + dfs["mg"]["seg_len"] // 2)
    .drop(columns="start_pos")
)

mgdf

Unnamed: 0,model,is_decoder,seg_len,act_r,best_f1,best_m,model_name,model_abbr,model_type,model_num_params,model_num_layers,model_num_experts,model_top_k,model_attn,seg_pos
0,powermoe,True,4,0.200,0.616218,1.025195,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2,2
1,powermoe,True,4,0.200,0.668296,1.089514,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2,3
2,powermoe,True,4,0.200,0.673125,1.096706,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2,4
3,powermoe,True,4,0.200,0.674550,1.098278,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2,5
4,powermoe,True,4,0.200,0.675213,1.100563,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2,6
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
48045,qwen2,True,256,0.125,0.275368,2.530671,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2,380
48046,qwen2,True,256,0.125,0.275366,2.530599,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2,381
48047,qwen2,True,256,0.125,0.275364,2.530506,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2,382
48048,qwen2,True,256,0.125,0.275366,2.530530,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2,383


In [10]:
num_cols = 10
num_rows = (len(main_model_config) - 1) // num_cols + 1

fig = make_subplots(
    rows=num_rows,
    cols=num_cols,
    shared_xaxes="all",
    shared_yaxes="all",
    horizontal_spacing=0.005,
    vertical_spacing=0.11,
    subplot_titles=sorted_model_keys.map(new_name).values,
)

font_size = [16, 16, 18, 20]
show_legend = True

for i, key in enumerate(sorted_model_keys):
    if (mgdf["model"] == key).sum() == 0:
        continue

    row = i // num_cols + 1
    col = i % num_cols + 1
    num_layers = model_config.loc[key, "num_layers"]

    for seg_len in seg_lens:
        tmpdf = mgdf.query(f"model == '{key}' and seg_len == {seg_len}")
        if len(tmpdf) == 0:
            continue

        line_name = f"SRP(E,{seg_len})"

        if main_model_config.loc[key, "model_type"] == "seq2seq":
            for is_decoder in (False, True):
                subdf = tmpdf.query(f"is_decoder == {is_decoder}")
                if len(subdf) == 0:
                    continue

                fig.add_scatter(
                    x=subdf["seg_pos"],
                    y=subdf["best_f1"],
                    hoverinfo="skip",
                    legendgroup=line_name,
                    line=go.scatter.Line(
                        color=seg_len_colors[seg_len], dash="solid" if is_decoder else "dot"
                    ),
                    mode="lines",
                    name=line_name,
                    opacity=1 if is_decoder else 0.5,
                    showlegend=show_legend,
                    row=row,
                    col=col,
                )
        else:
            fig.add_scatter(
                x=tmpdf["seg_pos"],
                y=tmpdf["best_f1"],
                hoverinfo="skip",
                legendgroup=line_name,
                line=go.scatter.Line(color=seg_len_colors[seg_len]),
                mode="lines",
                name=line_name,
                showlegend=show_legend,
                row=row,
                col=col,
            )

    show_legend = False

    fig.update_xaxes(
        showticklabels=True,
        tickfont=go.layout.xaxis.Tickfont(size=font_size[1]),
        row=row,
        col=col,
    )

    if row == num_rows:
        fig.update_xaxes(
            title=go.layout.xaxis.Title(
                font=go.layout.xaxis.title.Font(size=font_size[2]), standoff=1, text="Position"
            ),
            row=row,
            col=col,
        )

    fig.update_yaxes(showticklabels=col == 1, row=row, col=col)

    if col == 1:
        fig.update_yaxes(
            tickfont=go.layout.yaxis.Tickfont(size=font_size[1]),
            title=go.layout.yaxis.Title(
                font=go.layout.yaxis.title.Font(size=font_size[2]), text="SRP(E,m)"
            ),
            row=row,
            col=col,
        )

fig.update_annotations(font=go.layout.annotation.Font(size=font_size[3]))

fig.update_layout(
    legend=go.layout.Legend(
        font=go.layout.legend.Font(size=font_size[2]), orientation="h", x=0.47, xanchor="center"
    ),
    margin=go.layout.Margin(l=60, r=15, t=30, b=15),
    width=2000,
    height=500,
)

fig.write_image("./plot/msrpp.pdf", width=fig.layout.width, height=fig.layout.height)
fig.show()