In [1]:
from pathlib import Path
from enum import StrEnum


class STAGES(StrEnum):
    RIGID = "Rigid"
    AFFINE = "Affine"
    SYN = "SyN"


class EXPERIMENTS(StrEnum):
    AMP = "amp"
    FP64 = "fp64"


N_LEVELS = 4


BOKEH_COLORBLIND = (
    "#0072B2",  # 0
    "#E69F00",  # 1
    "#F0E442",  # 2
    "#009E73",  # 3
    "#56B4E9",  # 4
    "#D55E00",  # 5
    "#CC79A7",  # 6
    "#000000",  # 7
)

# MODIFY IF NEEDED
FIG_DIR = Path("paper", "figures")
FIG_DIR.mkdir(parents=True, exist_ok=True)

# Data ingest

In [2]:
from collections import defaultdict
import re
from typing import Any, Callable

import pandas as pd


def cleanup_level(levels: list[str]) -> list[str]:
    rv: list[str] = []
    for level in levels:
        clean_level = level.strip("\n").strip()
        if not (
            clean_level.startswith("2DIAGNOSTIC")
            or clean_level.startswith("1DIAGNOSTIC")
        ):
            continue
        rv.append(clean_level)
    return rv


def data2table_amp(data: str) -> pd.DataFrame:
    table: defaultdict[str, Any] = defaultdict(list)
    for line in data.splitlines():
        if line in ["", "XX"]:
            continue
        cols = line.split(",")
        table["iteration"].append(int(cols[1]))
        table["metricValue"].append(float(cols[2]))
        table["vprec_precision"].append(int(cols[6]))
        table["pmin_estimate"].append(float(cols[7]))
        table["pmin_estimate_avg"].append(float(cols[8]))

    return pd.DataFrame(table | {"exp_id": EXPERIMENTS.AMP})


def data2table_fp64(data: str) -> pd.DataFrame:
    table: defaultdict[str, Any] = defaultdict(list)
    for line in data.splitlines():
        if line in ["", "XX"]:
            continue
        cols = line.split(",")
        table["iteration"].append(int(cols[1]))
        table["metricValue"].append(float(cols[2]))

    return pd.DataFrame(table | {"exp_id": EXPERIMENTS.FP64})


def parse_log(
    txt: str, *, data2table_fn: Callable[[str], pd.DataFrame]
) -> pd.DataFrame:
    P_LEVEL_HEADER = r"DIAGNOSTIC,Iteration,metricValue,convergenceValue,ITERATION_TIME_INDEX,SINCE_LAST.*|  Elapsed time"
    levels = re.split(P_LEVEL_HEADER, txt)
    raw_data = cleanup_level(levels)

    parsed_levels: list[pd.DataFrame] = []
    for i, stage in enumerate(STAGES):
        for level in range(N_LEVELS):
            level_data = raw_data[level + i * N_LEVELS]
            table = data2table_fn(level_data)
            table["stage"] = stage
            table["level"] = level
            parsed_levels.append(table)

    return pd.concat(parsed_levels)


def get_data(log_dir: Path, *, data2table_fn: Callable[[str], pd.DataFrame]):
    data: list[pd.DataFrame] = list()

    for log in log_dir.glob("*.log"):
        subject_df = parse_log(
            log.read_text(),
            data2table_fn=data2table_fn,
        )
        subject_df["subject_id"] = log.stem
        data.append(subject_df)

    return pd.concat(data)


amp_df = get_data(Path("logs", EXPERIMENTS.AMP), data2table_fn=data2table_amp)
fp64_df = get_data(Path("logs", EXPERIMENTS.FP64), data2table_fn=data2table_fp64)

all_df = pd.concat([amp_df, fp64_df])
all_df

Unnamed: 0,iteration,metricValue,vprec_precision,pmin_estimate,pmin_estimate_avg,exp_id,stage,level,subject_id
0,1,-0.567814,8.0,2.982600,2.982600,amp,Rigid,0,0003059
1,2,-0.568658,8.0,3.617058,3.299829,amp,Rigid,0,0003059
2,3,-0.570341,8.0,3.949972,3.516543,amp,Rigid,0,0003059
3,4,-0.572271,8.0,4.260256,3.702472,amp,Rigid,0,0003059
4,5,-0.575280,8.0,3.620636,3.686104,amp,Rigid,0,0003059
...,...,...,...,...,...,...,...,...,...
15,16,-0.896145,,,,fp64,SyN,3,0003040
16,17,-0.896713,,,,fp64,SyN,3,0003040
17,18,-0.897232,,,,fp64,SyN,3,0003040
18,19,-0.897713,,,,fp64,SyN,3,0003040


# Precision requirement per iteration


In [3]:
import plotly.graph_objects as go
from plotly.colors import hex_to_rgb
from plotly.subplots import make_subplots


HLINE_COLOR = "darkgray"


def hex_to_rgba(hex_color: str, alpha: float) -> str:
    r, g, b = hex_to_rgb(hex_color)
    return f"rgba({r}, {g}, {b}, {alpha})"


def add_hline(fig: go.Figure, y: int, text: str) -> None:
    for i, stage in enumerate(STAGES, 1):
        if stage == STAGES.SYN and y >= 24:
            return
        # Only show the horizontal line in the last column
        line = dict(dash="solid", width=1, color=HLINE_COLOR)
        fig.add_hline(
            y,
            row=i,
            line=line,
        )
        fig.add_hline(
            y,
            row=i,
            col=N_LEVELS,
            line=line,
            annotation_text=text,
            annotation_position="right",
            annotation_xref="paper",
            annotation_font_size=12,
            annotation_showarrow=False,
            annotation_font_color="black",
        )


def display_hardware_support(fig: go.Figure, hw_support: dict[str, int]):
    for k, v in hw_support.items():
        add_hline(fig, y=v, text=k)


def adjust_layout(fig: go.Figure, title: str):
    fig.update_layout(
        height=len(STAGES) * 400,
        width=1200,
        font=dict(
            size=24,
            color="black",
        ),
        showlegend=True,
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.025,  # position above the subplots
            xanchor="left",
            x=0,
            traceorder="normal",
        ),
        title=title,
        margin=dict(l=10, r=40, t=150, b=10),
        plot_bgcolor="white",
    )
    for i in range(1, len(STAGES) * N_LEVELS + 1):
        fig.update_layout(
            {
                f"xaxis{i}": dict(showline=True, linewidth=2, linecolor="black"),
                f"yaxis{i}": dict(showline=True, linewidth=2, linecolor="black"),
            }
        )
    fig.update_yaxes(rangemode="tozero")

    # x-axis title
    fig.add_annotation(
        text="nth Iteration",
        xref="paper",
        yref="paper",
        x=0.5,
        y=-0.075,
        showarrow=False,
        font=dict(size=24),
    )
    fig.update_layout(margin=dict(b=80))


def make_legend(fig: go.Figure):
    fig.add_trace(
        go.Scatter(
            x=[1],
            y=[0],
            name="fp64",
            line=dict(dash="solid", color="black"),
            mode="lines",
            showlegend=True,
            hoverinfo="skip",
        ),
        row=1,
        col=1,
    )
    fig.add_trace(
        go.Scatter(
            x=[1],
            y=[0],
            name="amp",
            line=dict(dash="dash", color="black"),
            mode="lines",
            showlegend=True,
            hoverinfo="skip",
        ),
        row=1,
        col=1,
    )
    # Only add this ONCE (e.g., at subplot [1,1] or outside the loop)
    fig.add_trace(
        go.Scatter(
            x=[1],
            y=[0],
            mode="lines",
            line=dict(dash="solid", color=HLINE_COLOR),
            name="hardware support",
            showlegend=True,
            hoverinfo="skip",
        ),
        row=1,
        col=1,
    )


def trace_minmax(
    fig: go.Figure, stats: dict[str, pd.DataFrame], row: int, col: int
) -> None:
    # Add the Upper Bound (Max)
    fig.add_trace(
        go.Scatter(
            x=stats.index,
            y=stats["max"],
            line=dict(width=0),
            showlegend=False,
        ),
        row=row,
        col=col,
    )

    # Add lower bound (Min)
    fig.add_trace(
        go.Scatter(
            x=stats.index,
            y=stats["min"],
            line=dict(width=0),
            fill="tonexty",  # This performs the shading
            fillcolor="rgba(68, 68, 68, 0.2)",  # Gray with 20% opacity
            showlegend=False,
        ),
        row=row,
        col=col,
    )

    # Mean line
    fig.add_trace(
        go.Scatter(
            x=stats.index,
            y=stats["mean"],
            line=dict(
                color="black",
                width=3,
                dash="dot",
            ),
            showlegend=False,
        ),
        row=row,
        col=col,
    )


def trace_std(
    fig: go.Figure, stats: dict[str, pd.DataFrame], row: int, col: int
) -> None:
    # Add the Upper Bound (Max)
    fig.add_trace(
        go.Scatter(
            x=stats.index,
            y=stats["mean"] + stats["std"],
            line=dict(width=0),
            showlegend=False,
            # name=f"{stage} Max",
        ),
        row=row,
        col=col,
    )

    # Add lower bound (Min)
    fig.add_trace(
        go.Scatter(
            x=stats.index,
            y=stats["mean"] - stats["std"],
            line=dict(width=0),
            fill="tonexty",  # This performs the shading
            fillcolor="rgba(68, 68, 68, 0.2)",  # Gray with 20% opacity
            showlegend=False,
            # name=f"{stage} Min",
        ),
        row=row,
        col=col,
    )

    # Mean line
    fig.add_trace(
        go.Scatter(
            x=stats.index,
            y=stats["mean"],
            line=dict(
                color="black",
                width=3,
                dash="dot",
            ),
            showlegend=False,
        ),
        row=row,
        col=col,
    )


def show_and_save(fig: go.Figure, out_file: Path | None, show: bool):
    if show:
        fig.show()
    if out_file is not None:
        fig.write_image(str(out_file), scale=2)


def plot_pmin(
    df: pd.DataFrame,
    *,
    trace_fn,
    title: str,
    out_file: Path | None = None,
    show: bool = True,
) -> None:
    fig = make_subplots(
        rows=len(STAGES),
        cols=N_LEVELS,
        subplot_titles=[
            f"{stage}: Level {level}"
            for _, stage in enumerate(STAGES)
            for level in range(N_LEVELS)
        ],
        shared_xaxes=False,
        shared_yaxes=True,
        horizontal_spacing=0.025,
        vertical_spacing=0.075,
        # vertical_spacing=0.02,
    )
    # Update subplot titles font size
    for annotation in fig.layout.annotations:
        annotation.font = dict(size=18)

    for i, stage in enumerate(STAGES):
        for level in range(N_LEVELS):
            stage_level_df = df[(df["stage"] == stage) & (df["level"] == level)]
            stats = (
                stage_level_df.groupby("iteration")["pmin_estimate_avg"]
                .agg(["mean", "min", "max", "std"])
                .reset_index()
            )
            trace_fn(fig, stats, row=i + 1, col=level + 1)

    make_legend(fig)
    hw_support = {"FP32": 24, "FP24": 16, "TF32": 11, "BF16": 8}
    hw_support = {"FP64": 53, "FP32": 24, "FP24": 16, "TF32": 11, "BF16": 8}
    display_hardware_support(fig, hw_support=hw_support)
    adjust_layout(fig, title=title)

    show_and_save(fig, out_file=out_file, show=show)

In [4]:
plot_data = all_df[all_df["exp_id"] == EXPERIMENTS.AMP]
plot_pmin(
    plot_data,
    trace_fn=trace_minmax,
    title="Pmin: MinMax bounds",
    out_file=FIG_DIR / "pmin_minmax_bounds.png",
)
plot_pmin(
    plot_data,
    trace_fn=trace_std,
    title="Pmin: Std bounds",
    out_file=FIG_DIR / "pmin_std_bounds.png",
)

# Relative difference (AMP vs. FP64)
