In [1]:
from pathlib import Path

import pandas as pd

from misc import model_config

In [None]:
root_dir = Path("../case")
dfs = {p.stem: pd.read_parquet(p) for p in root_dir.glob("*.parquet")}
dfs["tokens"]

Unnamed: 0,model,dataset,pos,token
0,powermoe,arxiv,0,Introduction
1,powermoe,arxiv,1,Ġ\
2,powermoe,arxiv,2,label
3,powermoe,arxiv,3,{
4,powermoe,arxiv,4,sec
...,...,...,...,...
32251,qwen2,wikipedia,251,ĠThe
32252,qwen2,wikipedia,252,ĠSan
32253,qwen2,wikipedia,253,ĠMill
32254,qwen2,wikipedia,254,Ã¡n


In [None]:
rdf = (
    (
        dfs["srp"]
        .query("seg_len == 16")
        .drop(columns=["seg_len"])
        .groupby(["model", "dataset", "layer_idx"], as_index=False, observed=True)
        .apply(
            lambda df: pd.Series(
                {
                    "best_f1": (df["best_f1"] * df["act_r"] * (1 + df["best_m"])).sum()
                    / (df["act_r"] * (1 + df["best_m"])).sum(),
                }
            ),
            include_groups=False,
        )
    )
    .groupby(["model", "dataset"], observed=True)
    .apply(lambda df: df.iloc[[df["best_f1"].argmax()]], include_groups=False)
    .reset_index(2, drop=True)
    .reset_index()
    .sort_values("best_f1", ascending=False)
)

rdf.query("dataset == 'github'")

Unnamed: 0,model,dataset,layer_idx,best_f1
27,llamamoe2,github,31,0.918436
6,powermoe,github,1,0.767257
104,grin,github,23,0.653532
97,phi,github,21,0.651272
48,minicpm,github,19,0.592131
83,qwen3,github,6,0.590894
90,yuan,github,8,0.590616
111,mixtral,github,9,0.572393
20,olmoe,github,13,0.565924
34,jetmoe,github,10,0.513457


In [4]:
def make_matrix(model_key, data_key, layer_idx):
    threshold = model_config.loc[model_key, "num_experts"] - model_config.loc[model_key, "top_k"]

    return pd.merge(
        dfs["tokens"]
        .query(f"model == '{model_key}' and dataset == '{data_key}'")[["pos", "token"]]
        .set_index("pos"),
        dfs["logits"]
        .query(f"model == '{model_key}' and dataset == '{data_key}' and layer_idx == {layer_idx}")
        .drop(columns=["model", "dataset", "layer_idx"])
        .pivot(index="expert_idx", columns="pos", values="logit")
        .apply(lambda x: x.values.argsort().argsort() >= threshold, axis=0)
        .transpose(),
        left_index=True,
        right_index=True,
    ).set_index("token")

In [5]:
tdf1 = make_matrix("grin", "github", 21)
tdf1

Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15
token,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
▁using,True,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False
▁System,False,False,True,False,False,False,False,True,False,False,False,False,False,False,False,False
;,False,False,False,False,False,True,False,True,False,False,False,False,False,False,False,False
<0x0A>,False,True,False,False,False,False,False,True,False,False,False,False,False,False,False,False
using,False,False,False,False,False,False,False,True,False,False,False,False,False,True,False,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
▁_,False,False,True,False,False,False,False,False,False,False,False,False,False,True,False,False
sign,False,False,True,False,False,True,False,False,False,False,False,False,False,False,False,False
In,True,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False
Manager,False,False,True,False,False,False,False,True,False,False,False,False,False,False,False,False


In [6]:
tdf2 = make_matrix("jamba", "github", 25)
tdf2

Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15
token,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
<|startoftext|>,False,False,True,False,True,False,False,False,False,False,False,False,False,False,False,False
using,False,False,False,False,False,False,False,True,True,False,False,False,False,False,False,False
▁System,False,True,False,False,False,False,False,False,False,False,True,False,False,False,False,False
;,False,False,False,False,False,False,False,True,False,False,False,False,False,True,False,False
<0x0A>,False,False,False,False,False,False,False,False,True,False,False,False,False,True,False,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
▁,False,False,False,False,False,False,False,True,False,False,False,False,True,False,False,False
<0x0A>,False,False,False,False,False,False,False,False,True,False,False,False,False,False,False,True
▁▁▁▁▁▁▁▁▁▁▁,False,True,False,False,False,False,False,False,False,False,True,False,False,False,False,False
▁},True,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False


In [None]:
def print_router(tdf, offset, plen):
    cs = "🌚🌝"
    for c in tdf.columns:
        print("".join(cs[i] for i in tdf[c])[offset : offset + plen])

In [8]:
print_router(tdf1, 0, 64)

🌝🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌚🌚🌚
🌚🌚🌚🌝🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚
🌚🌝🌚🌚🌚🌝🌚🌝🌝🌝🌚🌝🌚🌚🌝🌝🌚🌚🌚🌚🌚🌝🌝🌝🌝🌝🌝🌚🌝🌚🌝🌝🌝🌝🌝🌝🌝🌝🌝🌝🌝🌝🌝🌝🌝🌝🌝🌝🌝🌝🌚🌝🌝🌝🌝🌝🌝🌝🌝🌝🌝🌝🌝🌝
🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚
🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚
🌚🌚🌝🌚🌚🌚🌚🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚
🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚
🌚🌝🌝🌝🌝🌝🌝🌚🌝🌝🌝🌚🌚🌝🌝🌝🌝🌝🌝🌚🌝🌚🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚
🌝🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚
🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚
🌚🌚🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌝🌚🌝🌚🌝🌚🌚🌚🌝🌚
🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚
🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚
🌚🌚🌚🌚🌝🌚🌝🌚🌚🌚🌝🌝🌚🌝🌚🌚🌝🌝🌝🌚🌝🌚🌝🌚🌚🌚🌚🌝🌝🌚🌝🌚🌝🌝🌝🌚🌝🌝🌝🌝🌝🌚🌝🌝🌝🌝🌝🌝🌚🌝🌝🌝🌝🌝🌚🌚🌚🌝🌚🌝🌝🌝🌚🌝
🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚
🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌝🌚🌚🌝

In [9]:
print_router(tdf2, 0, 64)

🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌝🌚🌝🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌝🌚🌝🌚🌚🌚🌚🌝🌚🌚🌝🌚🌚🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌝🌚🌝🌚🌚🌚🌚🌚🌝🌚🌝🌚🌚🌚🌚🌚🌚
🌚🌚🌝🌚🌚🌚🌝🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌝🌚🌝🌚🌚🌚🌚🌚🌝🌚🌝🌚🌚🌚🌝🌚🌝🌚🌚🌚🌚🌚🌚🌚🌝🌚🌝🌚
🌝🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌝
🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚
🌝🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚
🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚
🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚
🌚🌝🌚🌝🌚🌚🌚🌚🌚🌝🌝🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌚🌝🌝🌚🌚🌚🌚🌚🌚🌚🌚
🌚🌝🌚🌚🌝🌚🌚🌚🌚🌚🌚🌝🌚🌚🌝🌚🌚🌝🌚🌚🌝🌚🌚🌚🌚🌚🌝🌝🌚🌝🌚🌚🌚🌚🌝🌝🌚🌝🌚🌚🌝🌝🌚🌝🌚🌝🌚🌚🌝🌝🌚🌚🌚🌚🌚🌚🌝🌝🌚🌚🌚🌝🌚🌝
🌚🌚🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚
🌚🌚🌝🌚🌚🌚🌝🌚🌝🌝🌚🌚🌚🌝🌚🌝🌚🌚🌚🌝🌚🌝🌚🌝🌝🌚🌚🌚🌝🌚🌝🌚🌝🌚🌚🌚🌝🌚🌝🌝🌚🌚🌝🌚🌝🌚🌝🌚🌚🌚🌝🌝🌝🌝🌝🌚🌚🌚🌝🌝🌝🌝🌝🌚
🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚
🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌚🌚
🌚🌚🌚🌝🌝🌚🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌝🌝🌚🌚🌚🌝🌝🌚🌝🌚🌝🌝🌚🌚🌚🌝🌚🌚🌝🌝🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚
🌚🌚🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚🌚
🌚🌚🌚🌚🌚🌚🌚🌚🌝🌚🌚🌚🌝🌝🌚🌚🌚🌚🌚🌝🌚🌚🌚🌚🌚