In [None]:
import pandas as pd

In [None]:
df = pd.read_csv("raw_data/benchmark_flash.csv")
df.rename(columns={"context_length": "Context", "fw": "forward", "bk": "backward", "fw_bk": "forward-backward", "d_model": "Dimension"}, inplace=True)
df["forward"] *= 1000

In [None]:
def prepare_table(combined, test_val ,dtype_val):
    combined = combined[combined.dtype == dtype_val]
    pytorch = combined[combined.attention == "pytorch"].pivot_table(
        index="Dimension", columns="Context", values=test_val, observed=False,
    )
    flash_torch_bwd = combined[combined.attention == "flash_torch_bwd"].pivot_table(
        index="Dimension", columns="Context", values=test_val, observed=False,
    )
    flash_triton_bwd = combined[combined.attention == "flash_triton_bwd"].pivot_table(
        index="Dimension", columns="Context", values=test_val, observed=False,
    )

    # Align and combine into the formatted string
    report = (
        pytorch.combine(
            flash_torch_bwd,
            lambda a, b: a.combine(
                b,
                lambda x, y: (
                    f"{x:.1f} / {y:.1f}" if pd.notna(x) and pd.notna(y)
                    else f"{x:.1f} / OOM" if pd.notna(x)
                    else f"OOM / {y:.1f}" if pd.notna(y)
                    else "OOM / OOM"
                ),
            ),
        )
        .combine(
            flash_triton_bwd,
            lambda ab, c: ab.combine(
                c,
                # ab is already a string like "p / ft", c is the triton value
                lambda xy, z: (
                    f"{xy} / {z:.1f}" if pd.notna(z)
                    else f"{xy} / OOM"
                ),
            ),
        )
    )
    return report

In [None]:
tables = {}
for test in ("forward", "backward", "forward-backward"):
    for dtype in ("float32", "bfloat16"):
        tables[(test, dtype)] = prepare_table(df, test, dtype)

In [None]:
def print_report(df, title, unit):
    tex = df.to_latex(index=True, caption=f"{title} time in {unit} (naive pytorch / flash torch back / flash all triton)", escape=False)
    tex = tex.replace(r"\begin{table}", r"\begin{table}[H]")
    print(tex)

In [None]:
ranges = ((0, 4), (4, 7), (7, 10))
for name, table in tables.items():
    unit = "ms" if name[0] == "forward" else "s"
    for s, e in ranges:
        tbl = table.iloc[:, s:e]
        print_report(tbl, f"{name[0]} {name[1]}", unit)