In [15]:
import wandb
import pandas as pd
import numpy as np
from typing import *
from pathlib import Path
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import sys
sns.set_context("talk")

def loss_to_bpc(loss: float) -> float:
    return float(loss) / np.log(2)

def load_runs(*experiments: str) -> wandb.apis.public.Runs:
    api = wandb.Api(dict(base_url="https://wandb.sourcevertex.net"))
    runs = api.runs(path="douglaso/scaled-matmuls", filters={
        "config.metadata.experiment": {"$in": experiments},
    })
    print(f"Loaded {len(runs)} runs", file=sys.stderr)
    return runs

In [19]:
runs = load_runs(
    "20230115_large_p0",
    "20230115_large_p1",
    "20230115_large_p2",
)
def run_stats(run):
    if run.state != "finished" or "broken" in run.tags:
        return

    model = run.config["model"]
    optimiser = run.config["training"]["optimiser"]
    return dict(
        id=run.id,
        experiment=run.config["metadata"]["experiment"],
        ## model
        variant="unit" if run.config["unit_scale"] else "regular",
        dtype=model["dtype"],
        depth=model["depth"],
        norm=model["residual"]["norm"] or "none",
        model=model["sequence"]["kind"],
        ## optimiser
        optimiser=optimiser["kind"],
        lr=np.log(optimiser["learning_rate"])/np.log(2),
        loss_scale=run.config["training"]["loss_scale"],
        ## stats
        weights=run.summary["n_weights_no_embedding"],
        test_bpc=loss_to_bpc(run.summary["test_loss"]),
        valid_bpc=loss_to_bpc(run.summary["valid_loss"]),
        train_bpc=loss_to_bpc(run.summary["train_loss"]),
    )

df = pd.DataFrame.from_dict(filter(None, map(run_stats, runs)))

Loaded 510 runs


In [20]:
dfm = (df
   .pipe(lambda d: d[(d.model == "rnn") | (d.depth == 8)])
   .groupby(["model", "norm", "dtype", "loss_scale", "variant", "lr"])
   .apply(lambda d: pd.Series(dict(
       count=len(d),
       valid_bpc=d.valid_bpc.median(),
       test_bpc=d.test_bpc.median(),
       test_bpc_std=d.test_bpc.std(),
   )))
   .reset_index()
)

In [21]:
dfb = (dfm
 .groupby(["model", "norm", "dtype", "loss_scale", "variant"])
 .apply(lambda d: d.iloc[d.valid_bpc.argmin()][["lr", "test_bpc", "test_bpc_std"]])
 .reset_index()
)
ci = 2 * dfb.test_bpc_std.mean() / np.sqrt(3)
print(f"95% CI: {ci:.3f}")
(dfb
 .pivot(index=["model", "norm"], columns=["dtype", "variant", "loss_scale"], values="test_bpc")
 .style
 .format("{:.3f}")
 .apply(lambda s: np.where(s < s.min() + ci, "font-weight: bold", ""), axis=1)
)

95% CI: 0.010


Unnamed: 0_level_0,dtype,float16,float16,float16,float32,float32
Unnamed: 0_level_1,variant,regular,unit,regular,regular,unit
Unnamed: 0_level_2,loss_scale,1,1,2048,1,1
model,norm,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3
attention,post,1.66,1.54,1.555,1.548,1.54
attention,pre,1.661,1.567,1.587,1.582,1.562
conv,pre,1.654,1.622,1.627,1.625,1.62
rnn,pre,1.697,1.673,1.682,1.674,1.677
