In [1]:
import copy
from functools import partial
from typing import *

import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import wandb
import wandb.apis

api = wandb.Api()

In [34]:
def load_one(run: wandb.apis.public.Run) -> Dict[str, Any]:
    if "count" in run.summary:
        task = run.config["task"]["name"]
        # results = run.summary["results"]
        # bpc = (sum(r["bpc"] for r in results) / len(run.summary["results"])
        #        if isinstance(results[0]["bpc"], float) else
        #         float("nan"))
        return dict(
            task=task,
            model_scale=run.config["model"].replace("EleutherAI/pythia-", ""),
            **run.config["sparsity"],
            bpc=run.summary["bpc"],
        )

def load_sweep(experiment_name: str, *single_runs: str) -> pd.DataFrame:
    return pd.DataFrame.from_dict(filter(None, (
        load_one(run)
        for run in api.runs(
            "research/sparse-attention", {"$or": [
                {"config.name": experiment_name},
                *({"display_name": n} for n in single_runs)
            ]}
        )
    )))

def technique(s: pd.Series) -> str:
    return "_".join(str(n) for n in [s["name"], s["strategy"], s["score"]] if not pd.isna(n))

df = (
    load_sweep("RMOE-63-perplexity-v1")
    .pipe(lambda d: d.assign(technique=d.apply(technique, axis=1)))
    [["model_scale", "technique", "k", "rank", "bpc"]]
)
(df.groupby("model_scale").apply(lambda d: d.sort_values("bpc").drop(columns="model_scale"))
 .style.format(dict(
     k=lambda x: "" if pd.isna(x) else f"{x:.0f}",
     rank=lambda x: "" if pd.isna(x) else f"{x:.0f}",
     bpc="{:.3f}",
)))

Unnamed: 0_level_0,Unnamed: 1_level_0,technique,k,rank,bpc
model_scale,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
1.4b,9,dense,,,0.774
1.4b,2,ann_sparse_q,32.0,64.0,0.799
1.4b,1,eviction_sum_weight,256.0,,0.801
1.4b,0,ann_sparse_q,32.0,32.0,0.818
1.4b,5,eviction_sum_weight,128.0,,0.833
1.4b,3,ann_sparse_q,32.0,16.0,0.835
1.4b,4,eviction_sum_weight,64.0,,0.891
1b,13,dense,,,0.807
1b,6,ann_sparse_q,32.0,64.0,0.844
1b,11,eviction_sum_weight,256.0,,0.851
