In [None]:
from pathlib import Path

import pandas as pd
from misc import model_config

In [2]:
main_model_config = (
    model_config.query("main")
    .drop(columns="main")
    .rename(columns={k: f"model_{k}" for k in model_config.columns})
)

new_name = {
    "baseline": "Baseline",
    "32expert": "FewerExp",
    "top16": "ActMore",
    "top2": "ActFewer",
    "share1": "1ShrExp",
    "share2": "2ShrExp",
    "skip1": "DenseFst",
    "sparse2": "DenseHlf",
    "nolb": "NoLB",
    "overlb": "OverLB",
}

model_colors = {
    key: px.colors.qualitative.Dark24[i] for i, key in enumerate(main_model_config.index.values)
}

main_model_config

Unnamed: 0_level_0,model_name,model_abbr,model_type,model_num_params,model_num_layers,model_num_experts,model_top_k,model_attn
key,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
baseline,baseline,BL,causal,1.43,8,64,8,sdpa
32expert,baseline,FE,causal,1.43,8,32,4,sdpa
top16,baseline,AM,causal,1.43,8,64,16,sdpa
top2,baseline,AF,causal,1.43,8,64,2,sdpa
share1,share1,SH1,causal,1.43,8,63,7,sdpa
share2,share2,SH2,causal,1.43,8,62,6,sdpa
skip1,skip1,DF,causal,1.43,8,64,8,sdpa
sparse2,sparse2,DH,causal,1.43,8,64,8,sdpa
nolb,baseline,NB,causal,1.43,8,64,8,sdpa
overlb,baseline,OB,causal,1.43,8,64,8,sdpa


In [3]:
def make_abbr(df):
    return (
        f"{df['model_abbr']}{'d' if df['is_decoder'] else 'e'}"
        if df["model_type"] == "seq2seq"
        else df["model_abbr"]
    )

In [None]:
root_dir = Path("../output/tp_mpq")

df = pd.merge(
    pd.read_parquet(root_dir / "m.parquet"), main_model_config, left_on="model", right_index=True
)

df["model"] = df["model"].astype(model_config.index.dtype)
df

Unnamed: 0,model,dataset,cache_m,mode,prefill_thrp,decode_thrp,model_name,model_abbr,model_type,model_num_params,model_num_layers,model_num_experts,model_top_k,model_attn
0,baseline,c4,1.0,old_doc,669.586534,10.345799,baseline,BL,causal,1.43,8,64,8,sdpa
1,baseline,c4,1.0,new_gen,669.762733,10.473391,baseline,BL,causal,1.43,8,64,8,sdpa
2,baseline,c4,2.0,old_doc,746.376371,11.608819,baseline,BL,causal,1.43,8,64,8,sdpa
3,baseline,c4,2.0,new_gen,746.203748,11.705259,baseline,BL,causal,1.43,8,64,8,sdpa
4,baseline,c4,3.0,old_doc,803.713408,12.458667,baseline,BL,causal,1.43,8,64,8,sdpa
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1051,overlb,science,2.0,new_gen,747.422945,11.373545,baseline,OB,causal,1.43,8,64,8,sdpa
1052,overlb,science,3.0,old_doc,805.651832,12.072060,baseline,OB,causal,1.43,8,64,8,sdpa
1053,overlb,science,3.0,new_gen,803.538583,12.103101,baseline,OB,causal,1.43,8,64,8,sdpa
1054,overlb,science,inf,old_doc,1408.173687,18.671890,baseline,OB,causal,1.43,8,64,8,sdpa


In [None]:
mdf = df.groupby(["model", "cache_m", "mode"], observed=True, as_index=False)[
    ["prefill_thrp", "decode_thrp"]
].mean()

mdf.pivot(index=["model", "mode"], columns="cache_m", values=["prefill_thrp", "decode_thrp"])

Unnamed: 0_level_0,Unnamed: 1_level_0,prefill_thrp,prefill_thrp,prefill_thrp,prefill_thrp,decode_thrp,decode_thrp,decode_thrp,decode_thrp
Unnamed: 0_level_1,cache_m,1.0,2.0,3.0,inf,1.0,2.0,3.0,inf
model,mode,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
baseline,old_doc,671.211354,746.142357,806.701947,1440.988898,10.570958,12.041458,13.023481,19.036706
baseline,new_gen,671.10593,750.02648,805.326465,1442.08117,10.63958,12.135196,13.099395,19.029482
32expert,old_doc,687.171915,733.76502,785.751764,1445.474156,10.79539,11.803518,12.719432,19.093189
32expert,new_gen,687.158222,735.172242,786.745651,1446.81635,10.85318,11.872818,12.816151,19.086573
top16,old_doc,726.672124,850.446746,1049.297554,1370.380943,6.395263,7.618822,8.985702,10.777566
top16,new_gen,727.987611,851.126996,1050.449929,1371.243004,6.451738,7.689991,9.04669,10.769219
top2,old_doc,687.464493,687.195766,707.438857,1524.063953,27.46779,27.630046,28.598128,41.807579
top2,new_gen,690.125515,689.232838,707.698981,1523.566155,27.51034,27.657158,28.64345,41.795898
share1,old_doc,674.368386,732.038827,812.139154,1444.953554,10.647874,11.811415,13.060729,19.074727
share1,new_gen,672.579744,730.609584,813.440267,1446.292447,10.719114,11.873119,13.142425,19.069902


In [16]:
mmdf = mdf.query("cache_m != inf").merge(
    mdf.query("cache_m == inf").drop(columns="cache_m"), on=["model", "mode"], suffixes=("", "_b")
)

mmdf["prefill_oh"] = 1 / mmdf["prefill_thrp"] - 1 / mmdf["prefill_thrp_b"]
mmdf["decode_oh"] = 1 / mmdf["decode_thrp"] - 1 / mmdf["decode_thrp_b"]
mmdf["prefill_ohr"] = mmdf["prefill_oh"] * mmdf["prefill_thrp_b"]
mmdf["decode_ohr"] = mmdf["decode_oh"] * mmdf["decode_thrp_b"]

mmdf.pivot(
    columns="cache_m",
    index=["model", "mode"],
    values=["prefill_oh", "decode_oh", "prefill_ohr", "decode_ohr"],
)

Unnamed: 0_level_0,Unnamed: 1_level_0,prefill_oh,prefill_oh,prefill_oh,decode_oh,decode_oh,decode_oh,prefill_ohr,prefill_ohr,prefill_ohr,decode_ohr,decode_ohr,decode_ohr
Unnamed: 0_level_1,cache_m,1.0,2.0,3.0,1.0,2.0,3.0,1.0,2.0,3.0,1.0,2.0,3.0
model,mode,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2
baseline,old_doc,0.000796,0.000646,0.000546,0.042069,0.030516,0.024254,1.146848,0.931252,0.786272,0.80085,0.58093,0.461722
baseline,new_gen,0.000797,0.00064,0.000548,0.041439,0.029855,0.023789,1.148813,0.922707,0.790679,0.788556,0.568123,0.452699
32expert,old_doc,0.000763,0.000671,0.000581,0.040257,0.032346,0.026245,1.103512,0.969941,0.839607,0.768643,0.617585,0.501104
32expert,new_gen,0.000764,0.000669,0.00058,0.039746,0.031833,0.025634,1.105507,0.967996,0.838989,0.758616,0.607586,0.489259
top16,old_doc,0.000646,0.000446,0.000223,0.06358,0.038469,0.018503,0.885831,0.611366,0.305998,0.685242,0.414597,0.199413
top16,new_gen,0.000644,0.000446,0.000223,0.06214,0.037182,0.01768,0.883608,0.611091,0.305386,0.669197,0.40042,0.190404
top2,old_doc,0.000798,0.000799,0.000757,0.012487,0.012273,0.011048,1.216935,1.217802,1.15434,0.522058,0.51312,0.461899
top2,new_gen,0.000793,0.000795,0.000757,0.012424,0.012231,0.010986,1.207665,1.210525,1.152845,0.51928,0.511215,0.459178
share1,old_doc,0.000791,0.000674,0.000539,0.04149,0.032238,0.02414,1.142677,0.973876,0.779195,0.791412,0.61494,0.460464
share1,new_gen,0.000795,0.000677,0.000538,0.040853,0.031785,0.023651,1.150366,0.979569,0.777995,0.779056,0.606141,0.451019


In [28]:
chr_dir = Path("../output/chr_mpq")

chr_df = pd.merge(
    pd.read_parquet(chr_dir / "m.parquet"), main_model_config, left_on="model", right_index=True
)

chr_df["model"] = chr_df["model"].astype(model_config.index.dtype)
chr_df

Unnamed: 0,model,is_decoder,dataset,method,cache_m,recall,ci_lb,ci_ub,model_name,model_abbr,model_type,model_num_params,model_num_layers,model_num_experts,model_top_k,model_attn
0,baseline,True,c4,LRU,0.015625,0.003556,0.003538,0.003574,baseline,BL,causal,1.43,8,64,8,sdpa
1,baseline,True,c4,LRU,0.031250,0.007462,0.007428,0.007494,baseline,BL,causal,1.43,8,64,8,sdpa
2,baseline,True,c4,LRU,0.046875,0.011688,0.011635,0.011736,baseline,BL,causal,1.43,8,64,8,sdpa
3,baseline,True,c4,LRU,0.062500,0.016184,0.016111,0.016251,baseline,BL,causal,1.43,8,64,8,sdpa
4,baseline,True,c4,LRU,0.078125,0.020975,0.020880,0.021062,baseline,BL,causal,1.43,8,64,8,sdpa
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
172387,overlb,True,science,Beladi,7.937500,0.999542,0.999535,0.999549,baseline,OB,causal,1.43,8,64,8,sdpa
172388,overlb,True,science,Beladi,7.953125,0.999684,0.999679,0.999689,baseline,OB,causal,1.43,8,64,8,sdpa
172389,overlb,True,science,Beladi,7.968750,0.999812,0.999809,0.999816,baseline,OB,causal,1.43,8,64,8,sdpa
172390,overlb,True,science,Beladi,7.984375,0.999922,0.999920,0.999924,baseline,OB,causal,1.43,8,64,8,sdpa


In [29]:
mcdf = (
    chr_df.query("method == 'LRU' and cache_m.isin([1.0, 2.0, 3.0])")
    .groupby(["model", "cache_m"], observed=True, as_index=False)[["recall"]]
    .mean()
)

mcdf.pivot(columns="cache_m", index="model", values="recall")

cache_m,1.0,2.0,3.0
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
baseline,0.293701,0.47414,0.614081
32expert,0.255014,0.435932,0.579074
top16,0.439243,0.699631,0.892248
top2,0.099905,0.188884,0.253743
share1,0.270697,0.438061,0.566572
share2,0.239381,0.396527,0.514828
skip1,0.313387,0.498588,0.638797
sparse2,0.301056,0.479912,0.615446
nolb,0.436388,0.672001,0.823592
overlb,0.201784,0.355713,0.489083


In [37]:
mmdf.merge(mcdf).query("not model.isin(['top16', 'top2'])").groupby(
    ["cache_m", "mode"], observed=True
)[["prefill_oh", "decode_oh", "prefill_ohr", "decode_ohr", "recall"]].corr()["recall"].unstack(
    -1
).drop(columns="recall").unstack(-1)

Unnamed: 0_level_0,prefill_oh,prefill_oh,decode_oh,decode_oh,prefill_ohr,prefill_ohr,decode_ohr,decode_ohr
mode,old_doc,new_gen,old_doc,new_gen,old_doc,new_gen,old_doc,new_gen
cache_m,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2
1.0,0.094626,0.076122,-0.353779,-0.342702,0.18879,0.167119,-0.196177,-0.188924
2.0,0.337051,0.217339,-0.419011,-0.440524,0.394074,0.298722,-0.285584,-0.315164
3.0,0.131248,0.121281,-0.367251,-0.358451,0.243144,0.228427,-0.298174,-0.289628
