In [None]:
from pathlib import Path

import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from misc import data_config, model_config

In [3]:
main_model_config = (
    model_config.query("main")
    .drop(columns="main")
    .rename(columns={k: f"model_{k}" for k in model_config.columns})
)

new_name = {
    "powermoe": "PowerMoE",
    "llamamoe": "LLaMA-MoE-v1",
    "olmoe": "OLMoE",
    "switch": "SwitchTransformers",
    "llamamoe2": "LLaMA-MoE-v2",
    "jetmoe": "JetMoE",
    "openmoe": "OpenMoE",
    "minicpm": "MiniCPM-MoE",
    "qwen": "Qwen1.5-MoE",
    "deepseek2": "DeepSeek-V2-Lite",
    "deepseek": "DeepSeekMoE",
    "xverse": "XVERSE-MoE",
    "qwen3": "Qwen3",
    "yuan": "Yuan2.0",
    "phi": "Phi-3.5-MoE",
    "grin": "GRIN-MoE",
    "mixtral": "Mixtral-8x7B",
    "jamba": "Jamba-Mini",
    "nllb": "NLLB-MoE",
    "qwen2": "Qwen2",
}

model_colors = {
    key: px.colors.qualitative.Dark24[i] for i, key in enumerate(main_model_config.index.values)
}

seg_lens = (4, 16, 64, 256)
seg_len_colors = {key: px.colors.qualitative.Plotly[i] for i, key in enumerate(seg_lens)}
main_model_config

Unnamed: 0_level_0,model_name,model_abbr,model_type,model_num_params,model_num_layers,model_num_experts,model_top_k,model_attn
key,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
powermoe,PowerMoE-3B,PW,causal,3.3,32,40,8,flash_attention_2
llamamoe,LLaMA-MoE-v1-3.5B,LL1,causal,6.74,32,16,4,eager
olmoe,OLMoE-1B-7B-0125,OL,causal,6.92,16,64,8,flash_attention_2
switch,SwitchTransformers-Base-128,ST,seq2seq,7.42,24,128,1,eager
llamamoe2,LLaMA-MoE-v2-3.8B,LL2,causal,8.03,32,8,2,flash_attention_2
jetmoe,JetMoE-8B,JT,causal,8.52,24,8,2,flash_attention_2
openmoe,OpenMoE-8B,OP,causal,11.86,24,32,2,eager
minicpm,MiniCPM-MoE-8x2B,MC,causal,13.87,40,8,2,flash_attention_2
qwen,Qwen1.5-MoE-A2.7B,QW1,causal,14.32,24,60,4,flash_attention_2
deepseek2,DeepSeek-V2-Lite,DS2,causal,15.71,27,64,6,flash_attention_2


In [4]:
main_data_config = data_config.rename(columns={k: f"data_{k}" for k in data_config.columns})
main_data_config

Unnamed: 0_level_0,data_name,data_abbr
key,Unnamed: 1_level_1,Unnamed: 2_level_1
c4,C4,C4
cc2306,CommonCrawl,CC
book,Books,BK
wikipedia,Wikipedia,WK
arxiv,ArXiv,AX
stackexchange,StackExchange,SE
github,GitHub,GH
lmarena,LMArena,LM
math,OpenMath,OM
code,OpenCode,OC


In [5]:
def make_abbr(df):
    return (
        f"{df['model_abbr']}{'d' if df['is_decoder'] else 'e'}"
        if df["model_type"] == "seq2seq"
        else df["model_abbr"]
    )

In [None]:
root_dir = Path("../output/srp_mpq")

dfs = {
    p.stem: pd.merge(pd.read_parquet(p), main_model_config, left_on="model", right_index=True)
    for p in root_dir.glob("*.parquet")
}

for key in dfs.keys():
    if "dataset" in dfs[key].columns:
        dfs[key] = pd.merge(dfs[key], main_data_config, left_on="dataset", right_index=True)

for df in dfs.values():
    df["model"] = df["model"].astype(model_config.index.dtype)
    if "dataset" in df.columns:
        df["dataset"] = df["dataset"].astype(data_config.index.dtype)

sorted_model_keys = (
    dfs["mg"].query("is_decoder and seg_len == 16").sort_values("best_f1", ascending=False)["model"]
)

dfs["md"].query("seg_len == 16").pivot(
    index=["model", "is_decoder"], columns="dataset", values="best_f1"
)

Unnamed: 0_level_0,dataset,c4,cc2306,book,wikipedia,arxiv,stackexchange,github,lmarena,math,code,science
model,is_decoder,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
powermoe,True,0.528772,0.527592,0.530765,0.546199,0.566806,0.56316,0.56398,0.551285,0.571675,0.562916,0.554246
llamamoe,True,0.450365,0.450565,0.448192,0.454389,0.454172,0.45514,0.456557,0.456482,0.451833,0.452922,0.451186
olmoe,True,0.455232,0.454706,0.45541,0.52697,0.535874,0.550433,0.565604,0.523845,0.52972,0.547402,0.488658
switch,False,0.168423,0.169951,0.166393,0.200539,0.202298,0.198313,0.20962,0.201653,0.212789,0.199277,0.179113
switch,True,0.163001,0.162464,0.164019,0.207586,0.199107,0.196396,0.210864,0.203753,0.20776,0.195783,0.185368
llamamoe2,True,0.809891,0.807653,0.81154,0.81796,0.793941,0.767787,0.763904,0.776578,0.748637,0.719554,0.783998
jetmoe,True,0.475049,0.474884,0.470432,0.480482,0.475177,0.475215,0.470728,0.477087,0.473551,0.473441,0.474037
openmoe,True,0.269967,0.275923,0.287056,0.280349,0.297329,0.288142,0.297447,0.289828,0.291917,0.284828,0.300237
minicpm,True,0.491174,0.490462,0.486214,0.49759,0.489688,0.486447,0.483491,0.494263,0.48606,0.482844,0.486034
qwen,True,0.308066,0.308014,0.301221,0.325069,0.305697,0.309131,0.303476,0.318335,0.291613,0.302159,0.303838


In [None]:
sample_seg_len = 16

mdf = (
    pd.merge(
        dfs["mg"]
        .drop(columns=["best_m", "ci_lb", "ci_ub"])
        .rename(columns={"best_f1": "gen_best_f1"}),
        dfs["md"].drop(columns=["act_r", "best_m", "ci_lb", "ci_ub"]),
    )
    .query(f"seg_len == {sample_seg_len}")
    .drop(columns="seg_len")
)

mdf["f1_diff"] = (mdf["best_f1"] - mdf["gen_best_f1"]) / mdf["gen_best_f1"]
mdf.pivot(index=["model", "is_decoder"], columns="dataset", values="f1_diff")

Unnamed: 0_level_0,dataset,c4,cc2306,book,wikipedia,arxiv,stackexchange,github,lmarena,math,code,science
model,is_decoder,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
powermoe,True,-0.041592,-0.043732,-0.03798,-0.010006,0.027344,0.020737,0.022222,-0.000788,0.036169,0.020293,0.004579
llamamoe,True,-0.005554,-0.005112,-0.010352,0.003333,0.002852,0.00499,0.008118,0.007955,-0.002313,9.3e-05,-0.003741
olmoe,True,-0.105762,-0.106794,-0.105412,0.035159,0.052648,0.081247,0.111048,0.029019,0.04056,0.075293,-0.040101
switch,False,-0.128535,-0.120633,-0.139042,0.03764,0.046741,0.026119,0.084627,0.043404,0.101023,0.031107,-0.073224
switch,True,-0.154185,-0.156974,-0.148906,0.077168,0.033168,0.0191,0.094177,0.057274,0.078067,0.015921,-0.038125
llamamoe2,True,0.036215,0.033352,0.038325,0.046539,0.015807,-0.017655,-0.022623,-0.006408,-0.042156,-0.079367,0.003086
jetmoe,True,0.001072,0.000725,-0.008656,0.012522,0.001344,0.001422,-0.008034,0.005367,-0.002085,-0.002316,-0.00106
openmoe,True,-0.061715,-0.041017,-0.002322,-0.025634,0.033379,0.001449,0.033791,0.007311,0.014573,-0.010069,0.043488
minicpm,True,0.005396,0.003939,-0.004757,0.018529,0.002355,-0.004278,-0.01033,0.01172,-0.005071,-0.011655,-0.005124
qwen,True,0.003274,0.003107,-0.019016,0.05865,-0.004441,0.006743,-0.011672,0.036717,-0.050308,-0.015963,-0.010495


In [11]:
num_cols = 10
num_rows = (len(main_model_config) - 1) // num_cols + 1

fig = make_subplots(
    rows=num_rows,
    cols=num_cols,
    shared_xaxes="all",
    shared_yaxes="all",
    horizontal_spacing=0.005,
    vertical_spacing=0.1,
    subplot_titles=sorted_model_keys.map(new_name).values,
)

font_size = [12, 16, 18, 20]

for i, key in enumerate(sorted_model_keys):
    if (mdf["model"] == key).sum() == 0:
        continue

    row = i // num_cols + 1
    col = i % num_cols + 1

    for is_decoder in (False, True):
        tmpdf = mdf.query(f"model == '{key}' and is_decoder == {is_decoder}")
        if len(tmpdf) == 0:
            continue

        fig.add_bar(
            x=tmpdf["data_abbr"],
            y=tmpdf["f1_diff"],
            hoverinfo="skip",
            marker=go.bar.Marker(color=model_colors[key]),
            opacity=1 if is_decoder else 0.5,
            showlegend=False,
            row=row,
            col=col,
        )

    fig.update_xaxes(
        showticklabels=row == num_rows,
        tickangle=0,
        tickfont=go.layout.xaxis.Tickfont(size=font_size[0]),
        row=row,
        col=col,
    )

    fig.update_yaxes(showticklabels=col == 1, tickvals=[-0.1, 0, 0.1], row=row, col=col)

    if col == 1:
        fig.update_yaxes(
            tickfont=go.layout.yaxis.Tickfont(size=font_size[1]),
            ticktext=["-10%", "SRP<br>(E,16)", "+10%"],
            row=row,
            col=col,
        )

fig.update_annotations(font=go.layout.annotation.Font(size=font_size[3]))
fig.update_layout(margin=go.layout.Margin(l=60, r=15, t=30, b=15), width=2400, height=500)
fig.write_image("./plot/msrpdd.pdf", width=fig.layout.width, height=fig.layout.height)
fig.show()

In [13]:
sedf = dfs["ed"].query(f"seg_len == {sample_seg_len}").drop(columns="seg_len")
sedf

Unnamed: 0,model,is_decoder,layer_idx,expert_idx,dataset,act_r,best_f1,ci_lb,ci_ub,best_m,model_name,model_abbr,model_type,model_num_params,model_num_layers,model_num_experts,model_top_k,model_attn,data_name,data_abbr
1,powermoe,True,0,0,c4,0.694848,0.820701,0.819219,0.822131,1.423875,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2,C4,C4
5,powermoe,True,0,0,cc2306,0.691056,0.818060,0.816703,0.819371,1.431446,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2,CommonCrawl,CC
9,powermoe,True,0,0,book,0.646680,0.787783,0.786134,0.789406,1.486747,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2,Books,BK
13,powermoe,True,0,0,wikipedia,0.563178,0.741508,0.738549,0.744411,1.563096,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2,Wikipedia,WK
17,powermoe,True,0,0,arxiv,0.562731,0.729254,0.726513,0.732068,1.626349,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2,ArXiv,AX
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1216493,qwen2,True,27,63,github,0.129654,0.331546,0.330146,0.332850,2.651764,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2,GitHub,GH
1216497,qwen2,True,27,63,lmarena,0.126167,0.330782,0.329645,0.331981,2.611124,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2,LMArena,LM
1216501,qwen2,True,27,63,math,0.132522,0.319218,0.318248,0.320188,2.704185,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2,OpenMath,OM
1216505,qwen2,True,27,63,code,0.131186,0.325957,0.325040,0.326883,2.679304,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2,OpenCode,OC


In [14]:
dsedf = (
    sedf.pivot(
        index=["model", "is_decoder", "layer_idx", "expert_idx"],
        columns="dataset",
        values="act_r",
    )
    .groupby("model", observed=True)
    .corr()
)

dsedf

Unnamed: 0_level_0,dataset,c4,cc2306,book,wikipedia,arxiv,stackexchange,github,lmarena,math,code,science
model,dataset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
powermoe,c4,1.000000,0.978744,0.937798,0.729134,0.679949,0.686345,0.564100,0.823599,0.585133,0.593146,0.800213
powermoe,cc2306,0.978744,1.000000,0.929031,0.787864,0.670279,0.663258,0.559120,0.809376,0.571252,0.568623,0.774152
powermoe,book,0.937798,0.929031,1.000000,0.674666,0.604702,0.610619,0.488287,0.755546,0.540597,0.519877,0.755625
powermoe,wikipedia,0.729134,0.787864,0.674666,1.000000,0.609821,0.557400,0.537901,0.813285,0.492111,0.454688,0.590608
powermoe,arxiv,0.679949,0.670279,0.604702,0.609821,1.000000,0.914722,0.883749,0.856527,0.923905,0.897171,0.915194
...,...,...,...,...,...,...,...,...,...,...,...,...
qwen2,github,0.995873,0.995659,0.993598,0.991678,0.996925,0.998279,1.000000,0.994579,0.992560,0.993134,0.990244
qwen2,lmarena,0.995634,0.995559,0.994369,0.993262,0.994733,0.995635,0.994579,1.000000,0.987454,0.991225,0.990327
qwen2,math,0.990996,0.990550,0.987398,0.983966,0.992528,0.993363,0.992560,0.987454,1.000000,0.992970,0.989587
qwen2,code,0.991311,0.991136,0.987676,0.983942,0.991664,0.993705,0.993134,0.991225,0.992970,1.000000,0.987842


In [17]:
num_cols = 10
num_rows = (len(main_model_config) - 1) // num_cols + 1

fig = make_subplots(
    rows=num_rows,
    cols=num_cols,
    shared_xaxes="all",
    shared_yaxes="all",
    horizontal_spacing=0.005,
    vertical_spacing=0.1,
    subplot_titles=sorted_model_keys.map(new_name).values,
)

font_size = [12, 16, 18, 20]
show_scale = True

for i, key in enumerate(sorted_model_keys):
    if key not in dsedf.index.levels[0]:
        continue

    row = i // num_cols + 1
    col = i % num_cols + 1
    tmpdf = dsedf.loc[key]

    fig.add_heatmap(
        z=tmpdf.values,
        x=tmpdf.columns.map(main_data_config["data_abbr"]),
        y=tmpdf.index.map(main_data_config["data_abbr"]),
        colorbar=go.heatmap.ColorBar(tickfont=go.heatmap.colorbar.Tickfont(size=font_size[1])),
        colorscale="RdBu_r",
        hoverinfo="skip",
        zmin=-1,
        zmax=1,
        showscale=show_scale,
        row=row,
        col=col,
    )

    if row == num_rows:
        fig.update_xaxes(
            tickangle=0, tickfont=go.layout.xaxis.Tickfont(size=font_size[0]), row=row, col=col
        )

    if col == 1:
        fig.update_yaxes(tickfont=go.layout.yaxis.Tickfont(size=font_size[0]), row=row, col=col)

    show_scale = False

fig.update_annotations(font=go.layout.annotation.Font(size=font_size[3]))
fig.update_layout(margin=go.layout.Margin(l=30, r=15, t=30, b=15), width=2500, height=500)
fig.write_image("./plot/mactdr.pdf", width=fig.layout.width, height=fig.layout.height)
fig.show()

In [18]:
csedf = (
    sedf.pivot(
        index=["model", "is_decoder", "layer_idx", "expert_idx"],
        columns="dataset",
        values="best_f1",
    )
    .groupby("model", observed=True)
    .corr()
)

csedf

Unnamed: 0_level_0,dataset,c4,cc2306,book,wikipedia,arxiv,stackexchange,github,lmarena,math,code,science
model,dataset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
powermoe,c4,1.000000,0.978656,0.920813,0.778592,0.732542,0.745321,0.650831,0.819693,0.655827,0.624161,0.794798
powermoe,cc2306,0.978656,1.000000,0.927406,0.816800,0.728992,0.722863,0.642793,0.800040,0.640447,0.602126,0.778243
powermoe,book,0.920813,0.927406,1.000000,0.732665,0.678925,0.684056,0.587517,0.756090,0.630282,0.569597,0.759754
powermoe,wikipedia,0.778592,0.816800,0.732665,1.000000,0.645499,0.608995,0.594913,0.849671,0.539541,0.499980,0.645280
powermoe,arxiv,0.732542,0.728992,0.678925,0.645499,1.000000,0.909202,0.874565,0.815883,0.915770,0.884068,0.913671
...,...,...,...,...,...,...,...,...,...,...,...,...
qwen2,github,0.982161,0.983024,0.975862,0.975690,0.987926,0.995971,1.000000,0.989164,0.986760,0.988253,0.985717
qwen2,lmarena,0.987586,0.988611,0.983376,0.985426,0.989844,0.991351,0.989164,1.000000,0.981573,0.986330,0.983254
qwen2,math,0.981423,0.981522,0.976042,0.969141,0.984649,0.988929,0.986760,0.981573,1.000000,0.986450,0.988566
qwen2,code,0.976833,0.978263,0.971134,0.966884,0.982610,0.988228,0.988253,0.986330,0.986450,1.000000,0.980311


In [20]:
num_cols = 10
num_rows = (len(main_model_config) - 1) // num_cols + 1

fig = make_subplots(
    rows=num_rows,
    cols=num_cols,
    shared_xaxes="all",
    shared_yaxes="all",
    horizontal_spacing=0.005,
    vertical_spacing=0.1,
    subplot_titles=sorted_model_keys.map(new_name).values,
)

font_size = [12, 16, 18, 20]
show_scale = True

for i, key in enumerate(sorted_model_keys):
    if key not in csedf.index.levels[0]:
        continue

    row = i // num_cols + 1
    col = i % num_cols + 1
    tmpdf = csedf.loc[key]

    fig.add_heatmap(
        z=tmpdf.values,
        x=tmpdf.columns.map(main_data_config["data_abbr"]),
        y=tmpdf.index.map(main_data_config["data_abbr"]),
        colorbar=go.heatmap.ColorBar(tickfont=go.heatmap.colorbar.Tickfont(size=font_size[1])),
        colorscale="RdBu_r",
        hoverinfo="skip",
        zmin=-1,
        zmax=1,
        showscale=show_scale,
        row=row,
        col=col,
    )

    if row == num_rows:
        fig.update_xaxes(
            tickangle=0, tickfont=go.layout.xaxis.Tickfont(size=font_size[0]), row=row, col=col
        )

    if col == 1:
        fig.update_yaxes(tickfont=go.layout.yaxis.Tickfont(size=font_size[0]), row=row, col=col)

    show_scale = False

fig.update_annotations(font=go.layout.annotation.Font(size=font_size[3]))
fig.update_layout(margin=go.layout.Margin(l=30, r=15, t=30, b=15), width=2500, height=500)
fig.write_image("./plot/msrpdr.pdf", width=fig.layout.width, height=fig.layout.height)
fig.show()