In [None]:
from pathlib import Path

import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from misc import data_config, model_config

In [2]:
main_model_config = (
    model_config.query("main")
    .drop(columns="main")
    .rename(columns={k: f"model_{k}" for k in model_config.columns})
)

model_colors = {
    key: px.colors.qualitative.Dark24[i] for i, key in enumerate(main_model_config.index.values)
}

cmp_groups = {
    "llamamoe": ["llamamoe", "llamamoes"],
    "olmoe": ["olmoe", "olmoesft", "olmoedpo", "olmoeins"],
    "jetmoe": ["jetmoe", "jetmoesft", "jetmoechat"],
}

new_name = {
    "llamamoe": "LLaMA-MoE-v1",
    "llamamoes": "LLaMA-MoE-v1-SFT",
    "olmoe": "OLMoE",
    "olmoesft": "OLMoE-SFT",
    "olmoedpo": "OLMoE-DPO",
    "olmoeins": "OLMoE-Instruct",
    "jetmoe": "JetMoE",
    "jetmoesft": "JetMoE-SFT",
    "jetmoechat": "JetMoE-Chat",
}

cmp_keys = [v for vs in cmp_groups.values() for v in vs]

cmp_model_config = (
    model_config.loc[cmp_keys]
    .drop(columns="main")
    .rename(columns={k: f"model_{k}" for k in model_config.columns})
)

for k, vs in cmp_groups.items():
    for v in vs:
        cmp_model_config.loc[v, "model_group"] = k
        cmp_model_config.loc[v, "model_name"] = new_name[v]

cmp_model_config["model_group"] = cmp_model_config["model_group"].astype(model_config.index.dtype)
cmp_model_config

Unnamed: 0_level_0,model_name,model_abbr,model_type,model_num_params,model_num_layers,model_num_experts,model_top_k,model_attn,model_group
key,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
llamamoe,LLaMA-MoE-v1,LL1,causal,6.74,32,16,4,eager,llamamoe
llamamoes,LLaMA-MoE-v1-SFT,LL1-S,causal,6.74,32,16,4,eager,llamamoe
olmoe,OLMoE,OL,causal,6.92,16,64,8,flash_attention_2,olmoe
olmoesft,OLMoE-SFT,OL-S,causal,6.92,16,64,8,flash_attention_2,olmoe
olmoedpo,OLMoE-DPO,OL-D,causal,6.92,16,64,8,flash_attention_2,olmoe
olmoeins,OLMoE-Instruct,OL-I,causal,6.92,16,64,8,flash_attention_2,olmoe
jetmoe,JetMoE,JT,causal,8.52,24,8,2,flash_attention_2,jetmoe
jetmoesft,JetMoE-SFT,JT-S,causal,8.52,24,8,2,flash_attention_2,jetmoe
jetmoechat,JetMoE-Chat,JT-C,causal,8.52,24,8,2,flash_attention_2,jetmoe


In [3]:
main_data_config = data_config.rename(columns={k: f"data_{k}" for k in data_config.columns})
main_data_config

Unnamed: 0_level_0,data_name,data_abbr
key,Unnamed: 1_level_1,Unnamed: 2_level_1
c4,C4,C4
cc2306,CommonCrawl,CC
book,Books,BK
wikipedia,Wikipedia,WK
arxiv,ArXiv,AX
stackexchange,StackExchange,SE
github,GitHub,GH
lmarena,LMArena,LM
math,OpenMath,OM
code,OpenCode,OC


In [None]:
root_dir = Path("../output/srp_mpq")

dfs = {
    p.stem: pd.merge(pd.read_parquet(p), cmp_model_config, left_on="model", right_index=True)
    for p in root_dir.glob("*.parquet")
}

for key in dfs.keys():
    if "dataset" in dfs[key].columns:
        dfs[key] = pd.merge(dfs[key], main_data_config, left_on="dataset", right_index=True)

for df in dfs.values():
    df["model"] = df["model"].astype(model_config.index.dtype)
    if "dataset" in df.columns:
        df["dataset"] = df["dataset"].astype(data_config.index.dtype)

dfs["mg"].pivot(
    index=["model_group", "model_name"], columns="seg_len", values=["best_f1", "best_m"]
)

Unnamed: 0_level_0,Unnamed: 1_level_0,best_f1,best_f1,best_f1,best_f1,best_m,best_m,best_m,best_m
Unnamed: 0_level_1,seg_len,4,16,64,256,4,16,64,256
model_group,model_name,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
llamamoe,LLaMA-MoE-v1,0.557779,0.45288,0.416099,0.40618,1.029198,2.392017,2.923047,3.521086
llamamoe,LLaMA-MoE-v1-SFT,0.557927,0.452827,0.416027,0.406179,1.028538,2.390347,2.922548,3.520766
olmoe,OLMoE,0.646949,0.509072,0.455343,0.426438,0.997344,1.056418,1.205091,1.187848
olmoe,OLMoE-DPO,0.650923,0.513726,0.460383,0.43247,1.001774,1.065452,1.221628,1.171647
olmoe,OLMoE-Instruct,0.650671,0.513379,0.460033,0.432102,1.001423,1.064703,1.220646,1.170233
olmoe,OLMoE-SFT,0.651482,0.514721,0.461546,0.433809,1.002181,1.065605,1.220275,1.170159
jetmoe,JetMoE,0.602158,0.47454,0.427762,0.410949,1.093253,2.259037,2.687439,3.154335
jetmoe,JetMoE-Chat,0.600186,0.473155,0.426525,0.410049,1.090559,2.264925,2.702572,3.184174
jetmoe,JetMoE-SFT,0.600133,0.473083,0.426441,0.409977,1.09051,2.26528,2.703776,3.186965


In [5]:
dfs["md"].query("seg_len == 16").pivot(
    index=["model_group", "model_name"], columns="dataset", values="best_f1"
)

Unnamed: 0_level_0,dataset,c4,cc2306,book,wikipedia,arxiv,stackexchange,github,lmarena,math,code,science
model_group,model_name,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
llamamoe,LLaMA-MoE-v1,0.450365,0.450565,0.448192,0.454389,0.454172,0.45514,0.456557,0.456482,0.451833,0.452922,0.451186
llamamoe,LLaMA-MoE-v1-SFT,0.450131,0.450435,0.448012,0.454107,0.454214,0.455006,0.456492,0.456385,0.452112,0.452437,0.451884
olmoe,OLMoE,0.455232,0.454706,0.45541,0.52697,0.535874,0.550433,0.565604,0.523845,0.52972,0.547402,0.488658
olmoe,OLMoE-DPO,0.453336,0.453536,0.454904,0.524818,0.545393,0.558625,0.578345,0.523817,0.541753,0.557496,0.492086
olmoe,OLMoE-Instruct,0.453118,0.453374,0.454745,0.524696,0.545053,0.558308,0.577966,0.523748,0.54103,0.557106,0.491492
olmoe,OLMoE-SFT,0.453791,0.45369,0.454565,0.525979,0.546128,0.559705,0.579066,0.524998,0.543172,0.559711,0.493525
jetmoe,JetMoE,0.475049,0.474884,0.470432,0.480482,0.475177,0.475215,0.470728,0.477087,0.473551,0.473441,0.474037
jetmoe,JetMoE-Chat,0.473541,0.473573,0.469329,0.479443,0.47399,0.473865,0.469668,0.475743,0.471043,0.472132,0.472529
jetmoe,JetMoE-SFT,0.473475,0.473515,0.469279,0.479428,0.473923,0.473807,0.469609,0.475665,0.470966,0.472059,0.472348


In [None]:
sample_seg_len = 16

mdf = (
    pd.merge(
        dfs["mg"]
        .drop(columns=["best_m", "ci_lb", "ci_ub"])
        .rename(columns={"best_f1": "gen_best_f1"}),
        dfs["md"].drop(columns=["act_r", "best_m", "ci_lb", "ci_ub"]),
    )
    .query(f"seg_len == {sample_seg_len}")
    .drop(columns="seg_len")
)

mdf["f1_diff"] = (mdf["best_f1"] - mdf["gen_best_f1"]) / mdf["gen_best_f1"]
mdf.pivot(index=["model_group", "model_name"], columns="dataset", values="f1_diff")

Unnamed: 0_level_0,dataset,c4,cc2306,book,wikipedia,arxiv,stackexchange,github,lmarena,math,code,science
model_group,model_name,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
llamamoe,LLaMA-MoE-v1,-0.005554,-0.005112,-0.010352,0.003333,0.002852,0.00499,0.008118,0.007955,-0.002313,9.3e-05,-0.003741
llamamoe,LLaMA-MoE-v1-SFT,-0.005953,-0.005281,-0.010632,0.002826,0.003063,0.004814,0.008095,0.007859,-0.001577,-0.00086,-0.002081
olmoe,OLMoE,-0.105762,-0.106794,-0.105412,0.035159,0.052648,0.081247,0.111048,0.029019,0.04056,0.075293,-0.040101
olmoe,OLMoE-DPO,-0.117552,-0.117165,-0.114501,0.021591,0.061641,0.087398,0.125784,0.019643,0.054555,0.085201,-0.042125
olmoe,OLMoE-Instruct,-0.11738,-0.116883,-0.114211,0.022045,0.061698,0.087517,0.125808,0.020199,0.053861,0.085176,-0.042633
olmoe,OLMoE-SFT,-0.118375,-0.118571,-0.116871,0.021872,0.061018,0.087396,0.12501,0.019967,0.055275,0.087407,-0.04118
jetmoe,JetMoE,0.001072,0.000725,-0.008656,0.012522,0.001344,0.001422,-0.008034,0.005367,-0.002085,-0.002316,-0.00106
jetmoe,JetMoE-Chat,0.000816,0.000884,-0.008086,0.013291,0.001765,0.001502,-0.007369,0.00547,-0.004462,-0.002162,-0.001322
jetmoe,JetMoE-SFT,0.000827,0.000912,-0.008043,0.013411,0.001775,0.001529,-0.007343,0.005456,-0.004475,-0.002165,-0.001554


In [23]:
num_rows = len(cmp_groups)
num_cols = max(len(v) for v in cmp_groups.values())

fig = make_subplots(
    rows=num_rows,
    cols=num_cols,
    shared_xaxes="all",
    shared_yaxes="all",
    horizontal_spacing=0.005,
    vertical_spacing=0.1,
    subplot_titles=[
        "" if i >= len(group) else cmp_model_config.loc[group[i], "model_name"]
        for group in cmp_groups.values()
        for i in range(num_cols)
    ],
)

font_size = [12, 16, 18, 20]

for i, (group_key, group) in enumerate(cmp_groups.items()):
    row = i + 1
    for j, key in enumerate(group):
        col = j + 1

        tmpdf = mdf.query(f"model == '{key}'")
        if len(tmpdf) == 0:
            continue

        fig.add_bar(
            x=tmpdf["data_abbr"],
            y=tmpdf["f1_diff"],
            hoverinfo="skip",
            marker=go.bar.Marker(color=model_colors[group_key]),
            showlegend=False,
            row=row,
            col=col,
        )

        fig.update_xaxes(
            showticklabels=row == num_rows,
            tickangle=0,
            tickfont=go.layout.xaxis.Tickfont(size=font_size[0]),
            row=row,
            col=col,
        )

        fig.update_yaxes(showticklabels=col == 1, tickvals=[-0.1, 0, 0.1], row=row, col=col)

        if col == 1:
            fig.update_yaxes(
                tickfont=go.layout.yaxis.Tickfont(size=font_size[1]),
                ticktext=["-10%", "SRP<br>(E,16)", "+10%"],
                row=row,
                col=col,
            )

fig.update_annotations(font=go.layout.annotation.Font(size=font_size[3]))
fig.update_layout(margin=go.layout.Margin(l=60, r=15, t=30, b=15), width=1000, height=600)
fig.write_image("./plot/msrpddp.pdf", width=fig.layout.width, height=fig.layout.height)
fig.show()