# Utils


In [177]:
import math
from pathlib import Path
from typing import Callable, Optional

import numpy as np
import plotly.graph_objects as go


def _plot_metric(
    data: dict[dict[int]],
    metrics: dict[float],
    *,
    stage_metric_stats: np.array,
    metric_func: Callable,
    std_delta_func: Callable,
    title: str,
    n_subjects: int,
    exp_metrics: dict[float] = None,
    stage_metrics: dict[float] = None,
    threshold_zones: bool = False,
    out_dir: Optional[Path] = None,
    save_name: Optional[str] = None,
    show: bool = True,
):
    # Extract x, y, and intensity values
    data = data.copy()
    x = sorted({key for subdict in data.values() for key in subdict.keys()})
    y = sorted(data.keys())
    z = np.zeros((len(y), len(x)))

    # Populate the intensity grid
    for yi, y_key in enumerate(y):
        for xi, x_key in enumerate(x):
            z[yi, xi] = metrics.get(f"r{y_key}-p{x_key}", np.nan)

    # Compute metric transformation
    delta = metric_func(z, metrics)
    fp64_value = metric_func(stage_metric_stats["binary64"], stage_metric_stats)
    all_delta = metric_func(
        np.array([x for y in stage_metrics.values() for x in y]), stage_metric_stats
    )
    max_abs_diff = np.nanmax(np.abs(all_delta - fp64_value))
    zmin = fp64_value - max_abs_diff
    zmax = fp64_value + max_abs_diff
    print(
        ",".join(
            [
                f"zmin: {zmin:.4f}",
                f"zmax: {zmax:.4f}",
                f"ref_value: {fp64_value:.4f}",
                f"max_abs_diff: {max_abs_diff:.4f}",
            ]
        )
    )

    # Convert NaN to 1 to display with colorscale
    masked_failed = np.where(np.isnan(delta), 1, np.nan)
    delta = np.where(np.isnan(delta), np.nan, delta)

    if exp_metrics:
        std_delta = std_delta_func(x, y, exp_metrics)
        formatter = np.vectorize(lambda d, s: f"{d:.3f}<br>± {s:.3f}")
        formatted = formatter(delta, std_delta)
    else:
        formatter = np.vectorize(lambda d: f"{d:.3f}")
        formatted = formatter(delta)

    text_labels = np.where(delta <= 1, formatted, "")

    # Define custom colorscale
    colorscale = [
        [0, "green"],
        [0.5, "white"],
        [1, "red"],
    ]

    # Create heatmap
    normal_trace = go.Heatmap(
        z=delta,
        x=x,
        y=y,
        colorscale=colorscale,
        zmin=zmin,
        zmax=zmax,
        zmid=fp64_value,
        # colorbar=dict(title="Metric"),
        text=text_labels,
        texttemplate="%{text}",
    )
    failed_trace = go.Heatmap(
        z=masked_failed,
        x=x,
        y=y,
        colorscale=[[0, "grey"], [1, "grey"]],
        showscale=False,
    )
    fig = go.Figure(data=[failed_trace, normal_trace])

    # Annotate failures
    for i in y:
        for j in x:
            try:
                cell_value = data[i][j]
            except KeyError:
                continue

            if cell_value > 0 and cell_value < n_subjects:
                fig.add_annotation(
                    x=j,
                    y=i - 0.4,  # Place the star at the top edge of the cell
                    text=f"⚠ Err: {cell_value}/{n_subjects}",  # Star symbol
                    showarrow=False,
                    font=dict(color="black"),
                    xanchor="center",
                    yanchor="bottom",  # Align the bottom of the text with the coordinate
                )

    # Annotate hardware data format
    ## bfloat16
    fig.add_annotation(
        x=7,
        y=8.25,
        text="bfloat16",
        showarrow=False,
        font=dict(color="black"),
        xanchor="center",
        yanchor="bottom",
    )
    ## float32
    fig.add_annotation(
        x=23,
        y=8.25,
        text="float32",
        showarrow=False,
        font=dict(color="black"),
        xanchor="center",
        yanchor="bottom",
    )

    # Threshold annotations
    if threshold_zones:
        for threshold in [
            (6.5, 9.5),
            (9.5, 13.5),
            (13.5, 16.5),
            (16.5, 19.5),
            (19.5, 23.5),
        ]:
            threshold_exp = math.floor(math.log10(2) * math.floor(threshold[1])) - 1
            fig.add_vrect(
                x0=threshold[0],
                x1=threshold[1],
                line_width=3,
                label=dict(
                    text=f"Threshold: 1-e{threshold_exp}",
                    textposition="top center",
                    font=dict(size=20, family="Times New Roman"),
                ),
            )
    # Style & Layout
    fig.update_layout(
        title=title,
        xaxis=dict(title="Precision (# of bits)", tickvals=x, ticks="outside"),
        yaxis=dict(title="Range (# of bits)", tickvals=y, ticks="outside"),
        width=1600,
        height=400,
    )

    if show:
        fig.show()

    if out_dir and save_name:
        out_dir.mkdir(exist_ok=True, parents=True)
        fig.write_html(out_dir / f"{save_name}.html")
        fig.write_image(out_dir / f"{save_name}.png")

In [178]:
def plot_metric_value(
    data: dict[dict[int]],
    metrics: dict[float],
    *,
    stage_metric_stats: np.array,
    title: str,
    n_subjects: int,
    exp_metrics: dict[float] = None,
    stage_metrics: dict[float] = None,
    out_dir: Optional[Path] = None,
    save_name: Optional[str] = None,
    show: bool = True,
):
    def _metric_func(z, metrics):
        return z

    def _std_delta_func(x, y, exp_metrics):
        std_delta = np.zeros((len(y), len(x)))
        for yi, y_key in enumerate(y):
            for xi, x_key in enumerate(x):
                std_delta[yi, xi] = np.nanstd(
                    np.asarray(exp_metrics.get(f"r{y_key}-p{x_key}", np.nan))
                )
        return std_delta

    _plot_metric(
        data,
        metrics,
        stage_metric_stats=stage_metric_stats,
        title=title,
        n_subjects=n_subjects,
        exp_metrics=exp_metrics,
        stage_metrics=stage_metrics,
        metric_func=_metric_func,
        std_delta_func=_std_delta_func,
        out_dir=out_dir,
        save_name=save_name,
        show=show,
    )

In [179]:
def plot_metric_diff(
    data: dict[dict[int]],
    metrics: dict[float],
    *,
    stage_metric_stats: np.array,
    title: str,
    n_subjects: int,
    exp_metrics: dict[float] = None,
    stage_metrics: dict[float] = None,
    out_dir: Optional[Path] = None,
    save_name: Optional[str] = None,
    show: bool = True,
):
    def _metric_func(z, metrics):
        return z - metrics["binary64"]

    def _std_delta_func(x, y, exp_metrics):
        std_delta = np.zeros((len(y), len(x)))
        for yi, y_key in enumerate(y):
            for xi, x_key in enumerate(x):
                std_delta[yi, xi] = np.nanstd(
                    np.asarray(exp_metrics.get(f"r{y_key}-p{x_key}", np.nan))
                    - np.asarray(exp_metrics["binary64"])
                )
        return std_delta

    _plot_metric(
        data,
        metrics,
        stage_metric_stats=stage_metric_stats,
        title=title,
        n_subjects=n_subjects,
        exp_metrics=exp_metrics,
        stage_metrics=stage_metrics,
        metric_func=_metric_func,
        std_delta_func=_std_delta_func,
        out_dir=out_dir,
        save_name=save_name,
        show=show,
    )

In [180]:
def plot_metric_rel(
    data: dict[dict[int]],
    metrics: dict[float],
    *,
    stage_metric_stats: np.array,
    title: str,
    n_subjects: int,
    exp_metrics: dict[float] = None,
    stage_metrics: dict[float] = None,
    out_dir: Optional[Path] = None,
    save_name: Optional[str] = None,
    show: bool = True,
):
    def _metric_func(z, metrics):
        return (z - metrics["binary64"]) / np.abs(metrics["binary64"])

    def _std_delta_func(x, y, exp_metrics):
        std_delta = np.zeros((len(y), len(x)))
        for yi, y_key in enumerate(y):
            for xi, x_key in enumerate(x):
                std_delta[yi, xi] = np.nanstd(
                    np.abs(
                        np.asarray(exp_metrics.get(f"r{y_key}-p{x_key}", np.nan))
                        - np.asarray(exp_metrics["binary64"])
                    )
                    / np.abs(exp_metrics["binary64"])
                )
        return std_delta

    _plot_metric(
        data,
        metrics,
        stage_metric_stats=stage_metric_stats,
        title=title,
        n_subjects=n_subjects,
        exp_metrics=exp_metrics,
        stage_metrics=stage_metrics,
        metric_func=_metric_func,
        std_delta_func=_std_delta_func,
        out_dir=out_dir,
        save_name=save_name,
        show=show,
    )

In [154]:
from pathlib import Path
from typing import Optional

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots


# Avoid bokeh dependency, by hardcoding the colorblind8 palette
Colorblind8 = (
    "#0072B2",
    "#E69F00",
    "#F0E442",
    "#009E73",
    "#56B4E9",
    "#D55E00",
    "#CC79A7",
    "#000000",
)


def _make_agg_figure(
    data,
    *,
    out_dir: Optional[Path] = None,
    y_value: str = None,
    y_log: bool = False,
    show: bool = False,
    data_formats: list[str] = None,
    exp_id: Optional[str] = None,
):
    # Unique values for levels and data_formats to determine grid size
    levels = data["level"].unique()
    stages = data["stage"].unique()
    if not data_formats:
        data_formats = data["data_format"].unique()

    # Creating the subplot grid with independent x-axes
    fig = make_subplots(
        rows=len(stages),
        cols=len(levels),
        shared_xaxes=False,
        shared_yaxes=True,
        horizontal_spacing=0.01,
        vertical_spacing=0.05,
    )

    # Loop through each level and data_format to add traces to the respective subplot
    for l, level in enumerate(levels, 1):
        for s, stage in enumerate(stages, 1):
            for e, data_format in enumerate(data_formats):
                for i, subject in enumerate(data["subject"].unique()):
                    subset = data[
                        (data["level"] == level)
                        & (data["stage"] == stage)
                        & (data["data_format"] == data_format)
                        & (data["subject"] == subject)
                    ]

                    # Line color
                    color = px.colors.qualitative.Dark24[
                        e % len(px.colors.qualitative.Dark24)
                    ]

                    if data_format == "binary64":
                        color = "black"
                        dash = "dot"
                    else:
                        color = f"rgba{px.colors.hex_to_rgb(color) + (0.4,)}"
                        dash = "solid"

                    fig.add_trace(
                        go.Scatter(
                            x=subset["iterations"],
                            y=subset[y_value],
                            name=data_format,
                            mode="lines",
                            fillcolor=color,
                            line=dict(color=color, dash=dash),
                            legendgroup=data_format,
                            showlegend=(
                                True if (l == 1 and s == 1 and i == 0) else False
                            ),
                            hovertemplate=f"Subject: {subject}<br>Iteration: %{{x}}<br>{y_value}: %{{y}}",
                        ),
                        row=s,
                        col=l,
                    )

    for i, sigma in enumerate(list(range(len(levels)))[::-1], 1):
        fig.add_annotation(
            text=f"Sigma: {sigma}",
            xref="x domain",
            yref="y domain",
            x=0.5,
            y=1.05,
            xanchor="center",
            yanchor="bottom",
            row=1,
            col=i,
            showarrow=False,
            font=dict(size=14, color="black"),
        )

    for i, stage in enumerate(["Rigid (MI)", "Affine (MI)", "SyN (CC)"], 1):
        fig.add_annotation(
            text=stage,
            xref="x domain",
            yref="y domain",
            x=1.05,
            y=0.4,
            xanchor="center",
            yanchor="bottom",
            textangle=90,
            row=i,
            col=len(levels),
            showarrow=False,
            font=dict(size=14, color="black"),
        )

    fig.update_xaxes(title_text="Iterations", row=3)
    fig.update_yaxes(title_text=y_value.replace("_", " ").capitalize(), col=1)
    if y_log:
        fig.update_yaxes(type="log", exponentformat="e", range=[-7, 0])
    title = f"[{exp_id}] All subjects" if exp_id else "All subjects"
    fig.update_layout(
        title=title,
        showlegend=True,
        legend_title="Experiment ID",
        width=1920,
        height=1080,
    )

    if show:
        fig.show()
    if out_dir:
        out_dir.mkdir(exist_ok=True, parents=True)
        fig.write_html(out_dir / f"{y_value}.html")
        fig.write_image(out_dir / f"{y_value}.png")

# Data processing


In [6]:
from pathlib import Path
import re

DEBUG = True
N_SUBJECT = 20

FILENAME_PATTERN = re.compile(
    r"antsRegistration-r(?P<range>\d+)-p(?P<precision>\d+)-(?P<task_id>\d+)-(?P<array_id>\d+)\.out"
)

figures = Path("figures")
figures.mkdir(exist_ok=True, parents=True)

## Check failed execution


In [7]:
from typing import Iterable


def get_failed_exec(logs: Path | Iterable[Path]) -> int:
    if isinstance(logs, Path):
        logs = [logs]

    failed = list()
    for log in logs:
        if log.read_text().count("Elapsed time (stage 0):") != 4:
            failed.append(log)
    return failed

In [8]:
# Group logs by task_id
import re
from collections import defaultdict


logs = list(Path("log").rglob(f"antsRegistration-r*-p*-*.out"))

experiements = defaultdict(list)
for log in logs:
    m = FILENAME_PATTERN.match(log.name)
    if not m:
        print(log)
        raise ValueError("Invalid log file name")
    experiements[m.group("task_id")].append(log)

In [9]:
# Count failed execution for each task_id
n_failed = defaultdict(dict)
for task_id, logs in experiements.items():
    failed = get_failed_exec(logs)

    range_ = int(FILENAME_PATTERN.match(logs[0].name).group("range"))
    precision_ = int(FILENAME_PATTERN.match(logs[0].name).group("precision"))
    n_failed[range_][precision_] = len(failed)

    print(f"Task {task_id} (r{range_}-p{precision_}): {len(failed)} failed")
    if DEBUG and len(failed) != len(logs):
        for log in failed:
            print(log, end="\n\n")

# Manually add failure for range 6
n_failed[6] = {r: N_SUBJECT for r in range(7, 24)}

Task 42143358 (r8-p14): 0 failed
Task 42143377 (r8-p23): 0 failed
Task 42143349 (r7-p9): 0 failed
Task 42143346 (r8-p8): 0 failed
Task 42143366 (r7-p17): 0 failed
Task 42143351 (r7-p10): 0 failed
Task 42143344 (r8-p7): 0 failed
Task 42143362 (r7-p15): 0 failed
Task 42143342 (r8-p6): 20 failed
Task 42143374 (r7-p21): 0 failed
Task 42143368 (r7-p18): 0 failed
Task 42143354 (r8-p12): 0 failed
Task 42143372 (r7-p20): 0 failed
Task 42143378 (r7-p23): 0 failed
Task 42143353 (r7-p11): 0 failed
Task 42143369 (r8-p19): 0 failed
Task 42143356 (r8-p13): 0 failed
Task 42143361 (r8-p15): 0 failed
Task 42143375 (r8-p22): 0 failed
Task 42143371 (r8-p20): 0 failed
Task 42143350 (r8-p10): 0 failed
Task 42143364 (r7-p16): 0 failed
Task 42143357 (r7-p13): 0 failed
Task 42143359 (r7-p14): 0 failed
Task 42143343 (r7-p6): 20 failed
Task 42143348 (r8-p9): 0 failed
Task 42143345 (r7-p7): 0 failed
Task 42143347 (r7-p8): 0 failed
Task 42143365 (r8-p17): 0 failed
Task 42143363 (r8-p16): 0 failed
Task 42143367 (r

## QA using metric value


In [10]:
import pandas as pd


p_lvl_header = r"DIAGNOSTIC,Iteration,metricValue,convergenceValue,ITERATION_TIME_INDEX,SINCE_LAST|  Elapsed time"

logs = Path("log")

dfs = list()
for array_id in range(1, N_SUBJECT + 1):
    for filename in logs.rglob(f"antsRegistration-*-{array_id}.out"):
        if filename.name.startswith("antsRegistration-binary64"):
            data_format = "binary64"
        else:
            data_format = "-".join(filename.name.split("-")[1:3])

        txt = filename.read_text()

        # Extract subject_id from log
        subject_id = re.search(r"SUBJECT_ID: (?P<subject_id>.*)", txt).group(
            "subject_id"
        )

        # Filter out the header from the stages
        all_data = [
            x
            for x in re.split(p_lvl_header, txt)[1:]
            if not x.startswith(" (stage 0):")
        ]

        # 3 stages with 4 levels of resolution each
        for stage in range(1, 4):
            for i, level in enumerate(all_data[4 * (stage - 1) : 4 * stage]):
                table = defaultdict(list)
                for row in re.split(r"\n", level.strip("XX").strip()):
                    # Skip invalid rows
                    # e.g. error messages or write volumes to disk
                    if not ("2DIAGNOSTIC" in row or "1DIAGNOSTIC" in row):
                        continue

                    # Raise exception if the row is not as expected
                    try:
                        cols = row.split(",")
                        table["iterations"].append(cols[1].strip())
                        table["metric"].append(cols[2].strip())
                        table["convergence_value"].append(cols[3].strip())
                        table["total_time"].append(cols[4].strip())
                        table["since_last"].append(cols[5].strip())
                    except Exception as e:
                        print(cols)
                        raise e

                dfs.append(
                    pd.DataFrame(
                        data={
                            "subject": subject_id,
                            "array_id": array_id,
                            "exp_id": filename.parent.name,
                            "data_format": data_format,
                            "stage": stage,
                            "level": i + 1,
                            "iterations": table["iterations"],
                            "metric": table["metric"],
                            "convergence_value": table["convergence_value"],
                            "total_time": table["total_time"],
                            "since_last": table["since_last"],
                        }
                    )
                )

df = pd.concat(dfs, ignore_index=True)
df = df.astype(
    {
        "stage": int,
        "level": int,
        "iterations": int,
        "metric": np.float64,
        "convergence_value": np.float64,
        "total_time": np.float64,
        "since_last": np.float64,
    }
)
rv = (
    df.groupby(
        ["subject", "array_id", "exp_id", "data_format", "stage", "level", "iterations"]
    )
    .agg({"metric": "mean", "convergence_value": "mean", "total_time": "mean"})
    .reset_index()
)

In [150]:
def get_stage_metric_stats(df, stage):
    df_subset = df[(df["stage"] == stage) & (df["level"] == 4)]
    max_iter_idx = df_subset.groupby(["exp_id", "subject", "data_format"])[
        "iterations"
    ].idxmax()
    df_subset = df_subset.loc[
        max_iter_idx, ["exp_id", "subject", "data_format", "iterations", "metric"]
    ].reset_index(drop=True)
    exp_metrics = (
        df_subset.set_index(["exp_id", "subject", "data_format"])["metric"]
        .unstack()
        .to_dict()
    )

    subjects = df["subject"].unique()
    metric_stats = dict()

    def _get_metric(data, func):
        return {
            k: func(
                [
                    metric
                    for (_, subj_id), metric in v.items()
                    if subj_id in subjects and not np.isnan(metric)
                ]
            )
            for k, v in data.items()
        }

    metric_stats["worst"] = _get_metric(exp_metrics, np.nanmax)
    metric_stats["mean"] = _get_metric(exp_metrics, np.nanmean)
    metric_stats["best"] = _get_metric(exp_metrics, np.nanmin)
    metric_stats["raw_value"] = _get_metric(exp_metrics, lambda x: x)

    return metric_stats

# Figures


In [187]:
import warnings

warnings.filterwarnings("ignore")

STAGES = [("rigid", "MI"), ("affine", "MI"), ("syn", "CC")]
EXP_IDS = [
    # "0001",
    # "0011",
    # "0100",
    "0110",
    "0111",
    # "1111",
]


def get_experiment(df, exp_id):
    # Always include the binary64 data format
    if isinstance(exp_id, str):
        exp_id = [exp_id]
    return df[df["exp_id"].isin(exp_id + ["0000"])]


for stage_id, (stage_name, metric_name) in enumerate(STAGES, 1):
    stage_metric_stats = get_stage_metric_stats(get_experiment(rv, EXP_IDS), stage_id)

    for exp_id in EXP_IDS:
        out_dir = figures / exp_id
        exp_data = get_experiment(rv, exp_id)
        n_subjects = exp_data["subject"].nunique()
        exp_metric_stats = get_stage_metric_stats(exp_data, stage_id)

        """Metric value"""
        # plot_metric_value(
        #     n_failed,
        #     exp_metric_stats["worst"],
        #     stage_metric_stats=stage_metric_stats["worst"],
        #     title=f"[{exp_id}] {stage_name.capitalize()} ({metric_name}): Metric value - Worst subject",
        #     stage_metrics=stage_metric_stats["raw_value"],
        #     n_subjects=n_subjects,
        #     out_dir=out_dir,
        # )
        # plot_metric_value(
        #     n_failed,
        #     exp_metric_stats["mean"],
        #     stage_metric_stats=stage_metric_stats["mean"],
        #     title=f"[{exp_id}] {stage_name.capitalize()} ({metric_name}): Metric value - Subjects mean",
        #     exp_metrics=exp_metric_stats["raw_value"],
        #     stage_metrics=stage_metric_stats["raw_value"],
        #     n_subjects=n_subjects,
        #     out_dir=out_dir,
        # )
        # plot_metric_value(
        #     n_failed,
        #     exp_metric_stats["best"],
        #     stage_metric_stats=stage_metric_stats["best"],
        #     title=f"[{exp_id}] {stage_name.capitalize()} ({metric_name}): Metric value - Subjects best",
        #     stage_metrics=stage_metric_stats["raw_value"],
        #     n_subjects=n_subjects,
        # )

        """Metric difference with binary64"""
        # plot_metric_diff(
        #     n_failed,
        #     exp_metric_stats["worst"],
        #     stage_metric_stats=stage_metric_stats["worst"],
        #     title=f"[{exp_id}] {stage_name.capitalize()} ({metric_name}): Metric difference with Binary64 - Worst subject",
        #     stage_metrics=stage_metric_stats["raw_value"],
        #     n_subjects=n_subjects,
        #     out_dir=out_dir,
        # )
        plot_metric_diff(
            n_failed,
            exp_metric_stats["mean"],
            stage_metric_stats=stage_metric_stats["mean"],
            title=f"[{exp_id}] {stage_name.capitalize()} ({metric_name}): Metric difference with Binary64 - Subjects mean",
            exp_metrics=exp_metric_stats["raw_value"],
            stage_metrics=stage_metric_stats["raw_value"],
            n_subjects=n_subjects,
            out_dir=out_dir,
            save_name=f"{stage_name}-metric_diff-mean",
        )

        """Metric relative difference with binary64"""
        # plot_metric_rel(
        #     n_failed,
        #     exp_metric_stats["worst"],
        #     stage_metric_stats=stage_metric_stats["worst"],
        #     title=f"[{exp_id}] {stage_name.capitalize()} ({metric_name}): Relative difference with Binary64 - Worst subject",
        #     stage_metrics=stage_metric_stats["raw_value"],
        #     n_subjects=n_subjects,
        #     out_dir=out_dir,
        # )
        # plot_metric_rel(
        #     n_failed,
        #     exp_metric_stats["mean"],
        #     stage_metric_stats=stage_metric_stats["mean"],
        #     title=f"[{exp_id}] {stage_name.capitalize()} ({metric_name}): Relative difference with Binary64 - Subjects mean",
        #     exp_metrics=exp_metric_stats["raw_value"],
        #     stage_metrics=stage_metric_stats["raw_value"],
        #     n_subjects=n_subjects,
        #     out_dir=out_dir,
        # )

    # # Plot metric value and convergence value through time
    # data_formats = [
    #     "r7-p8",
    #     "r8-p23",
    #     "binary64",
    # ]
    # data = rv
    # _make_agg_figure(
    #     data,
    #     y_value="metric",
    #     show=True,
    #     data_formats=data_formats,
    #     out_dir=out_dir,
    #     exp_id=exp_id,
    # )
    # _make_agg_figure(
    #     data,
    #     y_value="convergence_value",
    #     y_log=True,
    #     show=True,
    #     data_formats=data_formats,
    #     out_dir=out_dir,
    #     exp_id=exp_id,
    # )

zmin: -0.1051,zmax: 0.1051,ref_value: 0.0000,max_abs_diff: 0.1051


zmin: -0.1051,zmax: 0.1051,ref_value: 0.0000,max_abs_diff: 0.1051


zmin: -0.2047,zmax: 0.2047,ref_value: 0.0000,max_abs_diff: 0.2047


zmin: -0.2047,zmax: 0.2047,ref_value: 0.0000,max_abs_diff: 0.2047


zmin: -0.1122,zmax: 0.1122,ref_value: 0.0000,max_abs_diff: 0.1122


zmin: -0.1122,zmax: 0.1122,ref_value: 0.0000,max_abs_diff: 0.1122


## SyN w/ SSIM


In [10]:
from collections import defaultdict
from pathlib import Path

import nibabel as nib
import numpy as np
import skimage as ski
from tqdm import tqdm

moving_images = list(
    Path("dataset/ds004513/derivatives/flint/antsRegistration").rglob("Warped.nii.gz")
)

In [None]:
fixed = nib.load(
    "/home/mathdugre/.cache/templateflow/tpl-MNI152NLin2009cAsym/tpl-MNI152NLin2009cAsym_res-01_desc-brain_T1w.nii.gz"
).get_fdata()
ssim = defaultdict(list)
for img in tqdm(moving_images):
    data_format = img.parent.parent.name
    img2 = nib.load(img).get_fdata()
    ssim[data_format].append(
        ski.metrics.structural_similarity(fixed, img2, data_range=1)
    )

In [30]:
ssim_mean = dict()
for k in ssim.keys():
    ssim_mean[k] = -np.nanmean(ssim[k])

In [None]:
plot_metric_value(
    n_failed,
    ssim_mean,
    title="SyN (SSIM): Metric value - Subject mean",
    n_subjects=n_subjects,
)
plot_metric_diff(
    n_failed,
    ssim_mean,
    title="SyN (SSIM): Metric diff w/ Binary64 - Subject mean",
    n_subjects=n_subjects,
)

## SyN w/ MI


In [None]:
import itk

fixed = itk.imread(
    "/home/mathdugre/.cache/templateflow/tpl-MNI152NLin2009cAsym/tpl-MNI152NLin2009cAsym_res-01_desc-brain_T1w.nii.gz",
    itk.D,
)

MetricType = itk.ANTSNeighborhoodCorrelationImageToImageMetricv4[
    type(fixed), type(fixed)
]

mi = defaultdict(list)
for img in tqdm(moving_images):
    data_format = img.parent.parent.name
    img2 = itk.imread(img, itk.D)

    metric = MetricType.New()
    metric.SetFixedImage(fixed)
    metric.SetMovingImage(img2)

    # Initialize the metric
    metric.Initialize()

    # Compute the similarity value
    mi[data_format].append(metric.GetValue())

In [24]:
mi_mean = dict()
for k in mi.keys():
    mi_mean[k] = np.nanmean(mi[k])

In [None]:
subjects = rv["subject"].unique()
n_subjects = rv["subject"].nunique()

plot_metric_value(
    n_failed,
    mi_mean,
    title="SyN (MI): Metric value - Subject mean",
    n_subjects=n_subjects,
)
plot_metric_diff(
    n_failed,
    mi_mean,
    title="SyN (MI): Metric diff w/ Binary64 - Subject mean",
    n_subjects=n_subjects,
)

## SyN MI (Python)


In [1]:
from collections import defaultdict
from pathlib import Path

import nibabel as nib
import numpy as np
import skimage as ski
from tqdm import tqdm

moving_images = list(
    Path("dataset/ds004513/derivatives/flint/antsRegistration").rglob("Warped.nii.gz")
)


def mutual_information(hgram):
    """Mutual information for joint histogram"""
    # Convert bins counts to probability values
    pxy = hgram / float(np.sum(hgram))
    px = np.sum(pxy, axis=1)  # marginal for x over y
    py = np.sum(pxy, axis=0)  # marginal for y over x
    px_py = px[:, None] * py[None, :]  # Broadcast to multiply marginals
    # Now we can do the calculation using the pxy, px_py 2D arrays
    nzs = pxy > 0  # Only non-zero pxy values contribute to the sum
    return np.sum(pxy[nzs] * np.log(pxy[nzs] / px_py[nzs]))

In [None]:
fixed = nib.load(
    "/home/mathdugre/.cache/templateflow/tpl-MNI152NLin2009cAsym/tpl-MNI152NLin2009cAsym_res-01_desc-brain_T1w.nii.gz"
).get_fdata()
py_mi = defaultdict(list)
for img in tqdm(moving_images):
    data_format = img.parent.parent.name
    img2 = nib.load(img).get_fdata()

    hist_2d, x_edges, y_edges = np.histogram2d(fixed.ravel(), img2.ravel(), bins=32)

    py_mi[data_format].append(mutual_information(hist_2d))

In [9]:
py_mi_mean = dict()
for k in py_mi.keys():
    py_mi_mean[k] = -np.nanmean(py_mi[k])

In [None]:
subjects = rv["subject"].unique()
n_subjects = rv["subject"].nunique()

plot_metric_value(
    n_failed,
    py_mi_mean,
    title="SyN (Python MI): Metric value - Subject mean",
    n_subjects=n_subjects,
)
plot_metric_diff(
    n_failed,
    py_mi_mean,
    title="SyN (Python MI): Metric diff w/ Binary64 - Subject mean",
    n_subjects=n_subjects,
)

# Thesis proposal figures


In [None]:
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots


# Avoid bokeh dependency, by hardcoding the colorblind8 palette
Colorblind8 = (
    "#0072B2",
    "#E69F00",
    "#F0E442",
    "#009E73",
    "#56B4E9",
    "#D55E00",
    "#CC79A7",
    "#000000",
)

# Creating the subplot grid with independent x-axes
fig = make_subplots(
    rows=3,
    cols=4,
    shared_xaxes=False,
    shared_yaxes=True,
    horizontal_spacing=0.01,
    vertical_spacing=0.05,
)

data_format = {"binary64": "float64", "r8-p23": "float32", "r7-p8": "bfloat16"}
max_iterations = rv.groupby(["data_format", "stage", "level", "subject"])[
    "iterations"
].max()

# Loop through each level and data_format to add traces to the respective subplot
for stage in range(1, 4):
    for level in range(1, 5):
        x_axis = list(data_format.values())
        y_values = [
            max_iterations[data_format][stage][level].values.mean()
            for data_format in data_format.keys()
        ]
        y_std = [
            max_iterations[data_format][stage][level].values.std()
            for data_format in data_format.keys()
        ]

        fig.add_trace(
            go.Bar(
                x=x_axis,
                y=y_values,
                error_y=dict(type="data", array=y_std),
                text=[f"{s:.1f}" for s in y_values],
                textposition="auto",
                marker=dict(color=Colorblind8),
                showlegend=False,
            ),
            row=stage,
            col=level,
        )

for i, sigma in enumerate(list(range(4))[::-1], 1):
    fig.add_annotation(
        text=f"Sigma: {sigma}",
        xref="x domain",
        yref="y domain",
        x=0.5,
        y=1.05,
        xanchor="center",
        yanchor="bottom",
        row=1,
        col=i,
        showarrow=False,
        font=dict(size=14, color="black"),
    )

for i, stage in enumerate(["Rigid", "Affine", "SyN"], 1):
    fig.add_annotation(
        text=stage,
        xref="x domain",
        yref="y domain",
        x=1.05,
        y=0.4,
        xanchor="center",
        yanchor="bottom",
        textangle=90,
        row=i,
        col=4,
        showarrow=False,
        font=dict(size=14, color="black"),
    )

# fig.update_xaxes(title_text="data format", row=3)
# fig.update_yaxes(title_text="Max iterations", col=1)
fig.update_layout(
    height=800,
)
fig.show()