In [61]:
from pathlib import Path
from enum import Enum


class STAGES(str, Enum):
    RIGID = "Rigid"
    AFFINE = "Affine"
    SYN = "SyN"


class EXPERIMENTS(str, Enum):
    AMP = "amp"
    FP64 = "fp64"


N_LEVELS = 4


BOKEH_COLORBLIND = (
    "#0072B2",  # 0
    "#E69F00",  # 1
    "#F0E442",  # 2
    "#009E73",  # 3
    "#56B4E9",  # 4
    "#D55E00",  # 5
    "#CC79A7",  # 6
    "#000000",  # 7
)

# MODIFY IF NEEDED
FIG_DIR = Path("paper", "figures")
FIG_DIR.mkdir(parents=True, exist_ok=True)

# Data ingest

In [62]:
from collections import defaultdict
import re
from typing import Any, Callable

import pandas as pd


def cleanup_level(levels: list[str]) -> list[str]:
    rv: list[str] = []
    for level in levels:
        clean_level = level.strip("\n").strip()
        if not (
            clean_level.startswith("2DIAGNOSTIC")
            or clean_level.startswith("1DIAGNOSTIC")
        ):
            continue
        rv.append(clean_level)
    return rv


def data2table_amp(data: str) -> pd.DataFrame:
    table: defaultdict[str, Any] = defaultdict(list)
    for line in data.splitlines():
        if line in ["", "XX"]:
            continue
        cols = line.split(",")
        table["iteration"].append(int(cols[1]))
        table["metricValue"].append(float(cols[2]))
        table["vprec_precision"].append(int(cols[6]))
        table["pmin_estimate"].append(float(cols[7]))
        table["pmin_estimate_avg"].append(float(cols[8]))

    return pd.DataFrame(table | {"exp_id": EXPERIMENTS.AMP.value})


def data2table_fp64(data: str) -> pd.DataFrame:
    table: defaultdict[str, Any] = defaultdict(list)
    for line in data.splitlines():
        if line in ["", "XX"]:
            continue
        cols = line.split(",")
        table["iteration"].append(int(cols[1]))
        table["metricValue"].append(float(cols[2]))

    return pd.DataFrame(table | {"exp_id": EXPERIMENTS.FP64.value})


def parse_log(
    txt: str, *, data2table_fn: Callable[[str], pd.DataFrame]
) -> pd.DataFrame:
    P_LEVEL_HEADER = r"DIAGNOSTIC,Iteration,metricValue,convergenceValue,ITERATION_TIME_INDEX,SINCE_LAST.*|  Elapsed time"
    levels = re.split(P_LEVEL_HEADER, txt)
    raw_data = cleanup_level(levels)

    parsed_levels: list[pd.DataFrame] = []
    for i, stage in enumerate(STAGES):
        for level in range(N_LEVELS):
            level_data = raw_data[level + i * N_LEVELS]
            table = data2table_fn(level_data)
            table["stage"] = stage.value
            table["level"] = level
            parsed_levels.append(table)

    return pd.concat(parsed_levels)


def get_data(
    log_dir: Path, *, data2table_fn: Callable[[str], pd.DataFrame]
) -> pd.DataFrame:
    data: list[pd.DataFrame] = list()

    for log in log_dir.glob("*.log"):
        subject_df = parse_log(
            log.read_text(),
            data2table_fn=data2table_fn,
        )
        subject_df["subject_id"] = log.stem
        data.append(subject_df)

    return pd.concat(data).reset_index(drop=True)


amp_df = get_data(Path("logs", EXPERIMENTS.AMP.value), data2table_fn=data2table_amp)
fp64_df = get_data(Path("logs", EXPERIMENTS.FP64.value), data2table_fn=data2table_fp64)

all_df: pd.DataFrame = pd.concat([amp_df, fp64_df]).reset_index(drop=True)
to_drop = (
    (all_df["stage"] == STAGES.SYN.value)
    & (all_df["level"] == 0)
    & (all_df["iteration"] == 1)
)
all_df = all_df.loc[~to_drop].reset_index(drop=True)
all_df

Unnamed: 0,iteration,metricValue,vprec_precision,pmin_estimate,pmin_estimate_avg,exp_id,stage,level,subject_id
0,1,-0.567814,8.0,2.982600,2.982600,amp,Rigid,0,0003059
1,2,-0.568658,8.0,3.617058,3.299829,amp,Rigid,0,0003059
2,3,-0.570341,8.0,3.949972,3.516543,amp,Rigid,0,0003059
3,4,-0.572271,8.0,4.260256,3.702472,amp,Rigid,0,0003059
4,5,-0.575280,8.0,3.620636,3.686104,amp,Rigid,0,0003059
...,...,...,...,...,...,...,...,...,...
42467,16,-0.896145,,,,fp64,SyN,3,0003040
42468,17,-0.896713,,,,fp64,SyN,3,0003040
42469,18,-0.897232,,,,fp64,SyN,3,0003040
42470,19,-0.897713,,,,fp64,SyN,3,0003040


In [63]:
# TMP fix due to bug in ITK implementation of our method
all_df["vprec_precision"] = all_df["vprec_precision"].replace(
    {8: 7, 11: 10, 24: 16, 32: 23, 64: 53}
)
all_df["vprec_precision"].unique()

array([ 7., 10., 16., 23., 53., nan])

In [64]:
import plotly.graph_objects as go
from plotly.colors import hex_to_rgb
from plotly.subplots import make_subplots


def hex_to_rgba(hex_color: str, alpha: float) -> str:
    r, g, b = hex_to_rgb(hex_color)
    return f"rgba({r}, {g}, {b}, {alpha})"


def adjust_layout(
    fig: go.Figure, title: str, width: int | None = None, height: int | None = None
) -> None:
    fig.update_layout(
        height=height or (len(STAGES) * 400),
        width=width or 1200,
        font=dict(
            size=24,
            color="black",
        ),
        showlegend=True,
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.025,  # position above the subplots
            xanchor="left",
            x=0,
            traceorder="normal",
        ),
        title=title,
        margin=dict(l=10, r=50, t=150, b=10),
        plot_bgcolor="white",
    )
    for i in range(1, len(STAGES) * N_LEVELS + 1):
        fig.update_layout(
            {
                f"xaxis{i}": dict(showline=True, linewidth=2, linecolor="black"),
                f"yaxis{i}": dict(showline=True, linewidth=2, linecolor="black"),
            }
        )
    fig.update_yaxes(rangemode="tozero")


def get_base_fig(title: str) -> go.Figure:
    fig = make_subplots(
        rows=len(STAGES),
        cols=N_LEVELS,
        subplot_titles=[
            f"{stage.value}: Level {level}"
            for _, stage in enumerate(STAGES)
            for level in range(N_LEVELS)
        ],
        shared_xaxes=False,
        shared_yaxes=True,
        horizontal_spacing=0.025,
        vertical_spacing=0.075,
    )
    # Update subplot titles font size
    for annotation in fig.layout.annotations:
        annotation.font = dict(size=18)

    adjust_layout(fig, title=title)

    return fig


def show_and_save(fig: go.Figure, out_file: Path | None, show: bool):
    if show:
        fig.show()
    if out_file is not None:
        fig.update_layout(
            title=None,
            margin=dict(t=10),
        )
        out_file.parent.mkdir(parents=True, exist_ok=True)
        fig.write_image(str(out_file), scale=2)


# Precision requirement per iteration


In [65]:
HLINE_COLOR = "darkgray"


class COLOR(str, Enum):
    VPREC_PRECISION = BOKEH_COLORBLIND[0]
    PMIN_ESTIMATE = BOKEH_COLORBLIND[1]


def add_hline(fig: go.Figure, y: int, text: str, row: int) -> None:
    # Only show the horizontal line in the last column
    line = dict(dash="solid", width=1, color=HLINE_COLOR)
    fig.add_hline(
        y,
        row=row,
        line=line,
    )
    fig.add_hline(
        y,
        row=row,
        col=N_LEVELS,
        line=line,
        annotation_text=text,
        annotation_position="right",
        annotation_xref="paper",
        annotation_font_size=14,
        annotation_showarrow=False,
        annotation_font_color="black",
    )


def display_hardware_support(fig: go.Figure, hw_support: dict[str, int]):
    for row, stage in enumerate(STAGES, 1):
        for k, v in hw_support.items():
            if stage == STAGES.SYN and v >= 24:
                continue
            add_hline(fig, y=v, text=k, row=row)


def _trace_boudary(
    fig: go.Figure,
    *,
    x: pd.Series,
    upper: pd.Series,
    lower: pd.Series,
    mean: pd.Series,
    row: int,
    col: int,
    color: str,
    dash: str,
):
    # Add the Upper Bound (Max)
    fig.add_trace(
        go.Scatter(
            x=x,
            y=upper,
            line=dict(width=0),
            showlegend=False,
        ),
        row=row,
        col=col,
    )

    # Add lower bound (Min)
    fig.add_trace(
        go.Scatter(
            x=x,
            y=lower,
            line=dict(width=0),
            fill="tonexty",  # This performs the shading
            fillcolor=hex_to_rgba(color, 0.2),  # Gray with 20% opacity
            showlegend=False,
        ),
        row=row,
        col=col,
    )

    # Mean line
    fig.add_trace(
        go.Scatter(
            x=x,
            y=mean,
            line=dict(
                color=color,
                width=3,
                dash=dash,
            ),
            showlegend=False,
        ),
        row=row,
        col=col,
    )


def trace_minmax(
    fig: go.Figure,
    stats: dict[str, pd.DataFrame],
    row: int,
    col: int,
    color: str,
    dash: str = "solid",
) -> None:
    _trace_boudary(
        fig,
        x=stats.index,
        upper=stats["max"],
        lower=stats["min"],
        mean=stats["mean"],
        row=row,
        col=col,
        color=color,
        dash=dash,
    )


def trace_std(
    fig: go.Figure,
    stats: dict[str, pd.DataFrame],
    row: int,
    col: int,
    color: str,
    dash: str = "solid",
) -> None:
    _trace_boudary(
        fig,
        x=stats.index,
        upper=stats["mean"] + stats["std"],
        lower=stats["mean"] - stats["std"],
        mean=stats["mean"],
        row=row,
        col=col,
        color=color,
        dash=dash,
    )


def make_legend(fig: go.Figure):
    fig.add_trace(
        go.Scatter(
            x=[1],
            y=[0],
            name="Virtual precision",
            line=dict(dash="solid", color=COLOR.VPREC_PRECISION.value),
            mode="lines",
            showlegend=True,
            hoverinfo="skip",
        ),
        row=1,
        col=1,
    )
    fig.add_trace(
        go.Scatter(
            x=[1],
            y=[0],
            name="Estimated pmin",
            line=dict(dash="dash", color=COLOR.PMIN_ESTIMATE.value),
            mode="lines",
            showlegend=True,
            hoverinfo="skip",
        ),
        row=1,
        col=1,
    )
    # Only add this ONCE (e.g., at subplot [1,1] or outside the loop)
    fig.add_trace(
        go.Scatter(
            x=[1],
            y=[0],
            mode="lines",
            line=dict(dash="solid", color=HLINE_COLOR),
            name="Hardware support",
            showlegend=True,
            hoverinfo="skip",
        ),
        row=1,
        col=1,
    )


def plot_pmin(
    df: pd.DataFrame,
    *,
    trace_fn,
    title: str,
    out_file: Path | None = None,
    show: bool = True,
) -> None:
    fig = get_base_fig(title)

    for i, stage in enumerate(STAGES):
        for level in range(N_LEVELS):
            stage_level_df = df[(df["stage"] == stage) & (df["level"] == level)]

            # Virtual precision
            stats = (
                stage_level_df.groupby("iteration")["vprec_precision"]
                .agg(["mean", "min", "max", "std"])
                .reset_index()
            )
            trace_fn(
                fig,
                stats,
                row=i + 1,
                col=level + 1,
                dash="solid",
                color=COLOR.VPREC_PRECISION.value,
            )
            # Estimated pmin
            stats = (
                stage_level_df.groupby("iteration")["pmin_estimate_avg"]
                .agg(["mean", "min", "max", "std"])
                .reset_index()
            )
            trace_fn(
                fig,
                stats,
                row=i + 1,
                col=level + 1,
                dash="dot",
                color=COLOR.PMIN_ESTIMATE.value,
            )

    make_legend(fig)
    hw_support = {"FP32": 24, "FP24": 16, "TF32": 11, "BF16": 8}
    hw_support = {"FP64": 53, "FP32": 24, "FP24": 16, "TF32": 11, "BF16": 8}
    display_hardware_support(fig, hw_support=hw_support)

    # x-axis title
    fig.add_annotation(
        text="nth Iteration",
        xref="paper",
        yref="paper",
        x=0.5,
        y=-0.075,
        showarrow=False,
        font=dict(size=24),
    )
    # y-axis title
    fig.add_annotation(
        text="Number of bits",
        xref="paper",
        yref="paper",
        x=-0.075,
        y=0.5,
        textangle=-90,
        showarrow=False,
        font=dict(size=24),
    )
    fig.update_layout(margin=dict(b=80, l=80))

    show_and_save(fig, out_file=out_file, show=show)

In [66]:
plot_data = all_df[all_df["exp_id"] == EXPERIMENTS.AMP.value]
plot_pmin(
    plot_data,
    trace_fn=trace_minmax,
    title="Pmin: MinMax bounds",
    out_file=FIG_DIR / "pmin_minmax_bounds.png",
)
plot_pmin(
    plot_data,
    trace_fn=trace_std,
    title="Pmin: Std bounds",
    out_file=FIG_DIR / "pmin_std_bounds.png",
)

# Relative difference (AMP vs. FP64)


In [67]:
df_last = all_df.sort_values("iteration", ascending=False).drop_duplicates(
    subset=["subject_id", "stage", "level", "exp_id"], keep="first"
)

df_pivot = df_last.pivot_table(
    index=["subject_id", "stage", "level"],
    columns="exp_id",
    values="metricValue",
).reset_index()

df_pivot["diff_fp64_amp"] = (df_pivot["amp"] - df_pivot["fp64"]) / df_pivot[
    "fp64"
].abs()
df_pivot


exp_id,subject_id,stage,level,amp,fp64,diff_fp64_amp
0,0003001,Affine,0,-0.871946,-0.871005,-1.080184e-03
1,0003001,Affine,1,-0.730982,-0.730976,-8.509114e-06
2,0003001,Affine,2,-0.620231,-0.620194,-5.956523e-05
3,0003001,Affine,3,-0.551107,-0.551397,5.243211e-04
4,0003001,Rigid,0,-0.665033,-0.665033,6.037296e-10
...,...,...,...,...,...,...
595,0003064,Rigid,3,-0.351596,-0.351596,0.000000e+00
596,0003064,SyN,0,-0.973531,-0.973531,4.917665e-09
597,0003064,SyN,1,-0.946442,-0.946442,-8.424920e-09
598,0003064,SyN,2,-0.932349,-0.932349,4.901598e-11


In [68]:
df_pivot[df_pivot["diff_fp64_amp"] < -0.001]


exp_id,subject_id,stage,level,amp,fp64,diff_fp64_amp
0,3001,Affine,0,-0.871946,-0.871005,-0.00108
88,3010,Rigid,0,-0.481422,-0.480912,-0.001061
242,3024,Affine,2,-0.633672,-0.632902,-0.001217
336,3038,Affine,0,-0.855163,-0.851151,-0.004713
346,3038,SyN,2,-0.932764,-0.931832,-0.001001
360,3040,Affine,0,-0.828309,-0.826714,-0.001929
363,3040,Affine,3,-0.548734,-0.547633,-0.002011
369,3040,SyN,1,-0.943624,-0.942394,-0.001306
372,3041,Affine,0,-0.878149,-0.876005,-0.002448
432,3049,Affine,0,-0.858337,-0.857361,-0.001139


In [69]:
THRESHOLDS = [1e-2, 5e-3, 1e-3, 1e-4]
THRESHOLDS = [5e-3, 1e-3, 1e-4]


def display_thresholds(fig: go.Figure, thresholds: list[float]):
    for row, stage in enumerate(STAGES, 1):
        for treshold in thresholds:
            if stage == STAGES.SYN and treshold >= 2e-3:
                continue
            if stage == STAGES.RIGID and treshold >= 4e-3:
                continue
            add_hline(fig, y=treshold, text=f"{treshold:.0e}", row=row)


def plot_rel_diff(
    df: pd.DataFrame,
    *,
    value_col: str,
    title: str,
    out_file: Path | None = None,
    show: bool = True,
) -> None:
    fig = go.Figure()
    adjust_layout(fig, title=title, width=1200, height=400)

    for i, stage in enumerate(STAGES):
        for level in range(N_LEVELS):
            stage_level_df = df[(df["stage"] == stage) & (df["level"] == level)]

            fig.add_trace(
                go.Box(
                    y=stage_level_df[value_col],
                    boxpoints="all",
                    pointpos=0,
                    marker=dict(
                        color=hex_to_rgba(BOKEH_COLORBLIND[i], 0.3),
                        line=dict(
                            width=1,
                            color=hex_to_rgba(BOKEH_COLORBLIND[i], 1),
                        ),
                        size=8,
                    ),
                    # line=dict(color="rgba(0,0,0,0)"),
                    # fillcolor="rgba(0,0,0,0)",
                    showlegend=False,
                    name=f"{stage}:{level}",
                ),
            )

    for threshold in THRESHOLDS:
        fig.add_hline(
            y=threshold,
            line=dict(dash="solid", width=1, color=HLINE_COLOR),
            annotation_text=f"{threshold:.0e}",
            annotation_position="right",
            annotation_font_size=14,
            annotation_showarrow=False,
            annotation_font_color="black",
        )

    # y-axis
    fig.update_yaxes(
        title="Rel. diff. in loss",
        title_font=dict(size=24),
    )
    fig.update_layout(
        yaxis=dict(
            tickformat=".0e"  # This is the D3 equivalent of "{x:.0e}"
        )
    )

    show_and_save(fig, out_file=out_file, show=show)

In [70]:
plot_rel_diff(
    df_pivot,
    value_col="diff_fp64_amp",
    title="Relative Difference in Metric Value (AMP vs. FP64)",
    out_file=FIG_DIR / "rel_diff_fp64_amp.png",
)

In [71]:
df_pivot[df_pivot["diff_fp64_amp"] < 0]

exp_id,subject_id,stage,level,amp,fp64,diff_fp64_amp
0,0003001,Affine,0,-0.871946,-0.871005,-1.080184e-03
1,0003001,Affine,1,-0.730982,-0.730976,-8.509114e-06
2,0003001,Affine,2,-0.620231,-0.620194,-5.956523e-05
5,0003001,Rigid,1,-0.528809,-0.528809,-1.001118e-09
6,0003001,Rigid,2,-0.456552,-0.456475,-1.682493e-04
...,...,...,...,...,...,...
588,0003064,Affine,0,-0.833238,-0.833238,-1.538576e-10
589,0003064,Affine,1,-0.712932,-0.712932,-3.146865e-09
590,0003064,Affine,2,-0.610534,-0.610534,-1.786468e-09
591,0003064,Affine,3,-0.543209,-0.543209,-1.027781e-09


In [72]:
df_pivot

exp_id,subject_id,stage,level,amp,fp64,diff_fp64_amp
0,0003001,Affine,0,-0.871946,-0.871005,-1.080184e-03
1,0003001,Affine,1,-0.730982,-0.730976,-8.509114e-06
2,0003001,Affine,2,-0.620231,-0.620194,-5.956523e-05
3,0003001,Affine,3,-0.551107,-0.551397,5.243211e-04
4,0003001,Rigid,0,-0.665033,-0.665033,6.037296e-10
...,...,...,...,...,...,...
595,0003064,Rigid,3,-0.351596,-0.351596,0.000000e+00
596,0003064,SyN,0,-0.973531,-0.973531,4.917665e-09
597,0003064,SyN,1,-0.946442,-0.946442,-8.424920e-09
598,0003064,SyN,2,-0.932349,-0.932349,4.901598e-11
