In [None]:
import math
import pandas as pd

In [None]:
fp = pd.read_csv("raw_data/benchmark_results.csv")
fp = fp[fp.num_warmup_steps==5].rename(
    columns={"forward_mean": "forward", "backward_mean": "backward", "model.context_length": "context"})[["model", "context", "forward", "backward", "forward_only"]]
mp = pd.read_csv("raw_data/benchmark_results_mp.csv").rename(
    columns={"forward_mean": "forward", "backward_mean": "backward", "model.context_length": "context"})[["model", "context", "forward", "backward", "forward_only"]]

In [None]:
MODEL_ORDER = ["small", "medium", "large", "xl", "2.7B"]

def combine(df1, df2, col):
    df1 = df1.pivot_table(index="model", columns="context", values=col, observed=False)
    df2 = df2.pivot_table(index="model", columns="context", values=col, observed=False)
   
    report = df2.combine(
        df1,
        lambda a, b: a.combine(
            b,
            lambda x, y: f"{x * 1000:.1f} / {y * 1000:.1f}" if pd.notna(x) and pd.notna(y)
                        else f"{x * 1000:.1f} / OOM" if pd.notna(x)
                        else f"OOM / {y * 1000:.1f}" if pd.notna(y)
                        else "OOM / OOM"
        ),
    )

    if "2.7B" not in report.index:
        report.loc["2.7B"] = "OOM / OOM"
        
    return report.sort_index()


In [None]:
fp["model"] = pd.Categorical(fp["model"], categories=MODEL_ORDER, ordered=True)
mp["model"] = pd.Categorical(mp["model"], categories=MODEL_ORDER, ordered=True)

forward_infer = combine(fp[fp.forward_only], mp[mp.forward_only], "forward")
forward_train = combine(fp[~fp.forward_only], mp[~mp.forward_only], "forward")
backward = combine(fp[~fp.forward_only], mp[~mp.forward_only], "backward")

In [None]:
forward_infer

In [None]:
forward_train

In [None]:
backward

In [None]:
def print_report(df, title):
    tex = df.to_latex(index=True, caption=f"{title} (bf16 / fp32 ms)", escape=False)
    tex = tex.replace(r"\begin{table}", r"\begin{table}[H]")
    print(tex)

In [None]:
print_report(forward_infer, "Forward Inference")
print_report(forward_train, "Forward Training")
print_report(backward, "Backward")