In [None]:
from pathlib import Path

import pandas as pd

from misc import model_config

In [None]:
root_dir = Path("../case")
dfs = {p.stem: pd.read_parquet(p) for p in root_dir.glob("*.parquet")}
dfs["tokens"]

In [None]:
rdf = (
    (
        dfs["srp"]
        .query("seg_len == 16")
        .drop(columns=["seg_len"])
        .groupby(["model", "dataset", "layer_idx"], as_index=False, observed=True)
        .apply(
            lambda df: pd.Series(
                {
                    "best_f1": (df["best_f1"] * df["act_r"] * (1 + df["best_m"])).sum()
                    / (df["act_r"] * (1 + df["best_m"])).sum(),
                }
            ),
            include_groups=False,
        )
    )
    .groupby(["model", "dataset"], observed=True)
    .apply(lambda df: df.iloc[[df["best_f1"].argmax()]], include_groups=False)
    .reset_index(2, drop=True)
    .reset_index()
    .sort_values("best_f1", ascending=False)
)

rdf.query("dataset == 'github'")

In [None]:
def make_matrix(model_key, data_key, layer_idx):
    threshold = model_config.loc[model_key, "num_experts"] - model_config.loc[model_key, "top_k"]

    return pd.merge(
        dfs["tokens"]
        .query(f"model == '{model_key}' and dataset == '{data_key}'")[["pos", "token"]]
        .set_index("pos"),
        dfs["logits"]
        .query(f"model == '{model_key}' and dataset == '{data_key}' and layer_idx == {layer_idx}")
        .drop(columns=["model", "dataset", "layer_idx"])
        .pivot(index="expert_idx", columns="pos", values="logit")
        .apply(lambda x: x.values.argsort().argsort() >= threshold, axis=0)
        .transpose(),
        left_index=True,
        right_index=True,
    ).set_index("token")

In [None]:
tdf1 = make_matrix("grin", "github", 21)
tdf1

In [None]:
tdf2 = make_matrix("jamba", "github", 25)
tdf2

In [None]:
def print_router(tdf, offset, plen):
    cs = "üåöüåù"
    for c in tdf.columns:
        print("".join(cs[i] for i in tdf[c])[offset : offset + plen])

In [None]:
print_router(tdf1, 0, 64)

In [None]:
print_router(tdf2, 0, 64)