In [None]:
from pathlib import Path

import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from misc import model_config

In [2]:
main_model_config = (
    model_config.query("main")
    .drop(columns="main")
    .rename(columns={k: f"model_{k}" for k in model_config.columns})
)

new_name = {
    "powermoe": "PowerMoE",
    "llamamoe": "LLaMA-MoE-v1",
    "olmoe": "OLMoE",
    "switch": "SwitchTransformers",
    "llamamoe2": "LLaMA-MoE-v2",
    "jetmoe": "JetMoE",
    "openmoe": "OpenMoE",
    "minicpm": "MiniCPM-MoE",
    "qwen": "Qwen1.5-MoE",
    "deepseek2": "DeepSeek-V2-Lite",
    "deepseek": "DeepSeekMoE",
    "xverse": "XVERSE-MoE",
    "qwen3": "Qwen3",
    "yuan": "Yuan2.0",
    "phi": "Phi-3.5-MoE",
    "grin": "GRIN-MoE",
    "mixtral": "Mixtral-8x7B",
    "jamba": "Jamba-Mini",
    "nllb": "NLLB-MoE",
    "qwen2": "Qwen2",
}

model_colors = {
    key: px.colors.qualitative.Dark24[i] for i, key in enumerate(main_model_config.index.values)
}

methods = ("LRU", "LFU", "static")
method_colors = {key: px.colors.qualitative.Plotly[i] for i, key in enumerate(methods)}
main_model_config

Unnamed: 0_level_0,model_name,model_abbr,model_type,model_num_params,model_num_layers,model_num_experts,model_top_k,model_attn
key,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
powermoe,PowerMoE-3B,PW,causal,3.3,32,40,8,flash_attention_2
llamamoe,LLaMA-MoE-v1-3.5B,LL1,causal,6.74,32,16,4,eager
olmoe,OLMoE-1B-7B-0125,OL,causal,6.92,16,64,8,flash_attention_2
switch,SwitchTransformers-Base-128,ST,seq2seq,7.42,24,128,1,eager
llamamoe2,LLaMA-MoE-v2-3.8B,LL2,causal,8.03,32,8,2,flash_attention_2
jetmoe,JetMoE-8B,JT,causal,8.52,24,8,2,flash_attention_2
openmoe,OpenMoE-8B,OP,causal,11.86,24,32,2,eager
minicpm,MiniCPM-MoE-8x2B,MC,causal,13.87,40,8,2,flash_attention_2
qwen,Qwen1.5-MoE-A2.7B,QW1,causal,14.32,24,60,4,flash_attention_2
deepseek2,DeepSeek-V2-Lite,DS2,causal,15.71,27,64,6,flash_attention_2


In [3]:
def make_abbr(df):
    return (
        f"{df['model_abbr']}{'d' if df['is_decoder'] else 'e'}"
        if df["model_type"] == "seq2seq"
        else df["model_abbr"]
    )

In [None]:
root_dir = Path("../output/chr_mpq")

dfs = {
    p.stem: pd.merge(pd.read_parquet(p), main_model_config, left_on="model", right_index=True)
    for p in root_dir.glob("*.parquet")
}

for df in dfs.values():
    df["model"] = df["model"].astype(model_config.index.dtype)

dfs["m"]

Unnamed: 0,model,is_decoder,dataset,method,cache_m,recall,ci_lb,ci_ub,model_name,model_abbr,model_type,model_num_params,model_num_layers,model_num_experts,model_top_k,model_attn
0,powermoe,True,c4,LRU,0.003906,0.0,0.0,0.0,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
1,powermoe,True,c4,LRU,0.007812,0.0,0.0,0.0,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
2,powermoe,True,c4,LRU,0.011719,0.0,0.0,0.0,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
3,powermoe,True,c4,LRU,0.015625,0.0,0.0,0.0,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
4,powermoe,True,c4,LRU,0.019531,0.0,0.0,0.0,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
918715,qwen2,True,science,static,7.982143,1.0,1.0,1.0,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2
918716,qwen2,True,science,static,7.986607,1.0,1.0,1.0,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2
918717,qwen2,True,science,static,7.991071,1.0,1.0,1.0,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2
918718,qwen2,True,science,static,7.995536,1.0,1.0,1.0,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2


In [None]:
dfs["m"].query("cache_m == 2").groupby(
    ["model", "is_decoder", "method"], as_index=False, observed=True
)[["recall"]].mean().pivot(
    index=["model", "is_decoder"], columns="method", values="recall"
).sort_values("LRU", ascending=False)

Unnamed: 0_level_0,method,LRU,LFU,static
model,is_decoder,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
llamamoe2,True,0.949757,0.955828,0.94884
yuan,True,0.681606,0.74742,0.724708
powermoe,True,0.637425,0.696776,0.670534
qwen3,True,0.610948,0.582071,0.450158
phi,True,0.584722,0.602486,0.535029
olmoe,True,0.571786,0.583027,0.522805
grin,True,0.559411,0.580581,0.51362
mixtral,True,0.547155,0.582992,0.551275
minicpm,True,0.542341,0.568518,0.551301
jetmoe,True,0.502463,0.547362,0.535108


In [6]:
dfs["m"].query("cache_m == 2").groupby(
    ["model", "is_decoder", "method"], as_index=False, observed=True
)[["ci_lb", "ci_ub"]].mean().pivot(index=["model", "is_decoder"], columns="method").swaplevel(
    0, 1, axis=1
).sort_index(axis=1).sort_values(("LRU", "ci_lb"), ascending=False)

Unnamed: 0_level_0,method,LRU,LRU,LFU,LFU,static,static
Unnamed: 0_level_1,Unnamed: 1_level_1,ci_lb,ci_ub,ci_lb,ci_ub,ci_lb,ci_ub
model,is_decoder,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2
llamamoe2,True,0.949125,0.950377,0.955224,0.956424,0.948087,0.9496
yuan,True,0.68031,0.682863,0.746226,0.748564,0.723218,0.726171
powermoe,True,0.636188,0.638677,0.695389,0.698209,0.668974,0.672048
qwen3,True,0.609235,0.612627,0.57947,0.584615,0.445747,0.454515
phi,True,0.582278,0.587088,0.599465,0.605443,0.531204,0.538787
olmoe,True,0.569629,0.573883,0.580312,0.585683,0.5194,0.526128
grin,True,0.557141,0.56166,0.577691,0.583447,0.509779,0.517428
mixtral,True,0.546289,0.547999,0.581884,0.584069,0.550496,0.552046
minicpm,True,0.541509,0.543173,0.567741,0.569247,0.550626,0.551983
jetmoe,True,0.501751,0.503169,0.546518,0.548158,0.534231,0.535958


In [7]:
mdf = pd.merge(
    dfs["m"]
    .groupby(["model", "is_decoder", "method", "cache_m"], as_index=False, observed=True)[
        ["recall"]
    ]
    .mean(),
    main_model_config,
    left_on="model",
    right_index=True,
)

mdf

Unnamed: 0,model,is_decoder,method,cache_m,recall,model_name,model_abbr,model_type,model_num_params,model_num_layers,model_num_experts,model_top_k,model_attn
0,powermoe,True,LRU,0.003906,0.0,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
1,powermoe,True,LRU,0.007812,0.0,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
2,powermoe,True,LRU,0.011719,0.0,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
3,powermoe,True,LRU,0.015625,0.0,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
4,powermoe,True,LRU,0.019531,0.0,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
...,...,...,...,...,...,...,...,...,...,...,...,...,...
70939,qwen2,True,static,7.982143,1.0,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2
70940,qwen2,True,static,7.986607,1.0,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2
70941,qwen2,True,static,7.991071,1.0,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2
70942,qwen2,True,static,7.995536,1.0,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2


In [None]:
sch_dir = Path("../output/sch_mpq")

rdf = pd.merge(
    pd.read_parquet(sch_dir / "m.parquet"), main_model_config, left_on="model", right_index=True
)

rdf["model"] = rdf["model"].astype(model_config.index.dtype)
rdf

Unnamed: 0,model,is_decoder,dataset,seg_len,cache_m,recall,ci_lb,ci_ub,model_name,model_abbr,model_type,model_num_params,model_num_layers,model_num_experts,model_top_k,model_attn
0,powermoe,True,c4,4,0.003906,0.003906,0.003906,0.003906,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
1,powermoe,True,c4,4,0.007812,0.007812,0.007812,0.007812,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
2,powermoe,True,c4,4,0.011719,0.011719,0.011719,0.011719,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
3,powermoe,True,c4,4,0.015625,0.015625,0.015625,0.015625,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
4,powermoe,True,c4,4,0.019531,0.019531,0.019531,0.019531,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1149435,qwen2,True,science,256,7.982143,1.000000,1.000000,1.000000,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2
1149436,qwen2,True,science,256,7.986607,1.000000,1.000000,1.000000,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2
1149437,qwen2,True,science,256,7.991071,1.000000,1.000000,1.000000,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2
1149438,qwen2,True,science,256,7.995536,1.000000,1.000000,1.000000,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2


In [9]:
bdf = pd.merge(
    dfs["m"][["model", "is_decoder", "dataset", "cache_m", "method", "recall"]].rename(
        columns={"recall": "chr"}
    ),
    rdf[["model", "is_decoder", "dataset", "cache_m", "seg_len", "recall"]].rename(
        columns={"recall": "sch"}
    ),
)

bdf["flag"] = bdf["sch"] >= bdf["chr"]
bdf

Unnamed: 0,model,is_decoder,dataset,cache_m,method,chr,seg_len,sch,flag
0,powermoe,True,c4,0.003906,LRU,0.0,4,0.003906,True
1,powermoe,True,c4,0.003906,LRU,0.0,16,0.003906,True
2,powermoe,True,c4,0.003906,LRU,0.0,64,0.003906,True
3,powermoe,True,c4,0.003906,LRU,0.0,256,0.003906,True
4,powermoe,True,c4,0.007812,LRU,0.0,4,0.007812,True
...,...,...,...,...,...,...,...,...,...
3096187,qwen2,True,science,7.995536,static,1.0,256,1.000000,True
3096188,qwen2,True,science,8.000000,static,1.0,4,1.000000,True
3096189,qwen2,True,science,8.000000,static,1.0,16,1.000000,True
3096190,qwen2,True,science,8.000000,static,1.0,64,1.000000,True


In [10]:
bdf.groupby(["seg_len", "method"], as_index=False, observed=True).apply(
    lambda df: pd.Series({"corr": df["chr"].corr(df["sch"])}), include_groups=False
).pivot(index="seg_len", columns="method", values="corr")

method,LRU,LFU,static
seg_len,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
4,0.81197,0.773861,0.76262
16,0.9043,0.886984,0.897946
64,0.931049,0.92825,0.955032
256,0.975234,0.991971,0.979096


In [11]:
bdf.groupby(["seg_len", "method"], as_index=False, observed=True)[["flag"]].mean().pivot(
    index="seg_len", columns="method", values="flag"
)

method,LRU,LFU,static
seg_len,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
4,1.0,1.0,1.0
16,1.0,1.0,1.0
64,0.999808,1.0,1.0
256,0.971865,1.0,0.999988


In [13]:
bdf.query("method == 'static' and not flag")

Unnamed: 0,model,is_decoder,dataset,cache_m,method,chr,seg_len,sch,flag
173059,llamamoe,True,c4,0.007812,static,0.003992,256,0.003973,False
179203,llamamoe,True,cc2306,0.007812,static,0.003885,256,0.003871,False
185347,llamamoe,True,book,0.007812,static,0.003991,256,0.003959,False


In [12]:
bdf.query("seg_len < 256 and not flag")

Unnamed: 0,model,is_decoder,dataset,cache_m,method,chr,seg_len,sch,flag
652354,minicpm,True,lmarena,1.0125,LRU,0.390382,64,0.388755,False
878078,deepseek2,True,cc2306,1.025641,LRU,0.288128,64,0.286467,False
878082,deepseek2,True,cc2306,1.032051,LRU,0.289119,64,0.28772,False
878086,deepseek2,True,cc2306,1.038462,LRU,0.289119,64,0.28897,False
1077782,deepseek,True,c4,1.024691,LRU,0.281404,64,0.279526,False
1077786,deepseek,True,c4,1.030864,LRU,0.282507,64,0.280705,False
1077790,deepseek,True,c4,1.037037,LRU,0.282507,64,0.281881,False
1098514,deepseek,True,cc2306,1.018519,LRU,0.276918,64,0.273194,False
1098518,deepseek,True,cc2306,1.024691,LRU,0.281327,64,0.274373,False
1098522,deepseek,True,cc2306,1.030864,LRU,0.282425,64,0.27555,False


In [15]:
dash_level = {
    "llamamoe2": 0,
    "yuan": 0,
    "powermoe": 0,
    "qwen3": 0,
    "phi": 0,
    "olmoe": 0,
    "grin": 0,
    "mixtral": 1,
    "minicpm": 1,
    "jetmoe": 1,
    "llamamoe": 1,
    "xverse": 2,
    "jamba": 2,
    "deepseek2": 2,
    "deepseek": 2,
    "qwen2": 2,
    "nllb": 3,
    "qwen": 3,
    "openmoe": 3,
    "switch": 3,
}

dash_style = ["solid", "longdash", "dashdot", "dot"]
dash_type = {k: dash_style[v] for k, v in dash_level.items()}

In [16]:
fig = make_subplots(
    rows=1,
    cols=len(methods),
    shared_xaxes="all",
    horizontal_spacing=0.01,
    subplot_titles=methods,
)

font_size = [16, 20, 24, 28]
show_legend = True

for i, method in enumerate(methods):
    col = i + 1

    for j, key in enumerate(main_model_config.index.values):
        tmpdf = mdf.query(f"model == '{key}' and method == '{method}'")

        if model_config.loc[key, "type"] == "seq2seq":
            for is_decoder in (False, True):
                subdf = tmpdf.query(f"is_decoder == {is_decoder}")

                fig.add_scatter(
                    x=subdf["cache_m"],
                    y=subdf["recall"],
                    hoverinfo="skip",
                    legendgroup=key,
                    line=go.scatter.Line(color=model_colors[key], dash=dash_type[key], width=2),
                    mode="lines",
                    name=f"{new_name[key]} ({'De' if is_decoder else 'En'}coder)",
                    opacity=1 if is_decoder else 0.5,
                    showlegend=show_legend,
                    row=1,
                    col=col,
                )
        else:
            fig.add_scatter(
                x=tmpdf["cache_m"],
                y=tmpdf["recall"],
                hoverinfo="skip",
                legendgroup=key,
                line=go.scatter.Line(color=model_colors[key], dash=dash_type[key], width=2),
                mode="lines",
                name=new_name[key],
                showlegend=show_legend,
                row=1,
                col=col,
            )

        fig.update_xaxes(
            range=[0, 4],
            tickfont=go.layout.xaxis.Tickfont(size=font_size[1]),
            title=go.layout.xaxis.Title(
                font=go.layout.xaxis.title.Font(size=font_size[2]), text="ρ"
            ),
            row=1,
            col=col,
        )

        fig.update_yaxes(range=[0, 1], showticklabels=col == 1, row=1, col=col)

        if col == 1:
            fig.update_yaxes(
                tickfont=go.layout.yaxis.Tickfont(size=font_size[1]),
                title=go.layout.yaxis.Title(
                    font=go.layout.yaxis.title.Font(size=font_size[2]), text="CHR(E,ρ)"
                ),
                row=1,
                col=col,
            )

    show_legend = False

fig.update_annotations(font=go.layout.annotation.Font(size=font_size[3]))

fig.update_layout(
    legend=go.layout.Legend(
        font=go.layout.legend.Font(size=font_size[1]),
        itemsizing="constant",
        orientation="h",
        y=-0.15,
        yanchor="top",
    ),
    margin=go.layout.Margin(l=60, r=30, t=30, b=15),
    width=1400,
    height=600,
)

fig.write_image("./plot/mchr.pdf", width=fig.layout.width, height=fig.layout.height)
fig.show()