In [None]:
from pathlib import Path

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from misc import model_config

In [2]:
main_model_config = (
    model_config.query("main")
    .drop(columns="main")
    .rename(columns={k: f"model_{k}" for k in model_config.columns})
)

new_name = {
    "powermoe": "PowerMoE",
    "llamamoe": "LLaMA-MoE-v1",
    "olmoe": "OLMoE",
    "switch": "SwitchTransformers",
    "llamamoe2": "LLaMA-MoE-v2",
    "jetmoe": "JetMoE",
    "openmoe": "OpenMoE",
    "minicpm": "MiniCPM-MoE",
    "qwen": "Qwen1.5-MoE",
    "deepseek2": "DeepSeek-V2-Lite",
    "deepseek": "DeepSeekMoE",
    "xverse": "XVERSE-MoE",
    "qwen3": "Qwen3",
    "yuan": "Yuan2.0",
    "phi": "Phi-3.5-MoE",
    "grin": "GRIN-MoE",
    "mixtral": "Mixtral-8x7B",
    "jamba": "Jamba-Mini",
    "nllb": "NLLB-MoE",
    "qwen2": "Qwen2",
}

model_colors = {
    key: px.colors.qualitative.Dark24[i] for i, key in enumerate(main_model_config.index.values)
}

seg_lens = (4, 16, 64, 256)
seg_len_colors = {key: px.colors.qualitative.Plotly[i] for i, key in enumerate(seg_lens)}
main_model_config

Unnamed: 0_level_0,model_name,model_abbr,model_type,model_num_params,model_num_layers,model_num_experts,model_top_k,model_attn
key,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
powermoe,PowerMoE-3B,PW,causal,3.3,32,40,8,flash_attention_2
llamamoe,LLaMA-MoE-v1-3.5B,LL1,causal,6.74,32,16,4,eager
olmoe,OLMoE-1B-7B-0125,OL,causal,6.92,16,64,8,flash_attention_2
switch,SwitchTransformers-Base-128,ST,seq2seq,7.42,24,128,1,eager
llamamoe2,LLaMA-MoE-v2-3.8B,LL2,causal,8.03,32,8,2,flash_attention_2
jetmoe,JetMoE-8B,JT,causal,8.52,24,8,2,flash_attention_2
openmoe,OpenMoE-8B,OP,causal,11.86,24,32,2,eager
minicpm,MiniCPM-MoE-8x2B,MC,causal,13.87,40,8,2,flash_attention_2
qwen,Qwen1.5-MoE-A2.7B,QW1,causal,14.32,24,60,4,flash_attention_2
deepseek2,DeepSeek-V2-Lite,DS2,causal,15.71,27,64,6,flash_attention_2


In [3]:
def make_abbr(df):
    return (
        f"{df['model_abbr']}{'d' if df['is_decoder'] else 'e'}"
        if df["model_type"] == "seq2seq"
        else df["model_abbr"]
    )

In [None]:
root_dir = Path("../output/sch_mpq")

dfs = {
    p.stem: pd.merge(pd.read_parquet(p), main_model_config, left_on="model", right_index=True)
    for p in root_dir.glob("*.parquet")
}

for df in dfs.values():
    df["model"] = df["model"].astype(model_config.index.dtype)

dfs["m"]

Unnamed: 0,model,is_decoder,dataset,seg_len,cache_m,recall,ci_lb,ci_ub,model_name,model_abbr,model_type,model_num_params,model_num_layers,model_num_experts,model_top_k,model_attn
0,powermoe,True,c4,4,0.003906,0.003906,0.003906,0.003906,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
1,powermoe,True,c4,4,0.007812,0.007812,0.007812,0.007812,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
2,powermoe,True,c4,4,0.011719,0.011719,0.011719,0.011719,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
3,powermoe,True,c4,4,0.015625,0.015625,0.015625,0.015625,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
4,powermoe,True,c4,4,0.019531,0.019531,0.019531,0.019531,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1149435,qwen2,True,science,256,7.982143,1.000000,1.000000,1.000000,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2
1149436,qwen2,True,science,256,7.986607,1.000000,1.000000,1.000000,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2
1149437,qwen2,True,science,256,7.991071,1.000000,1.000000,1.000000,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2
1149438,qwen2,True,science,256,7.995536,1.000000,1.000000,1.000000,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2


In [None]:
dfs["m"].query("cache_m == 2").groupby(
    ["model", "is_decoder", "seg_len"], as_index=False, observed=True
)[["recall"]].mean().pivot(
    index=["model", "is_decoder"], columns="seg_len", values="recall"
).sort_values(16, ascending=False)

Unnamed: 0_level_0,seg_len,4,16,64,256
model,is_decoder,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
llamamoe2,True,0.999365,0.985314,0.974056,0.967488
yuan,True,0.950463,0.824813,0.785438,0.766915
powermoe,True,0.928987,0.802172,0.746407,0.720497
qwen3,True,0.916601,0.75618,0.67217,0.618575
phi,True,0.895927,0.744333,0.673407,0.635522
mixtral,True,0.882406,0.739457,0.663284,0.623013
minicpm,True,0.881495,0.731386,0.650123,0.608004
olmoe,True,0.893379,0.728658,0.653038,0.612885
grin,True,0.88291,0.726498,0.653082,0.61352
jetmoe,True,0.856773,0.708372,0.631002,0.591265


In [5]:
dfs["m"].query("cache_m == 2").groupby(
    ["model", "is_decoder", "seg_len"], as_index=False, observed=True
)[["ci_lb", "ci_ub"]].mean().pivot(index=["model", "is_decoder"], columns="seg_len").swaplevel(
    0, 1, axis=1
).sort_index(axis=1).sort_values((16, "ci_lb"), ascending=False)

Unnamed: 0_level_0,seg_len,4,4,16,16,64,64,256,256
Unnamed: 0_level_1,Unnamed: 1_level_1,ci_lb,ci_ub,ci_lb,ci_ub,ci_lb,ci_ub,ci_lb,ci_ub
model,is_decoder,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
llamamoe2,True,0.999347,0.999384,0.98506,0.98557,0.973635,0.97447,0.966944,0.968033
yuan,True,0.949923,0.950992,0.823946,0.825665,0.784436,0.78643,0.76573,0.768095
powermoe,True,0.92848,0.929475,0.801326,0.802982,0.745301,0.747516,0.719176,0.721841
qwen3,True,0.915791,0.917399,0.754804,0.757526,0.670337,0.673966,0.616241,0.620939
phi,True,0.894849,0.896997,0.742534,0.7462,0.670998,0.675721,0.632623,0.638448
mixtral,True,0.881986,0.88281,0.738863,0.740028,0.662478,0.664052,0.622023,0.623975
minicpm,True,0.881089,0.881895,0.730869,0.731885,0.64953,0.650702,0.607313,0.60868
olmoe,True,0.892341,0.894402,0.726935,0.730353,0.650859,0.65524,0.610375,0.615384
grin,True,0.881818,0.883972,0.724701,0.728306,0.65073,0.655431,0.610738,0.61629
jetmoe,True,0.856381,0.857159,0.707904,0.708842,0.630402,0.63159,0.590519,0.591988


In [6]:
dfs["m"].assign(
    ci_dist=np.maximum(
        dfs["m"]["ci_ub"] - dfs["m"]["recall"], dfs["m"]["recall"] - dfs["m"]["ci_lb"]
    )
).groupby(["model", "is_decoder", "dataset", "seg_len"], as_index=False, observed=True)[
    ["ci_dist"]
].max().pivot(index=["model", "is_decoder"], columns=["dataset", "seg_len"], values="ci_dist")

Unnamed: 0_level_0,dataset,c4,c4,c4,c4,cc2306,cc2306,cc2306,cc2306,book,book,...,math,math,code,code,code,code,science,science,science,science
Unnamed: 0_level_1,seg_len,4,16,64,256,4,16,64,256,4,16,...,64,256,4,16,64,256,4,16,64,256
model,is_decoder,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2
powermoe,True,0.000387,0.000627,0.000856,0.00115,0.000355,0.000637,0.000871,0.001039,0.000454,0.000584,...,0.001385,0.001714,0.000399,0.000553,0.000596,0.000706,0.000827,0.001435,0.001645,0.0018
llamamoe,True,0.000379,0.000376,0.000464,0.000592,0.000357,0.000367,0.000421,0.000447,0.000414,0.000281,...,0.00041,0.000501,0.000405,0.000349,0.000376,0.000427,0.000413,0.000341,0.000455,0.000546
olmoe,True,0.000826,0.001288,0.001716,0.002064,0.000769,0.001345,0.001553,0.001887,0.000726,0.001247,...,0.002304,0.002984,0.000655,0.00089,0.000944,0.001084,0.001791,0.002945,0.003475,0.003592
switch,False,0.000362,0.000982,0.001645,0.002277,0.000522,0.001166,0.001717,0.002184,0.000415,0.001234,...,0.002346,0.003322,0.000453,0.000929,0.001218,0.001816,0.000469,0.001871,0.002897,0.004111
switch,True,0.000492,0.001289,0.001811,,0.000523,0.001327,0.001732,,0.000524,0.001526,...,0.002632,,0.000669,0.001399,0.001661,,0.000781,0.002555,0.003084,
llamamoe2,True,0.000541,0.000713,0.000786,0.000933,0.000505,0.00068,0.00072,0.000801,0.000513,0.000671,...,0.001463,0.001643,0.001233,0.001688,0.001931,0.002109,0.000624,0.000785,0.000953,0.001165
jetmoe,True,0.000289,0.000445,0.000579,0.000791,0.000314,0.000398,0.000556,0.000693,0.000301,0.000321,...,0.000562,0.000726,0.000358,0.000454,0.000465,0.000529,0.000403,0.00046,0.00067,0.000807
openmoe,True,0.000783,0.001225,0.001698,0.00201,0.000809,0.001271,0.001601,0.001853,0.001073,0.001704,...,0.002415,0.002855,0.000974,0.001276,0.001544,0.00186,0.001119,0.00179,0.002395,0.002555
minicpm,True,0.00031,0.000477,0.000608,0.000755,0.000419,0.000697,0.000702,0.000749,0.000303,0.000367,...,0.00047,0.00058,0.000386,0.000501,0.000612,0.000683,0.000386,0.000393,0.000545,0.000689
qwen,True,0.000591,0.000848,0.001106,0.001411,0.000631,0.000844,0.001172,0.001455,0.000546,0.000793,...,0.001019,0.001277,0.000537,0.000767,0.000931,0.000976,0.000673,0.000862,0.001233,0.001367


In [12]:
mdf = pd.merge(
    dfs["m"]
    .groupby(["model", "is_decoder", "seg_len", "cache_m"], as_index=False, observed=True)[
        ["recall"]
    ]
    .mean(),
    main_model_config,
    left_on="model",
    right_index=True,
)

mdf

Unnamed: 0,model,is_decoder,seg_len,cache_m,recall,model_name,model_abbr,model_type,model_num_params,model_num_layers,model_num_experts,model_top_k,model_attn
0,powermoe,True,4,0.003906,0.003906,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
1,powermoe,True,4,0.007812,0.007812,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
2,powermoe,True,4,0.011719,0.011719,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
3,powermoe,True,4,0.015625,0.015625,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
4,powermoe,True,4,0.019531,0.019531,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
...,...,...,...,...,...,...,...,...,...,...,...,...,...
93819,qwen2,True,256,7.982143,1.000000,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2
93820,qwen2,True,256,7.986607,1.000000,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2
93821,qwen2,True,256,7.991071,1.000000,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2
93822,qwen2,True,256,7.995536,1.000000,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2


In [None]:
srp_dir = Path("../output/srp_mpq")

rdf = pd.merge(
    pd.read_parquet(srp_dir / "mg.parquet"), main_model_config, left_on="model", right_index=True
)

sorted_model_keys = rdf.query("is_decoder and seg_len == 16").sort_values(
    "best_f1", ascending=False
)["model"]

cdf = (
    pd.merge(mdf, rdf)
    .query("cache_m.isin((0.5, 1, 1.5, 2, 2.5, 3))")
    .groupby(["seg_len", "cache_m"], as_index=False)
    .apply(lambda df: pd.Series({"corr": df["recall"].corr(df["best_f1"])}), include_groups=False)
)

cdf.pivot(index="seg_len", columns="cache_m", values="corr")

cache_m,0.5,1.0,1.5,2.0,2.5,3.0
seg_len,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
4,0.965829,0.979832,0.982089,0.960389,0.89065,0.750535
16,0.978689,0.995168,0.99755,0.984886,0.956964,0.923314
64,0.96746,0.990054,0.997799,0.990243,0.968081,0.936023
256,0.954106,0.983185,0.995826,0.99197,0.971122,0.936997


In [20]:
dash_level = {
    "llamamoe2": 0,
    "yuan": 0,
    "powermoe": 0,
    "qwen3": 0,
    "phi": 0,
    "olmoe": 0,
    "grin": 0,
    "mixtral": 1,
    "minicpm": 1,
    "jetmoe": 1,
    "llamamoe": 1,
    "xverse": 2,
    "jamba": 2,
    "deepseek2": 2,
    "deepseek": 2,
    "qwen2": 2,
    "nllb": 3,
    "qwen": 3,
    "openmoe": 3,
    "switch": 3,
}

dash_style = ["solid", "longdash", "dashdot", "dot"]
dash_type = {k: dash_style[v] for k, v in dash_level.items()}

In [None]:
fig = make_subplots(
    rows=1,
    cols=len(seg_lens),
    shared_xaxes="all",
    horizontal_spacing=0.01,
    subplot_titles=[f"m={seg_len}" for seg_len in seg_lens],
)

font_size = [16, 20, 24, 28]
show_legend = True

for i, seg_len in enumerate(seg_lens):
    col = i + 1

    for j, key in enumerate(main_model_config.index.values):
        tmpdf = mdf.query(f"model == '{key}' and seg_len == {seg_len}")

        if model_config.loc[key, "type"] == "seq2seq":
            for is_decoder in (False, True):
                subdf = tmpdf.query(f"is_decoder == {is_decoder}")

                fig.add_scatter(
                    x=subdf["cache_m"],
                    y=subdf["recall"],
                    hoverinfo="skip",
                    legendgroup=key,
                    line=go.scatter.Line(color=model_colors[key], dash=dash_type[key], width=2),
                    mode="lines",
                    name=f"{new_name[key]} ({'De' if is_decoder else 'En'}coder)",
                    opacity=1 if is_decoder else 0.5,
                    showlegend=show_legend,
                    row=1,
                    col=col,
                )
        else:
            fig.add_scatter(
                x=tmpdf["cache_m"],
                y=tmpdf["recall"],
                hoverinfo="skip",
                legendgroup=key,
                line=go.scatter.Line(color=model_colors[key], dash=dash_type[key], width=2),
                mode="lines",
                name=new_name[key],
                showlegend=show_legend,
                row=1,
                col=col,
            )

        fig.update_xaxes(
            range=[0, 4],
            tickfont=go.layout.xaxis.Tickfont(size=font_size[1]),
            title=go.layout.xaxis.Title(
                font=go.layout.xaxis.title.Font(size=font_size[2]), text="ρ"
            ),
            row=1,
            col=col,
        )

        fig.update_yaxes(range=[0, 1], showticklabels=col == 1, row=1, col=col)

        if col == 1:
            fig.update_yaxes(
                tickfont=go.layout.yaxis.Tickfont(size=font_size[1]),
                title=go.layout.yaxis.Title(
                    font=go.layout.yaxis.title.Font(size=font_size[2]), text="SCH(E,m,ρ)"
                ),
                row=1,
                col=col,
            )

    show_legend = False

fig.update_annotations(font=go.layout.annotation.Font(size=font_size[3]))

fig.update_layout(
    legend=go.layout.Legend(
        font=go.layout.legend.Font(size=font_size[1]),
        itemsizing="constant",
        orientation="h",
        y=-0.15,
        yanchor="top",
    ),
    margin=go.layout.Margin(l=60, r=30, t=30, b=15),
    width=2000,
    height=600,
)

fig.write_image("./plot/msch.pdf", width=fig.layout.width, height=fig.layout.height)
fig.show()

In [18]:
ldf = pd.merge(
    dfs["l"]
    .groupby(
        ["model", "is_decoder", "layer_idx", "seg_len", "cache_m"], as_index=False, observed=True
    )[["recall"]]
    .mean(),
    main_model_config,
    left_on="model",
    right_index=True,
)

ldf

Unnamed: 0,model,is_decoder,layer_idx,seg_len,cache_m,recall,model_name,model_abbr,model_type,model_num_params,model_num_layers,model_num_experts,model_top_k,model_attn
0,powermoe,True,0,4,0.125,0.105572,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
1,powermoe,True,0,4,0.250,0.200081,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
2,powermoe,True,0,4,0.375,0.286259,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
3,powermoe,True,0,4,0.500,0.364742,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
4,powermoe,True,0,4,0.625,0.436532,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
93819,qwen2,True,27,256,7.500,0.961675,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2
93820,qwen2,True,27,256,7.625,0.972169,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2
93821,qwen2,True,27,256,7.750,0.982209,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2
93822,qwen2,True,27,256,7.875,0.991632,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2


In [19]:
num_cols = 10
num_rows = (len(main_model_config) - 1) // num_cols + 1

fig = make_subplots(
    rows=num_rows,
    cols=num_cols,
    shared_xaxes="all",
    shared_yaxes="all",
    horizontal_spacing=0.005,
    vertical_spacing=0.11,
    subplot_titles=sorted_model_keys.map(new_name).values,
)

font_size = [16, 16, 18, 20]
show_legend = True

for i, key in enumerate(sorted_model_keys):
    if (ldf["model"] == key).sum() == 0:
        continue

    row = i // num_cols + 1
    col = i % num_cols + 1
    num_layers = model_config.loc[key, "num_layers"]

    for seg_len in seg_lens:
        for j, cache_m in enumerate((1, 2)):
            tmpdf = ldf.query(f"model == '{key}' and seg_len == {seg_len} and cache_m == {cache_m}")

            if len(tmpdf) == 0:
                continue

            layer_idx = tmpdf["layer_idx"] + 1
            if model_config.loc[key, "type"] == "seq2seq":
                layer_idx += tmpdf["is_decoder"] * model_config.loc[key, "num_layers"] // 2

            line_name = f"SCH(E,{seg_len},{cache_m})"

            fig.add_scatter(
                x=layer_idx / num_layers,
                y=tmpdf["recall"],
                hoverinfo="skip",
                legendgroup=line_name,
                line=go.scatter.Line(color=seg_len_colors[seg_len], dash="dot" if j else "solid"),
                marker=go.scatter.Marker(size=4),
                mode="lines" if j else "lines+markers",
                name=line_name,
                opacity=0.5 if j else 1,
                showlegend=show_legend,
                row=row,
                col=col,
            )

    show_legend = False

    fig.update_xaxes(
        showticklabels=True,
        tickfont=go.layout.xaxis.Tickfont(size=font_size[1]),
        tickvals=[1 / num_layers, 0.5, 1],
        ticktext=[1, num_layers // 2, num_layers],
        row=row,
        col=col,
    )

    if row == num_rows:
        fig.update_xaxes(
            title=go.layout.xaxis.Title(
                font=go.layout.xaxis.title.Font(size=font_size[2]), standoff=1, text="Layer"
            ),
            row=row,
            col=col,
        )

    fig.update_yaxes(showticklabels=col == 1, row=row, col=col)

    if col == 1:
        fig.update_yaxes(
            tickfont=go.layout.yaxis.Tickfont(size=font_size[1]),
            title=go.layout.yaxis.Title(
                font=go.layout.yaxis.title.Font(size=font_size[2]), text="SCH(E,m,ρ)"
            ),
            row=row,
            col=col,
        )

fig.update_annotations(font=go.layout.annotation.Font(size=font_size[3]))

fig.update_layout(
    legend=go.layout.Legend(
        font=go.layout.legend.Font(size=font_size[2]), orientation="h", x=0.47, xanchor="center"
    ),
    margin=go.layout.Margin(l=60, r=15, t=30, b=15),
    width=2000,
    height=500,
)

fig.write_image("./plot/lsch.pdf", width=fig.layout.width, height=fig.layout.height)
fig.show()

In [21]:
d1mdf = mdf.assign(
    diff1=mdf.groupby(["model", "is_decoder", "seg_len"], observed=True)[["cache_m", "recall"]]
    .diff()
    .apply(lambda s: s["recall"] / s["cache_m"], axis=1)
)

d1mdf

Unnamed: 0,model,is_decoder,seg_len,cache_m,recall,model_name,model_abbr,model_type,model_num_params,model_num_layers,model_num_experts,model_top_k,model_attn,diff1
0,powermoe,True,4,0.003906,0.003906,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2,
1,powermoe,True,4,0.007812,0.007812,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2,1.0
2,powermoe,True,4,0.011719,0.011719,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2,1.0
3,powermoe,True,4,0.015625,0.015625,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2,1.0
4,powermoe,True,4,0.019531,0.019531,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
93819,qwen2,True,256,7.982143,1.000000,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2,0.0
93820,qwen2,True,256,7.986607,1.000000,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2,0.0
93821,qwen2,True,256,7.991071,1.000000,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2,0.0
93822,qwen2,True,256,7.995536,1.000000,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2,0.0


In [22]:
fig = make_subplots(
    rows=1,
    cols=len(seg_lens),
    shared_xaxes="all",
    horizontal_spacing=0.01,
    subplot_titles=[f"m={seg_len}" for seg_len in seg_lens],
)

font_size = [16, 20, 24, 28]
show_legend = True

for i, seg_len in enumerate(seg_lens):
    col = i + 1

    for j, key in enumerate(main_model_config.index.values):
        tmpdf = d1mdf.query(f"model == '{key}' and seg_len == {seg_len}")

        if model_config.loc[key, "type"] == "seq2seq":
            for is_decoder in (False, True):
                subdf = tmpdf.query(f"is_decoder == {is_decoder}")

                fig.add_scatter(
                    x=subdf["cache_m"],
                    y=subdf["diff1"],
                    hoverinfo="skip",
                    legendgroup=key,
                    line=go.scatter.Line(color=model_colors[key], dash=dash_type[key], width=2),
                    mode="lines",
                    name=f"{new_name[key]} ({'De' if is_decoder else 'En'}coder)",
                    opacity=1 if is_decoder else 0.5,
                    showlegend=show_legend,
                    row=1,
                    col=col,
                )
        else:
            fig.add_scatter(
                x=tmpdf["cache_m"],
                y=tmpdf["diff1"],
                hoverinfo="skip",
                legendgroup=key,
                line=go.scatter.Line(color=model_colors[key], dash=dash_type[key], width=2),
                mode="lines",
                name=new_name[key],
                showlegend=show_legend,
                row=1,
                col=col,
            )

        fig.update_xaxes(
            range=[0, 4],
            tickfont=go.layout.xaxis.Tickfont(size=font_size[1]),
            title=go.layout.xaxis.Title(
                font=go.layout.xaxis.title.Font(size=font_size[2]), text="ρ"
            ),
            row=1,
            col=col,
        )

        fig.update_yaxes(range=[0, 1], showticklabels=col == 1, row=1, col=col)

        if col == 1:
            fig.update_yaxes(
                tickfont=go.layout.yaxis.Tickfont(size=font_size[1]),
                title=go.layout.yaxis.Title(
                    font=go.layout.yaxis.title.Font(size=font_size[2]), text="dSCH(E,m,ρ)/dρ"
                ),
                row=1,
                col=col,
            )

    show_legend = False

fig.update_annotations(font=go.layout.annotation.Font(size=font_size[3]))

fig.update_layout(
    legend=go.layout.Legend(
        font=go.layout.legend.Font(size=font_size[1]),
        itemsizing="constant",
        orientation="h",
        y=-0.15,
        yanchor="top",
    ),
    margin=go.layout.Margin(l=60, r=30, t=30, b=15),
    width=2000,
    height=600,
)

fig.write_image("./plot/mschd1.pdf", width=fig.layout.width, height=fig.layout.height)
fig.show()

In [23]:
d2mdf = d1mdf.assign(
    diff2=d1mdf.groupby(["model", "is_decoder", "seg_len"], observed=True)[["cache_m", "diff1"]]
    .diff()
    .apply(lambda s: s["diff1"] / s["cache_m"], axis=1)
)

d2mdf

Unnamed: 0,model,is_decoder,seg_len,cache_m,recall,model_name,model_abbr,model_type,model_num_params,model_num_layers,model_num_experts,model_top_k,model_attn,diff1,diff2
0,powermoe,True,4,0.003906,0.003906,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2,,
1,powermoe,True,4,0.007812,0.007812,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2,1.0,
2,powermoe,True,4,0.011719,0.011719,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2,1.0,0.0
3,powermoe,True,4,0.015625,0.015625,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2,1.0,0.0
4,powermoe,True,4,0.019531,0.019531,PowerMoE-3B,PW,causal,3.30,32,40,8,flash_attention_2,1.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
93819,qwen2,True,256,7.982143,1.000000,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2,0.0,0.0
93820,qwen2,True,256,7.986607,1.000000,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2,0.0,0.0
93821,qwen2,True,256,7.991071,1.000000,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2,0.0,0.0
93822,qwen2,True,256,7.995536,1.000000,Qwen2-57B-A14B,QW2,causal,57.41,28,64,8,flash_attention_2,0.0,0.0


In [24]:
fig = make_subplots(
    rows=1,
    cols=len(seg_lens),
    shared_xaxes="all",
    horizontal_spacing=0.01,
    subplot_titles=[f"m={seg_len}" for seg_len in seg_lens],
)

font_size = [16, 20, 24, 28]
show_legend = True

for i, seg_len in enumerate(seg_lens):
    col = i + 1

    for j, key in enumerate(main_model_config.index.values):
        tmpdf = d2mdf.query(f"model == '{key}' and seg_len == {seg_len}")

        if model_config.loc[key, "type"] == "seq2seq":
            for is_decoder in (False, True):
                subdf = tmpdf.query(f"is_decoder == {is_decoder}")

                fig.add_scatter(
                    x=subdf["cache_m"],
                    y=subdf["diff2"],
                    hoverinfo="skip",
                    legendgroup=key,
                    line=go.scatter.Line(color=model_colors[key], dash=dash_type[key], width=2),
                    mode="lines",
                    name=f"{new_name[key]} ({'De' if is_decoder else 'En'}coder)",
                    opacity=1 if is_decoder else 0.5,
                    showlegend=show_legend,
                    row=1,
                    col=col,
                )
        else:
            fig.add_scatter(
                x=tmpdf["cache_m"],
                y=tmpdf["diff2"],
                hoverinfo="skip",
                legendgroup=key,
                line=go.scatter.Line(color=model_colors[key], dash=dash_type[key], width=2),
                mode="lines",
                name=new_name[key],
                showlegend=show_legend,
                row=1,
                col=col,
            )

        fig.update_xaxes(
            range=[0, 4],
            tickfont=go.layout.xaxis.Tickfont(size=font_size[1]),
            title=go.layout.xaxis.Title(
                font=go.layout.xaxis.title.Font(size=font_size[2]), text="ρ"
            ),
            row=1,
            col=col,
        )

        if col == 1:
            fig.update_yaxes(
                tickfont=go.layout.yaxis.Tickfont(size=font_size[1]),
                title=go.layout.yaxis.Title(
                    font=go.layout.yaxis.title.Font(size=font_size[2]), text="d^2SCH(E,m,ρ)/dρ^2"
                ),
                row=1,
                col=col,
            )

    show_legend = False

fig.update_annotations(font=go.layout.annotation.Font(size=font_size[3]))

fig.update_layout(
    legend=go.layout.Legend(
        font=go.layout.legend.Font(size=font_size[1]),
        itemsizing="constant",
        orientation="h",
        y=-0.15,
        yanchor="top",
    ),
    margin=go.layout.Margin(l=60, r=30, t=30, b=15),
    width=2000,
    height=600,
)

fig.write_image("./plot/mschd2.pdf", width=fig.layout.width, height=fig.layout.height)
fig.show()