In [None]:
from pathlib import Path

import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from misc import model_config

In [2]:
main_model_config = (
    model_config.query("main")
    .drop(columns="main")
    .rename(columns={k: f"model_{k}" for k in model_config.columns})
)

new_name = {
    "powermoe": "PowerMoE",
    "llamamoe": "LLaMA-MoE-v1",
    "olmoe": "OLMoE",
    "switch": "SwitchTransformers",
    "llamamoe2": "LLaMA-MoE-v2",
    "jetmoe": "JetMoE",
    "openmoe": "OpenMoE",
    "minicpm": "MiniCPM-MoE",
    "qwen": "Qwen1.5-MoE",
    "deepseek2": "DeepSeek-V2-Lite",
    "deepseek": "DeepSeekMoE",
    "xverse": "XVERSE-MoE",
    "qwen3": "Qwen3",
    "yuan": "Yuan2.0",
    "phi": "Phi-3.5-MoE",
    "grin": "GRIN-MoE",
    "mixtral": "Mixtral-8x7B",
    "jamba": "Jamba-Mini",
    "nllb": "NLLB-MoE",
    "qwen2": "Qwen2",
}

model_colors = {
    key: px.colors.qualitative.Dark24[i] for i, key in enumerate(main_model_config.index.values)
}

seg_lens = (4, 16, 64, 256)
seg_len_colors = {key: px.colors.qualitative.Plotly[i] for i, key in enumerate(seg_lens)}
main_model_config

Unnamed: 0_level_0,model_name,model_abbr,model_type,model_num_params,model_num_layers,model_num_experts,model_top_k,model_attn
key,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
powermoe,PowerMoE-3B,PW,causal,3.3,32,40,8,flash_attention_2
llamamoe,LLaMA-MoE-v1-3.5B,LL1,causal,6.74,32,16,4,eager
olmoe,OLMoE-1B-7B-0125,OL,causal,6.92,16,64,8,flash_attention_2
switch,SwitchTransformers-Base-128,ST,seq2seq,7.42,24,128,1,eager
llamamoe2,LLaMA-MoE-v2-3.8B,LL2,causal,8.03,32,8,2,flash_attention_2
jetmoe,JetMoE-8B,JT,causal,8.52,24,8,2,flash_attention_2
openmoe,OpenMoE-8B,OP,causal,11.86,24,32,2,eager
minicpm,MiniCPM-MoE-8x2B,MC,causal,13.87,40,8,2,flash_attention_2
qwen,Qwen1.5-MoE-A2.7B,QW1,causal,14.32,24,60,4,flash_attention_2
deepseek2,DeepSeek-V2-Lite,DS2,causal,15.71,27,64,6,flash_attention_2


In [3]:
def make_abbr(df):
    return (
        f"{df['model_abbr']}{'d' if df['is_decoder'] else 'e'}"
        if df["model_type"] == "seq2seq"
        else df["model_abbr"]
    )

In [None]:
root_dir = Path("../output/srp_mpq")

dfs = {
    p.stem: pd.merge(pd.read_parquet(p), main_model_config, left_on="model", right_index=True)
    for p in root_dir.glob("*.parquet")
}

for df in dfs.values():
    df["model"] = df["model"].astype(model_config.index.dtype)

sorted_model_keys = (
    dfs["mg"].query("is_decoder and seg_len == 16").sort_values("best_f1", ascending=False)["model"]
)

dfs["mg"].pivot(index=["model", "is_decoder"], columns="seg_len", values=["best_f1", "best_m"])

Unnamed: 0_level_0,Unnamed: 1_level_0,best_f1,best_f1,best_f1,best_f1,best_m,best_m,best_m,best_m
Unnamed: 0_level_1,seg_len,4,16,64,256,4,16,64,256
model,is_decoder,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
powermoe,True,0.673095,0.55172,0.505882,0.485872,1.103515,1.392749,1.488992,1.579756
llamamoe,True,0.557779,0.45288,0.416099,0.40618,1.029198,2.392017,2.923047,3.521086
olmoe,True,0.646949,0.509072,0.455343,0.426438,0.997344,1.056418,1.205091,1.187848
switch,False,0.416367,0.193265,0.123377,0.099904,3.803456,2.009145,2.649861,1.913562
switch,True,0.418338,0.192715,0.117753,,3.780828,2.059083,2.665879,
llamamoe2,True,0.83185,0.781586,0.764461,0.75481,1.12093,1.034675,1.058977,1.062069
jetmoe,True,0.602158,0.47454,0.427762,0.410949,1.093253,2.259037,2.687439,3.154335
openmoe,True,0.455252,0.287725,0.217618,0.187988,3.393169,1.588945,2.63504,2.177667
minicpm,True,0.625336,0.488538,0.437157,0.417272,1.105926,2.186703,2.57877,2.916163
qwen,True,0.472319,0.30706,0.225031,0.187072,3.234429,1.842482,2.105532,2.889928


In [None]:
vocab_dir = Path("../output/vocab_pq")
vdf = pd.read_parquet(vocab_dir / "gen.parquet")
vdf

Unnamed: 0,model,layer_idx,expert_idx,token_type,freq,hitoken,hifreq
0,powermoe,0,0,in,0.718943,"[ el, others, otherwise, equ, exactly, em...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..."
1,powermoe,0,0,out,0.402893,"[;|&, ergency, ;/, copes, hasis, psilon, perat...","[1.0, 0.992126, 0.98913044, 0.96875, 0.9662921..."
2,powermoe,0,0,pred,0.411291,"[qw, ;|&, ergency, ;/, psilon, hasis, Mgr, enu...","[1.0, 1.0, 0.99230766, 0.9893617, 0.96, 0.9574..."
3,powermoe,0,1,in,0.771183,"[ю, оп, knowledge, considering, sell, quic...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..."
4,powermoe,0,1,out,0.358705,"[ (^)(, RELEASE, upd, ;-, ilities, pedia, ati...","[1.0, 0.97727275, 0.9285714, 0.9166667, 0.8913..."
...,...,...,...,...,...,...,...
72720,qwen2,27,62,out,0.392961,"[ToList, typeof, maxcdn, bitcoin, unfinished, ...","[1.0, 1.0, 1.0, 0.976, 0.9667171, 0.9589041, 0..."
72721,qwen2,27,62,pred,0.395528,"[developer, ToList, stackoverflow, NSDictionar...","[1.0, 1.0, 1.0, 1.0, 0.9705882, 0.96520424, 0...."
72722,qwen2,27,63,in,0.583993,"[ �, _un, (�, Cos, пон, zn, .has, ح, �, ...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..."
72723,qwen2,27,63,out,0.428998,"[_gender, DOCTYPE, alnum, -founder, IMATION, a...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..."


In [6]:
sample_seg_len = 16
sedf = dfs["eg"].query(f"seg_len == {sample_seg_len}").drop(columns="seg_len")
sedf

Unnamed: 0,model,is_decoder,layer_idx,expert_idx,act_r,best_f1,ci_lb,ci_ub,best_m,model_name,model_abbr,model_type,model_num_params,model_num_layers,model_num_experts,model_top_k,model_attn
1,powermoe,True,0,0,0.568817,0.740981,0.740284,0.741636,1.563594,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
5,powermoe,True,0,1,0.500232,0.686511,0.685632,0.687438,1.669474,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
9,powermoe,True,0,2,0.351072,0.572114,0.571304,0.572940,1.843012,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
13,powermoe,True,0,3,0.380526,0.600333,0.599478,0.601122,1.901469,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
17,powermoe,True,0,4,0.354294,0.570837,0.570008,0.571744,1.919815,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
110573,qwen2,True,27,59,0.124139,0.326662,0.326257,0.327032,2.575495,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2
110577,qwen2,True,27,60,0.133005,0.334691,0.334344,0.335032,2.702744,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2
110581,qwen2,True,27,61,0.131245,0.334319,0.333973,0.334675,2.684064,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2
110585,qwen2,True,27,62,0.133081,0.324485,0.324125,0.324846,2.682198,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2


In [7]:
dedf = pd.merge(
    sedf,
    dfs["ed"]
    .query(f"seg_len == {sample_seg_len} and act_r > 1e-5")
    .drop(columns="seg_len")
    .groupby(["model", "is_decoder", "layer_idx", "expert_idx"], as_index=False, observed=True)
    .apply(
        lambda df: pd.Series({"act_r_cv": df["act_r"].std() / df["act_r"].mean()}),
        include_groups=False,
    ),
)

dedf

Unnamed: 0,model,is_decoder,layer_idx,expert_idx,act_r,best_f1,ci_lb,ci_ub,best_m,model_name,model_abbr,model_type,model_num_params,model_num_layers,model_num_experts,model_top_k,model_attn,act_r_cv
0,powermoe,True,0,0,0.568817,0.740981,0.740284,0.741636,1.563594,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2,0.153358
1,powermoe,True,0,1,0.500232,0.686511,0.685632,0.687438,1.669474,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2,0.152950
2,powermoe,True,0,2,0.351072,0.572114,0.571304,0.572940,1.843012,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2,0.259543
3,powermoe,True,0,3,0.380526,0.600333,0.599478,0.601122,1.901469,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2,0.178769
4,powermoe,True,0,4,0.354294,0.570837,0.570008,0.571744,1.919815,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2,0.160276
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
23603,qwen2,True,27,59,0.124139,0.326662,0.326257,0.327032,2.575495,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2,0.028844
23604,qwen2,True,27,60,0.133005,0.334691,0.334344,0.335032,2.702744,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2,0.024259
23605,qwen2,True,27,61,0.131245,0.334319,0.333973,0.334675,2.684064,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2,0.027346
23606,qwen2,True,27,62,0.133081,0.324485,0.324125,0.324846,2.682198,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2,0.035607


In [8]:
vedf = pd.merge(
    sedf,
    vdf.pivot(index=["model", "layer_idx", "expert_idx"], columns="token_type", values="freq")
    .rename(columns={k: f"{k}_freq" for k in ("in", "pred", "out")})
    .reset_index(),
)

vedf

Unnamed: 0,model,is_decoder,layer_idx,expert_idx,act_r,best_f1,ci_lb,ci_ub,best_m,model_name,model_abbr,model_type,model_num_params,model_num_layers,model_num_experts,model_top_k,model_attn,in_freq,out_freq,pred_freq
0,powermoe,True,0,0,0.568817,0.740981,0.740284,0.741636,1.563594,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2,0.718943,0.402893,0.411291
1,powermoe,True,0,1,0.500232,0.686511,0.685632,0.687438,1.669474,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2,0.771183,0.358705,0.369789
2,powermoe,True,0,2,0.351072,0.572114,0.571304,0.572940,1.843012,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2,0.940814,0.548292,0.528278
3,powermoe,True,0,3,0.380526,0.600333,0.599478,0.601122,1.901469,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2,0.943600,0.388211,0.403397
4,powermoe,True,0,4,0.354294,0.570837,0.570008,0.571744,1.919815,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2,0.582862,0.322898,0.336349
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
20216,qwen2,True,27,59,0.124139,0.326662,0.326257,0.327032,2.575495,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2,0.560779,0.416438,0.428814
20217,qwen2,True,27,60,0.133005,0.334691,0.334344,0.335032,2.702744,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2,0.543648,0.346100,0.364221
20218,qwen2,True,27,61,0.131245,0.334319,0.333973,0.334675,2.684064,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2,0.578887,0.333648,0.342862
20219,qwen2,True,27,62,0.133081,0.324485,0.324125,0.324846,2.682198,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2,0.519721,0.392961,0.395528


In [9]:
smdf = dfs["mg"].query(f"seg_len == {sample_seg_len}").drop(columns="seg_len")
smdf.sort_values("best_f1", ascending=False)

Unnamed: 0,model,is_decoder,act_r,best_f1,ci_lb,ci_ub,best_m,model_name,model_abbr,model_type,model_num_params,model_num_layers,model_num_experts,model_top_k,model_attn
36,llamamoe2,True,0.25,0.781586,0.781215,0.781927,1.034675,LLaMA-MoE-v2-3.8B,LL2,causal,8.03,32,8,2,flash_attention_2
84,yuan,True,0.0625,0.634819,0.634489,0.635146,0.884136,Yuan2.0-M32,Y2,causal,39.94,24,32,2,flash_attention_2
1,powermoe,True,0.2,0.55172,0.551487,0.551961,1.392749,PowerMoE-3B,PW,causal,3.3,32,40,8,flash_attention_2
80,qwen3,True,0.0625,0.54143,0.540983,0.541858,1.065951,Qwen3-30B-A3B,QW3,causal,30.53,48,128,8,flash_attention_2
88,phi,True,0.125,0.519771,0.519196,0.520402,1.14119,Phi-3.5-MoE,PH,causal,41.87,32,16,2,flash_attention_2
13,olmoe,True,0.125,0.509072,0.508583,0.509616,1.056418,OLMoE-1B-7B-0125,OL,causal,6.92,16,64,8,flash_attention_2
92,grin,True,0.125,0.503854,0.503271,0.504423,1.107831,GRIN-MoE,GR,causal,41.87,32,16,2,flash_attention_2
96,mixtral,True,0.25,0.493599,0.493493,0.493707,2.175447,Mixtral-8x7B-v0.1,MX,causal,46.7,32,8,2,flash_attention_2
60,minicpm,True,0.25,0.488538,0.488427,0.48865,2.186703,MiniCPM-MoE-8x2B,MC,causal,13.87,40,8,2,flash_attention_2
44,jetmoe,True,0.25,0.47454,0.474448,0.474627,2.259037,JetMoE-8B,JT,causal,8.52,24,8,2,flash_attention_2


In [10]:
dmcdf = pd.merge(
    smdf,
    dedf.groupby(["model", "is_decoder"], as_index=False, observed=True).apply(
        lambda df: pd.Series(
            {"ds": df["act_r_cv"].mean(), "corr": df["best_f1"].corr(df["act_r_cv"])}
        ),
        include_groups=False,
    ),
)

dmcdf.sort_values("best_f1", ascending=False)

Unnamed: 0,model,is_decoder,act_r,best_f1,ci_lb,ci_ub,best_m,model_name,model_abbr,model_type,model_num_params,model_num_layers,model_num_experts,model_top_k,model_attn,ds,corr
5,llamamoe2,True,0.25,0.781586,0.781215,0.781927,1.034675,LLaMA-MoE-v2-3.8B,LL2,causal,8.03,32,8,2,flash_attention_2,0.840082,-0.396718
14,yuan,True,0.0625,0.634819,0.634489,0.635146,0.884136,Yuan2.0-M32,Y2,causal,39.94,24,32,2,flash_attention_2,0.832104,0.081963
0,powermoe,True,0.2,0.55172,0.551487,0.551961,1.392749,PowerMoE-3B,PW,causal,3.3,32,40,8,flash_attention_2,0.372975,0.008011
13,qwen3,True,0.0625,0.54143,0.540983,0.541858,1.065951,Qwen3-30B-A3B,QW3,causal,30.53,48,128,8,flash_attention_2,1.015582,0.643646
15,phi,True,0.125,0.519771,0.519196,0.520402,1.14119,Phi-3.5-MoE,PH,causal,41.87,32,16,2,flash_attention_2,0.673283,0.765383
2,olmoe,True,0.125,0.509072,0.508583,0.509616,1.056418,OLMoE-1B-7B-0125,OL,causal,6.92,16,64,8,flash_attention_2,0.565762,0.675895
16,grin,True,0.125,0.503854,0.503271,0.504423,1.107831,GRIN-MoE,GR,causal,41.87,32,16,2,flash_attention_2,0.634935,0.823552
17,mixtral,True,0.25,0.493599,0.493493,0.493707,2.175447,Mixtral-8x7B-v0.1,MX,causal,46.7,32,8,2,flash_attention_2,0.108105,0.114573
8,minicpm,True,0.25,0.488538,0.488427,0.48865,2.186703,MiniCPM-MoE-8x2B,MC,causal,13.87,40,8,2,flash_attention_2,0.096963,0.332728
6,jetmoe,True,0.25,0.47454,0.474448,0.474627,2.259037,JetMoE-8B,JT,causal,8.52,24,8,2,flash_attention_2,0.083977,0.643061


In [11]:
vmcdf = pd.merge(
    smdf,
    vedf.groupby(["model", "is_decoder"], as_index=False, observed=True).apply(
        lambda df: pd.concat(
            [
                pd.Series(
                    {
                        f"{k}_vs": df[f"{k}_freq"].mean(),
                        f"{k}_corr": df["best_f1"].corr(df[f"{k}_freq"]),
                    }
                )
                for k in ("in", "pred", "out")
            ]
        ),
        include_groups=False,
    ),
)

vmcdf.sort_values("best_f1", ascending=False)

Unnamed: 0,model,is_decoder,act_r,best_f1,ci_lb,ci_ub,best_m,model_name,model_abbr,model_type,...,model_num_layers,model_num_experts,model_top_k,model_attn,in_vs,in_corr,pred_vs,pred_corr,out_vs,out_corr
3,llamamoe2,True,0.25,0.781586,0.781215,0.781927,1.034675,LLaMA-MoE-v2-3.8B,LL2,causal,...,32,8,2,flash_attention_2,0.604184,0.335311,0.553795,0.373816,0.515217,0.518879
12,yuan,True,0.0625,0.634819,0.634489,0.635146,0.884136,Yuan2.0-M32,Y2,causal,...,24,32,2,flash_attention_2,0.700695,-0.014754,0.546642,0.396525,0.530193,0.388381
0,powermoe,True,0.2,0.55172,0.551487,0.551961,1.392749,PowerMoE-3B,PW,causal,...,32,40,8,flash_attention_2,0.493072,-0.128305,0.370982,-0.103815,0.348098,-0.141077
11,qwen3,True,0.0625,0.54143,0.540983,0.541858,1.065951,Qwen3-30B-A3B,QW3,causal,...,48,128,8,flash_attention_2,0.365378,-0.163869,0.300987,0.043979,0.277934,0.042567
13,phi,True,0.125,0.519771,0.519196,0.520402,1.14119,Phi-3.5-MoE,PH,causal,...,32,16,2,flash_attention_2,0.610751,-0.280532,0.47911,0.318297,0.454195,0.314817
2,olmoe,True,0.125,0.509072,0.508583,0.509616,1.056418,OLMoE-1B-7B-0125,OL,causal,...,16,64,8,flash_attention_2,0.417058,0.187233,0.319714,0.418033,0.28434,0.427685
14,grin,True,0.125,0.503854,0.503271,0.504423,1.107831,GRIN-MoE,GR,causal,...,32,16,2,flash_attention_2,0.610343,-0.281619,0.478315,0.316785,0.453627,0.316188
15,mixtral,True,0.25,0.493599,0.493493,0.493707,2.175447,Mixtral-8x7B-v0.1,MX,causal,...,32,8,2,flash_attention_2,0.55459,-0.21417,0.424376,-0.143317,0.404657,-0.155179
6,minicpm,True,0.25,0.488538,0.488427,0.48865,2.186703,MiniCPM-MoE-8x2B,MC,causal,...,40,8,2,flash_attention_2,0.586272,-0.213344,0.444673,0.163532,0.416286,0.163397
4,jetmoe,True,0.25,0.47454,0.474448,0.474627,2.259037,JetMoE-8B,JT,causal,...,24,8,2,flash_attention_2,0.667016,-0.593696,0.472356,0.256363,0.447456,0.241025


In [12]:
fig = make_subplots(
    rows=1,
    cols=len(seg_lens),
    shared_xaxes="all",
    horizontal_spacing=0.01,
    subplot_titles=[
        "Domain Specialization",
        "Input Vocabulary Specialization",
        "Pred. Output Vocab. Spec.",
        "G. T. Output Vocab. Spec.",
    ],
)

text_pos = [
    {
        "powermoe": "bottom center",
        "llamamoe": "bottom center",
        "olmoe": "bottom center",
        "switch": ["top center", "bottom center"],
        "llamamoe2": "bottom center",
        "jetmoe": "bottom center",
        "openmoe": "middle right",
        "minicpm": "middle left",
        "qwen": "bottom center",
        "deepseek2": "top center",
        "deepseek": "middle left",
        "xverse": "middle right",
        "qwen3": "bottom center",
        "yuan": "bottom center",
        "phi": "middle right",
        "grin": "middle left",
        "mixtral": "middle left",
        "jamba": "bottom center",
        "nllb": ["bottom center", "middle right"],
        "qwen2": "bottom center",
    },
    {
        "powermoe": "middle right",
        "llamamoe": "bottom center",
        "olmoe": "bottom center",
        "llamamoe2": "bottom center",
        "jetmoe": "bottom center",
        "openmoe": "bottom center",
        "minicpm": "top center",
        "qwen": "bottom center",
        "deepseek2": "middle right",
        "deepseek": "middle right",
        "xverse": "bottom center",
        "qwen3": "top center",
        "yuan": "bottom center",
        "phi": "middle right",
        "grin": "middle left",
        "mixtral": "middle left",
        "jamba": "bottom center",
        "qwen2": "bottom center",
    },
    {
        "powermoe": "bottom center",
        "llamamoe": "bottom center",
        "olmoe": "top center",
        "llamamoe2": "bottom center",
        "jetmoe": "middle left",
        "openmoe": "bottom center",
        "minicpm": "bottom center",
        "qwen": "top center",
        "deepseek2": "middle right",
        "deepseek": "bottom center",
        "xverse": "bottom center",
        "qwen3": "bottom center",
        "yuan": "bottom center",
        "phi": "middle right",
        "grin": "middle left",
        "mixtral": "bottom center",
        "jamba": "top center",
        "qwen2": "top center",
    },
    {
        "powermoe": "bottom center",
        "llamamoe": "bottom center",
        "olmoe": "top center",
        "llamamoe2": "bottom center",
        "jetmoe": "bottom center",
        "openmoe": "bottom center",
        "minicpm": "middle right",
        "qwen": "top center",
        "deepseek2": "middle right",
        "deepseek": "bottom center",
        "xverse": "bottom center",
        "qwen3": "bottom center",
        "yuan": "bottom center",
        "phi": "middle right",
        "grin": "middle left",
        "mixtral": "bottom center",
        "jamba": "top center",
        "qwen2": "top center",
    },
]

font_size = [16, 20, 24, 28]

for j, key in enumerate(main_model_config.index.values):
    tmpdf = dmcdf.query(f"model == '{key}'")
    if len(tmpdf) == 0:
        continue

    fig.add_scatter(
        x=tmpdf["best_f1"],
        y=tmpdf["corr"],
        hoverinfo="skip",
        marker=go.scatter.Marker(
            color=model_colors[key],
            line=go.scatter.marker.Line(color="white", width=1),
            opacity=0.7,
            size=tmpdf["ds"] ** 0.5 * 30,
        ),
        legendgroup=key,
        mode="markers+text",
        name=new_name[key],
        showlegend=True,
        text=tmpdf.apply(make_abbr, axis=1),
        textfont=go.scatter.Textfont(size=font_size[0], shadow="auto"),
        textposition=text_pos[0][key],
        zorder=100 - j,
        row=1,
        col=1,
    )

    fig.update_xaxes(
        tickfont=go.layout.xaxis.Tickfont(size=font_size[1]),
        title=go.layout.xaxis.Title(
            font=go.layout.xaxis.title.Font(size=font_size[2]), text=f"SRP(E,{sample_seg_len})"
        ),
        row=1,
        col=1,
    )

    fig.update_yaxes(
        range=[-1, 1],
        showticklabels=True,
        tickfont=go.layout.yaxis.Tickfont(size=font_size[1]),
        title=go.layout.yaxis.Title(
            font=go.layout.yaxis.title.Font(size=font_size[2]),
            text="Correlation",
        ),
        row=1,
        col=1,
    )

for j, key in enumerate(main_model_config.index.values):
    tmpdf = vmcdf.query(f"model == '{key}'")
    if len(tmpdf) == 0:
        continue

    for i, prefix in enumerate(("in", "pred", "out")):
        col = i + 2

        fig.add_scatter(
            x=tmpdf["best_f1"],
            y=tmpdf[f"{prefix}_corr"],
            hoverinfo="skip",
            marker=go.scatter.Marker(
                color=model_colors[key],
                line=go.scatter.marker.Line(color="white", width=1),
                opacity=0.7,
                size=tmpdf[f"{prefix}_vs"] * 30,
            ),
            legendgroup=key,
            mode="markers+text",
            name=new_name[key],
            showlegend=False,
            text=tmpdf.apply(make_abbr, axis=1),
            textfont=go.scatter.Textfont(size=font_size[0], shadow="auto"),
            textposition=text_pos[i + 1][key],
            zorder=100 - j,
            row=1,
            col=col,
        )

        fig.update_xaxes(
            tickfont=go.layout.xaxis.Tickfont(size=font_size[1]),
            title=go.layout.xaxis.Title(
                font=go.layout.xaxis.title.Font(size=font_size[2]), text=f"SRP(E,{sample_seg_len})"
            ),
            row=1,
            col=col,
        )

        fig.update_yaxes(range=[-1, 1], showticklabels=False, row=1, col=col)

fig.update_annotations(font=go.layout.annotation.Font(size=font_size[3]))

fig.update_layout(
    legend=go.layout.Legend(
        font=go.layout.legend.Font(size=font_size[1]),
        itemsizing="constant",
        orientation="h",
        y=-0.15,
        yanchor="top",
    ),
    margin=go.layout.Margin(l=60, r=30, t=30, b=15),
    width=2000,
    height=600,
)

fig.write_image("./plot/msrpsc.pdf", width=fig.layout.width, height=fig.layout.height)
fig.show()

In [16]:
num_cols = 10
num_rows = (len(main_model_config) - 1) // num_cols + 1

fig = make_subplots(
    rows=num_rows,
    cols=num_cols,
    shared_xaxes="all",
    shared_yaxes="all",
    horizontal_spacing=0.02,
    vertical_spacing=0.08,
    subplot_titles=sorted_model_keys.map(new_name).values,
)

font_size = [14, 16, 18, 20]

for i, key in enumerate(sorted_model_keys):
    if (dedf["model"] == key).sum() == 0:
        continue

    row = i // num_cols + 1
    col = i % num_cols + 1
    tmpdf = dedf.query(f"model == '{key}'")
    num_layers = model_config.loc[key, "num_layers"]
    layer_idx = tmpdf["layer_idx"] + 1

    if model_config.loc[key, "type"] == "seq2seq":
        layer_idx += tmpdf["is_decoder"] * num_layers // 2

    fig.add_scatter(
        x=tmpdf["act_r_cv"],
        y=tmpdf["best_f1"],
        hoverinfo="skip",
        marker=go.scatter.Marker(
            cmax=num_layers,
            cmin=1,
            color=layer_idx + 1,
            colorbar=go.scatter.marker.ColorBar(
                len=0.49,
                thickness=10,
                x=(col - 0.23) * 0.102,
                y=(num_rows - 1 - row + 1.45) * 0.54,
                tickfont=go.scatter.marker.colorbar.Tickfont(size=font_size[0]),
                tickvals=[1, num_layers // 2, num_layers],
                title=go.scatter.marker.colorbar.Title(
                    font=go.scatter.marker.colorbar.title.Font(size=font_size[1]), text="Ly."
                ),
            ),
            size=5,
        ),
        mode="markers",
        showlegend=False,
        row=row,
        col=col,
    )

    fig.update_xaxes(showticklabels=row == num_rows, row=row, col=col)

    if row == num_rows:
        fig.update_xaxes(
            title=go.layout.xaxis.Title(
                font=go.layout.xaxis.title.Font(size=font_size[2]), standoff=5, text="Domain Spec."
            ),
            tickfont=go.layout.xaxis.Tickfont(size=font_size[1]),
            row=row,
            col=col,
        )

    if col == 1:
        fig.update_yaxes(
            showticklabels=True,
            title=go.layout.yaxis.Title(
                font=go.layout.yaxis.title.Font(size=font_size[2]), text=f"SRP(e,{sample_seg_len})"
            ),
            tickfont=go.layout.yaxis.Tickfont(size=font_size[1]),
            row=row,
            col=col,
        )

fig.update_annotations(font=go.layout.annotation.Font(size=font_size[3]))
fig.update_layout(margin=go.layout.Margin(l=60, r=15, t=30, b=60), width=2000, height=500)
fig.write_image("./plot/esrpds.pdf", width=fig.layout.width, height=fig.layout.height)
fig.show()

In [17]:
num_cols = 10
num_rows = (len(main_model_config) - 1) // num_cols + 1

fig = make_subplots(
    rows=num_rows,
    cols=num_cols,
    shared_xaxes="all",
    shared_yaxes="all",
    horizontal_spacing=0.02,
    vertical_spacing=0.08,
    subplot_titles=sorted_model_keys.map(new_name).values,
)

font_size = [14, 16, 18, 20]

for i, key in enumerate(sorted_model_keys):
    if (vedf["model"] == key).sum() == 0:
        continue

    row = i // num_cols + 1
    col = i % num_cols + 1
    tmpdf = vedf.query(f"model == '{key}'")
    num_layers = model_config.loc[key, "num_layers"]
    layer_idx = tmpdf["layer_idx"] + 1

    if model_config.loc[key, "type"] == "seq2seq":
        layer_idx += tmpdf["is_decoder"] * num_layers // 2

    fig.add_scatter(
        x=tmpdf["in_freq"],
        y=tmpdf["best_f1"],
        hoverinfo="skip",
        marker=go.scatter.Marker(
            cmax=num_layers,
            cmin=1,
            color=layer_idx + 1,
            colorbar=go.scatter.marker.ColorBar(
                len=0.49,
                thickness=10,
                x=(col - 0.23) * 0.102,
                y=(num_rows - 1 - row + 1.45) * 0.54,
                tickfont=go.scatter.marker.colorbar.Tickfont(size=font_size[0]),
                tickvals=[1, num_layers // 2, num_layers],
                title=go.scatter.marker.colorbar.Title(
                    font=go.scatter.marker.colorbar.title.Font(size=font_size[1]), text="Ly."
                ),
            ),
            size=5,
        ),
        mode="markers",
        showlegend=False,
        row=row,
        col=col,
    )

    fig.update_xaxes(showticklabels=row == num_rows, row=row, col=col)

    if row == num_rows:
        fig.update_xaxes(
            title=go.layout.xaxis.Title(
                font=go.layout.xaxis.title.Font(size=font_size[2]),
                standoff=5,
                text="In. Vocab. Spec.",
            ),
            tickfont=go.layout.xaxis.Tickfont(size=font_size[1]),
            row=row,
            col=col,
        )

    if col == 1:
        fig.update_yaxes(
            showticklabels=True,
            title=go.layout.yaxis.Title(
                font=go.layout.yaxis.title.Font(size=font_size[2]), text=f"SRP(e,{sample_seg_len})"
            ),
            tickfont=go.layout.yaxis.Tickfont(size=font_size[1]),
            row=row,
            col=col,
        )

fig.update_annotations(font=go.layout.annotation.Font(size=font_size[3]))
fig.update_layout(margin=go.layout.Margin(l=60, r=15, t=30, b=60), width=2000, height=500)
fig.write_image("./plot/esrpivs.pdf", width=fig.layout.width, height=fig.layout.height)
fig.show()

In [18]:
num_cols = 10
num_rows = (len(main_model_config) - 1) // num_cols + 1

fig = make_subplots(
    rows=num_rows,
    cols=num_cols,
    shared_xaxes="all",
    shared_yaxes="all",
    horizontal_spacing=0.02,
    vertical_spacing=0.08,
    subplot_titles=sorted_model_keys.map(new_name).values,
)

font_size = [14, 16, 18, 20]

for i, key in enumerate(sorted_model_keys):
    if (vedf["model"] == key).sum() == 0:
        continue

    row = i // num_cols + 1
    col = i % num_cols + 1
    tmpdf = vedf.query(f"model == '{key}'")
    num_layers = model_config.loc[key, "num_layers"]
    layer_idx = tmpdf["layer_idx"] + 1

    if model_config.loc[key, "type"] == "seq2seq":
        layer_idx += tmpdf["is_decoder"] * num_layers // 2

    fig.add_scatter(
        x=tmpdf["pred_freq"],
        y=tmpdf["best_f1"],
        hoverinfo="skip",
        marker=go.scatter.Marker(
            cmax=num_layers,
            cmin=1,
            color=layer_idx + 1,
            colorbar=go.scatter.marker.ColorBar(
                len=0.49,
                thickness=10,
                x=(col - 0.23) * 0.102,
                y=(num_rows - 1 - row + 1.45) * 0.54,
                tickfont=go.scatter.marker.colorbar.Tickfont(size=font_size[0]),
                tickvals=[1, num_layers // 2, num_layers],
                title=go.scatter.marker.colorbar.Title(
                    font=go.scatter.marker.colorbar.title.Font(size=font_size[1]), text="Ly."
                ),
            ),
            size=5,
        ),
        mode="markers",
        showlegend=False,
        row=row,
        col=col,
    )

    fig.update_xaxes(showticklabels=row == num_rows, row=row, col=col)

    if row == num_rows:
        fig.update_xaxes(
            title=go.layout.xaxis.Title(
                font=go.layout.xaxis.title.Font(size=font_size[2]),
                standoff=5,
                text="G. T. O. Vocab. Spec.",
            ),
            tickfont=go.layout.xaxis.Tickfont(size=font_size[1]),
            row=row,
            col=col,
        )

    if col == 1:
        fig.update_yaxes(
            showticklabels=True,
            title=go.layout.yaxis.Title(
                font=go.layout.yaxis.title.Font(size=font_size[2]), text=f"SRP(e,{sample_seg_len})"
            ),
            tickfont=go.layout.yaxis.Tickfont(size=font_size[1]),
            row=row,
            col=col,
        )

fig.update_annotations(font=go.layout.annotation.Font(size=font_size[3]))
fig.update_layout(margin=go.layout.Margin(l=60, r=15, t=30, b=60), width=2000, height=500)
fig.write_image("./plot/esrpovs.pdf", width=fig.layout.width, height=fig.layout.height)
fig.show()

In [19]:
num_cols = 10
num_rows = (len(main_model_config) - 1) // num_cols + 1

fig = make_subplots(
    rows=num_rows,
    cols=num_cols,
    shared_xaxes="all",
    shared_yaxes="all",
    horizontal_spacing=0.02,
    vertical_spacing=0.08,
    subplot_titles=sorted_model_keys.map(new_name).values,
)

font_size = [14, 16, 18, 20]

for i, key in enumerate(sorted_model_keys):
    if (vedf["model"] == key).sum() == 0:
        continue

    row = i // num_cols + 1
    col = i % num_cols + 1
    tmpdf = vedf.query(f"model == '{key}'")
    num_layers = model_config.loc[key, "num_layers"]
    layer_idx = tmpdf["layer_idx"] + 1

    if model_config.loc[key, "type"] == "seq2seq":
        layer_idx += tmpdf["is_decoder"] * num_layers // 2

    fig.add_scatter(
        x=tmpdf["pred_freq"],
        y=tmpdf["best_f1"],
        hoverinfo="skip",
        marker=go.scatter.Marker(
            cmax=num_layers,
            cmin=1,
            color=layer_idx + 1,
            colorbar=go.scatter.marker.ColorBar(
                len=0.49,
                thickness=10,
                x=(col - 0.23) * 0.102,
                y=(num_rows - 1 - row + 1.45) * 0.54,
                tickfont=go.scatter.marker.colorbar.Tickfont(size=font_size[0]),
                tickvals=[1, num_layers // 2, num_layers],
                title=go.scatter.marker.colorbar.Title(
                    font=go.scatter.marker.colorbar.title.Font(size=font_size[1]), text="Ly."
                ),
            ),
            size=5,
        ),
        mode="markers",
        showlegend=False,
        row=row,
        col=col,
    )

    fig.update_xaxes(showticklabels=row == num_rows, row=row, col=col)

    if row == num_rows:
        fig.update_xaxes(
            title=go.layout.xaxis.Title(
                font=go.layout.xaxis.title.Font(size=font_size[2]),
                standoff=5,
                text="P. O. Vocab. Spec.",
            ),
            tickfont=go.layout.xaxis.Tickfont(size=font_size[1]),
            row=row,
            col=col,
        )

    if col == 1:
        fig.update_yaxes(
            showticklabels=True,
            title=go.layout.yaxis.Title(
                font=go.layout.yaxis.title.Font(size=font_size[2]), text=f"SRP(e,{sample_seg_len})"
            ),
            tickfont=go.layout.yaxis.Tickfont(size=font_size[1]),
            row=row,
            col=col,
        )

fig.update_annotations(font=go.layout.annotation.Font(size=font_size[3]))
fig.update_layout(margin=go.layout.Margin(l=60, r=15, t=30, b=60), width=2000, height=500)
fig.write_image("./plot/esrppvs.pdf", width=fig.layout.width, height=fig.layout.height)
fig.show()

In [20]:
gedf = (
    pd.merge(
        dedf.groupby(["model", "is_decoder"], observed=True)
        .apply(
            lambda df: df.query("act_r < 0.8").sort_values("best_f1", ascending=False).head(1),
            include_groups=False,
        )
        .reset_index(["model", "is_decoder"])
        .reset_index(drop=True),
        vdf,
    )
    .drop(columns="is_decoder")
    .explode(["hitoken", "hifreq"])
)

gedf

Unnamed: 0,model,layer_idx,expert_idx,act_r,best_f1,ci_lb,ci_ub,best_m,model_name,model_abbr,...,model_num_params,model_num_layers,model_num_experts,model_top_k,model_attn,act_r_cv,token_type,freq,hitoken,hifreq
0,powermoe,22,39,0.523734,0.883847,0.882681,0.885058,1.058946,PowerMoE-3B,PW,...,3.30,32,40,8,flash_attention_2,0.614856,in,0.448054,Computes,1.0
0,powermoe,22,39,0.523734,0.883847,0.882681,0.885058,1.058946,PowerMoE-3B,PW,...,3.30,32,40,8,flash_attention_2,0.614856,in,0.448054,Creates,1.0
0,powermoe,22,39,0.523734,0.883847,0.882681,0.885058,1.058946,PowerMoE-3B,PW,...,3.30,32,40,8,flash_attention_2,0.614856,in,0.448054,Inser,1.0
0,powermoe,22,39,0.523734,0.883847,0.882681,0.885058,1.058946,PowerMoE-3B,PW,...,3.30,32,40,8,flash_attention_2,0.614856,in,0.448054,Constructs,1.0
0,powermoe,22,39,0.523734,0.883847,0.882681,0.885058,1.058946,PowerMoE-3B,PW,...,3.30,32,40,8,flash_attention_2,0.614856,in,0.448054,файл,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
53,qwen2,1,38,0.639774,0.781575,0.781366,0.781770,1.520319,Qwen2-57B-A14B,QW2,...,57.41,28,64,8,flash_attention_2,0.004133,pred,0.140742,\tname,0.227273
53,qwen2,1,38,0.639774,0.781575,0.781366,0.781770,1.520319,Qwen2-57B-A14B,QW2,...,57.41,28,64,8,flash_attention_2,0.004133,pred,0.140742,-footer,0.227273
53,qwen2,1,38,0.639774,0.781575,0.781366,0.781770,1.520319,Qwen2-57B-A14B,QW2,...,57.41,28,64,8,flash_attention_2,0.004133,pred,0.140742,icates,0.227273
53,qwen2,1,38,0.639774,0.781575,0.781366,0.781770,1.520319,Qwen2-57B-A14B,QW2,...,57.41,28,64,8,flash_attention_2,0.004133,pred,0.140742,erton,0.227273
