# Plot functions


In [1]:
# import kaleido
# kaleido.get_chrome_sync()

BOKEH_COLORBLIND = (
    "#0072B2",  # 0
    "#E69F00",  # 1
    "#F0E442",  # 2
    "#009E73",  # 3
    "#56B4E9",  # 4
    "#D55E00",  # 5
    "#CC79A7",  # 6
    "#000000",  # 7
)

In [2]:
def set_plot_style(fig):
    fig.update_xaxes(
        showline=True,
        linewidth=1,
        linecolor="black",
        mirror=True,
        ticks="outside",
        tickcolor="black",
    )
    fig.update_yaxes(
        showline=True,
        linewidth=1,
        linecolor="black",
        mirror=True,
        ticks="outside",
        tickcolor="black",
    )
    fig.update_layout(
        width=600,
        height=600,
        font=dict(
            size=20,
            color="black",
        ),
        plot_bgcolor="white",
        legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0),
        margin=dict(l=15, r=50, t=10, b=15),
    )

In [3]:
from enum import Enum
import re
from pathlib import Path
import plotly.graph_objects as go

figures = Path("paper-miccai", "figures")
figures.mkdir(exist_ok=True, parents=False)


class LOSS_COLOR(Enum):
    BASELINE = BOKEH_COLORBLIND[0]
    DMP = BOKEH_COLORBLIND[1]
    BF16 = BOKEH_COLORBLIND[6]
    AMP = BOKEH_COLORBLIND[3]


def get_loss_values(log_file: Path) -> tuple[list[float], list[float], list[int]]:
    step_loss_values = []
    epoch_loss_values = []
    epoch_xaxis = []

    STEP_LOSS = "Step loss ="
    EPOCH_LOSS = "Loss"

    with log_file.open(encoding="utf-8", errors="ignore") as f:
        for line in f:
            if (STEP_LOSS in line) or (EPOCH_LOSS in line):
                step_loss_match = re.search(rf"{STEP_LOSS}\s(\d*\.?\d+)", line)
                epoch_loss_match = re.search(rf"{EPOCH_LOSS}:\s(\d*\.?\d+)", line)
                try:
                    if step_loss_match:
                        step_loss_values.append(float(step_loss_match.group(1)))
                    if epoch_loss_match:
                        epoch_loss_values.append(float(epoch_loss_match.group(1)))
                        epoch_xaxis.append(len(step_loss_values))
                except Exception as e:
                    raise ValueError(
                        f"Could not extract loss value from line: {line.strip()}"
                    ) from e

    return step_loss_values, epoch_loss_values, epoch_xaxis


def plot_loss(
    data: list[dict[str, any]],
    *,
    fout_stem: Path,
    xaxis: list[int],
    avg_only: bool = False,
) -> None:
    # Create Plotly figure
    fig = go.Figure()

    # Add scatter trace for pmin values
    for d in data:
        if not avg_only:
            fig.add_trace(
                go.Scatter(
                    y=d["values"],
                    mode="markers",
                    marker=dict(size=3, color=d["color"], opacity=0.05),
                    legendgroup=d["label"],
                    showlegend=False,
                )
            )
        fig.add_trace(
            go.Scatter(
                x=xaxis,
                y=d["avg"],
                mode="lines+markers",
                marker=dict(size=6, color=d["color"], opacity=0.5),
                name=d["label"],
                legendgroup=d["label"],
            )
        )

    fig.update_layout(
        xaxis_title="Batch Index",
        yaxis_title="Loss: MSE + 0.01 x Grad Loss",
    )
    set_plot_style(fig)

    # Save and show
    fig.write_html(fout_stem.with_suffix(".html"))
    fig.write_image(fout_stem.with_suffix(".pdf"))
    fig.show()

In [4]:
from enum import Enum


class Color(Enum):
    FP32 = "blue"
    FP24 = "grey"
    BF16 = "red"
    FP16 = "green"


def plot_pmin(pmin_type: str, /, *, log_file: Path, fout_stem: Path) -> None:
    # Parse the DMP log file
    pmin_values = []
    with log_file.open("r") as f:
        for line in f:
            # Look for lines containing "Estimated pmin"
            if pmin_type in line:
                pmin_match = re.search(rf"{pmin_type}\s+([\d.]+)", line)

                if pmin_match:
                    pmin_values.append(float(pmin_match.group(1)))

    colors = []
    for pmin in pmin_values:
        if pmin <= 7:
            colors.append(Color.FP16.value)  # FP16
        elif pmin <= 10:
            colors.append(Color.BF16.value)  # BF16
        else:
            colors.append(Color.FP32.value)  # FP32

    # Create Plotly figure
    fig = go.Figure()

    # Add scatter trace for pmin values
    fig.add_trace(
        go.Scatter(
            y=pmin_values,
            mode="markers",
            marker=dict(size=3, color=colors, opacity=0.3),
            name="pmin",
        )
    )

    thresholds = [
        (7, Color.FP16.value, "FP16"),
        (10, Color.BF16.value, "BF16"),
        (16, Color.FP24.value, "FP24<br>(N/A)"),
        (23, Color.FP32.value, "FP32"),
    ]

    for y_val, color, label in thresholds:
        # Add the horizontal line (without built-in annotation)
        fig.add_hline(y=y_val, line_color=color, line_dash="dash")

        # Add the text label separately using add_annotation
        fig.add_annotation(
            y=y_val,
            text=label,
            xref="paper",
            x=1.01,
            xanchor="left",
            yanchor="middle",
            showarrow=False,
            font_size=14,
        )

    # Update layout
    fig.update_layout(
        xaxis_title="Batch Index",
        yaxis_title="Estimated pmin",
    )
    set_plot_style(fig)

    # Save and show
    fig.write_html(fout_stem.with_suffix(".html"))
    fig.write_image(fout_stem.with_suffix(".pdf"))
    fig.show()

    print(f"Total data points: {len(pmin_values)}")
    if pmin_values:
        print(f"pmin range: [{min(pmin_values):.2f}, {max(pmin_values):.2f}]")


# Voxelmorph


In [5]:
y1, y1_avg, xaxis = get_loss_values(
    Path("logs", "voxelmorph", "voxelmorph_train_2d_default.log")
)
y2, y2_avg, _ = get_loss_values(
    Path("logs", "voxelmorph", "voxelmorph_train_2d_dmp.log")
)
y3, y3_avg, _ = get_loss_values(
    Path("logs", "voxelmorph", "voxelmorph_train_2d_bf16.log")
)
y4, y4_avg, _ = get_loss_values(
    Path("logs", "voxelmorph", "voxelmorph_train_2d_amp.log")
)
plot_loss(
    [
        {
            "values": y1,
            "avg": y1_avg,
            "label": "baseline",
            "color": LOSS_COLOR.BASELINE.value,
        },
        {
            "values": y2,
            "avg": y2_avg,
            "label": "dmp",
            "color": LOSS_COLOR.DMP.value,
        },
        # {
        #     "values": y3,
        #     "avg": y3_avg,
        #     "label": "bf16",
        #     "color": LOSS_COLOR.BF16.value,
        # },
        # {
        #     "values": y4,
        #     "avg": y4_avg,
        #     "label": "amp",
        #     "color": LOSS_COLOR.AMP.value,
        # },
    ],
    fout_stem=figures / "vm2d_loss",
    xaxis=xaxis,
)

plot_pmin(
    "PMIN moving avg:",
    log_file=Path("logs", "voxelmorph", "voxelmorph_train_2d_dmp.log"),
    fout_stem=figures.joinpath(f"vm2d_pmin_moving_avg"),
)


Total data points: 9999
pmin range: [9.72, 16.29]


In [6]:
y1, y1_avg, xaxis = get_loss_values(
    Path("logs", "voxelmorph", "voxelmorph_train_3d_default.log")
)
y2, y2_avg, _ = get_loss_values(
    Path("logs", "voxelmorph", "voxelmorph_train_3d_dmp.log")
)
plot_loss(
    [
        {
            "values": y1,
            "avg": y1_avg,
            "label": "baseline",
            "color": LOSS_COLOR.BASELINE.value,
        },
        {
            "values": y2,
            "avg": y2_avg,
            "label": "dmp",
            "color": LOSS_COLOR.DMP.value,
        },
    ],
    fout_stem=figures / "vm3d_loss",
    xaxis=xaxis,
)

plot_pmin(
    "PMIN moving avg:",
    log_file=Path("logs", "voxelmorph", "voxelmorph_train_3d_dmp.log"),
    fout_stem=figures.joinpath(f"vm3d_pmin_moving_avg"),
)

Total data points: 9999
pmin range: [9.61, 15.82]


# MNIST


In [7]:
y1, y1_avg, xaxis = get_loss_values(Path("logs", "mnist", "mnist_train_default.log"))
y2, y2_avg, _ = get_loss_values(Path("logs", "mnist", "mnist_train_dmp.log"))
y3, y3_avg, _ = get_loss_values(Path("logs", "mnist", "mnist_train_bf16.log"))
y4, y4_avg, _ = get_loss_values(Path("logs", "mnist", "mnist_train_amp.log"))
plot_loss(
    [
        {
            "values": y1,
            "avg": y1_avg,
            "label": "baseline",
            "color": LOSS_COLOR.BASELINE.value,
        },
        {
            "values": y2,
            "avg": y2_avg,
            "label": "dmp",
            "color": LOSS_COLOR.DMP.value,
        },
        {
            "values": y3,
            "avg": y3_avg,
            "label": "bf16",
            "color": LOSS_COLOR.BF16.value,
        },
        {
            "values": y4,
            "avg": y4_avg,
            "label": "amp",
            "color": LOSS_COLOR.AMP.value,
        },
    ],
    fout_stem=figures / "mnist_loss",
    xaxis=xaxis,
    avg_only=True,
)

plot_pmin(
    "PMIN moving avg:",
    log_file=Path("logs", "mnist", "mnist_train_dmp.log"),
    fout_stem=figures.joinpath(f"mnist_pmin_moving_avg"),
)


Total data points: 13131
pmin range: [6.11, 22.53]
