# Utils


In [90]:
import math
from typing import Callable

import numpy as np
import plotly.graph_objects as go


def _plot_metric(
    failures: dict[dict[int]],
    metrics: dict[float],
    *,
    metric_func: Callable,
    std_delta_func: Callable,
    title: str,
    n_subjects: int,
    exp_metrics: dict[float] = None,
    threshold_zones: bool = False,
):
    # Extract x, y, and intensity values
    failures = failures.copy()
    x = sorted({key for subdict in failures.values() for key in subdict.keys()})
    y = sorted(failures.keys())
    z = np.zeros((len(y), len(x)))

    # Populate the intensity grid
    for yi, y_key in enumerate(y):
        for xi, x_key in enumerate(x):
            z[yi, xi] = metrics.get(f"r{y_key}-p{x_key}", np.nan)

    # Compute metric transformation
    delta = metric_func(z, metrics)
    m = np.nanmin(delta)
    M = np.nanmax(delta)
    norm_pivot = (0 - m) / (M - m)

    # Convert NaN to 1 to display with colorscale
    masked_failed = np.where(np.isnan(delta), 1, np.nan)
    delta = np.where(np.isnan(delta), np.nan, delta)

    if exp_metrics:
        std_delta = std_delta_func(x, y, exp_metrics)
        formatter = np.vectorize(lambda d, s: f"{d:.3f} ± {s:.3f}")
        formatted = formatter(delta, std_delta)
    else:
        formatter = np.vectorize(lambda d: f"{d:.3f}")
        formatted = formatter(delta)

    text_labels = np.where(delta <= 1, formatted, "")

    # Define custom colorscale
    # colorscale = [
    #     [0, "white"],
    #     [1, "red"],
    # ]
    colorscale = [
        [0, "green"],
        # [norm_pivot, "white"],
        [1, "red"],
    ]

    # Create heatmap
    normal_trace = go.Heatmap(
        z=delta,
        x=x,
        y=y,
        colorscale=colorscale,
        zmin=m,
        zmax=M,
        # colorbar=dict(title="Metric"),
        text=text_labels,
        texttemplate="%{text}",
    )
    failed_trace = go.Heatmap(
        z=masked_failed,
        x=x,
        y=y,
        colorscale=[[0, "grey"], [1, "grey"]],
        showscale=False,
    )
    fig = go.Figure(data=[failed_trace, normal_trace])

    # Annotate failures
    for i in y:
        for j in x:
            try:
                cell_value = failures[i][j]
            except KeyError:
                continue

            if cell_value > 0 and cell_value < n_subjects:
                fig.add_annotation(
                    x=j,
                    y=i - 0.4,  # Place the star at the top edge of the cell
                    text=f"⚠ Err: {cell_value}/{n_subjects}",  # Star symbol
                    showarrow=False,
                    font=dict(color="black"),
                    xanchor="center",
                    yanchor="bottom",  # Align the bottom of the text with the coordinate
                )

    # Annotate hardware data format
    ## bfloat16
    fig.add_annotation(
                    x=7,
                    y=8.1,
                    text="bfloat16",
                    showarrow=False,
                    font=dict(color="black"),
                    xanchor="center",
                    yanchor="bottom",
                )
    ## float32
    fig.add_annotation(
                    x=23,
                    y=8.1,
                    text="float32",
                    showarrow=False,
                    font=dict(color="black"),
                    xanchor="center",
                    yanchor="bottom",
                )



    # Threshold annotations
    if threshold_zones:
        for threshold in [
            (6.5, 9.5),
            (9.5, 13.5),
            (13.5, 16.5),
            (16.5, 19.5),
            (19.5, 23.5),
        ]:
            threshold_exp = math.floor(math.log10(2) * math.floor(threshold[1])) - 1
            fig.add_vrect(
                x0=threshold[0],
                x1=threshold[1],
                line_width=3,
                label=dict(
                    text=f"Threshold: 1-e{threshold_exp}",
                    textposition="top center",
                    font=dict(size=20, family="Times New Roman"),
                ),
            )
    # Style & Layout
    fig.update_layout(
        xaxis=dict(title="Precision (# of bits)", tickvals=x, ticks="outside"),
        yaxis=dict(title="Range (# of bits)", tickvals=y, ticks="outside"),
    )

    fig.update_layout(
        title=title,
    )
    fig.show()

In [65]:
def plot_metric_value(
    failures: dict[dict[int]],
    metrics: dict[float],
    *,
    title: str,
    n_subjects: int,
    exp_metrics: dict[float] = None,
):
    def _metric_func(z, metrics):
        return z


    def _std_delta_func(x, y, exp_metrics):
        std_delta = np.zeros((len(y), len(x)))
        for yi, y_key in enumerate(y):
            for xi, x_key in enumerate(x):
                std_delta[yi, xi] = np.nanstd(
                    np.asarray(exp_metrics.get(f"r{y_key}-p{x_key}", np.nan))
                )
        return std_delta
    
    _plot_metric(
        failures,
        metrics,
        title=title,
        n_subjects=n_subjects,
        exp_metrics=exp_metrics,
        metric_func=_metric_func,
        std_delta_func=_std_delta_func,
    )

In [66]:
def plot_metric_diff(
    failures: dict[dict[int]],
    metrics: dict[float],
    *,
    title: str,
    n_subjects: int,
    exp_metrics: dict[float] = None,
):
    def _metric_func(z, metrics):
        return z - metrics["binary64"]


    def _std_delta_func(x, y, exp_metrics):
        std_delta = np.zeros((len(y), len(x)))
        for yi, y_key in enumerate(y):
            for xi, x_key in enumerate(x):
                std_delta[yi, xi] = np.nanstd(
                    np.asarray(exp_metrics.get(f"r{y_key}-p{x_key}", np.nan))
                    - np.asarray(exp_metrics["binary64"])
                )
        return std_delta
    
    _plot_metric(
        failures,
        metrics,
        title=title,
        n_subjects=n_subjects,
        exp_metrics=exp_metrics,
        metric_func=_metric_func,
        std_delta_func=_std_delta_func,
    )

In [67]:
def plot_metric_rel(
    failures: dict[dict[int]],
    metrics: dict[float],
    *,
    title: str,
    n_subjects: int,
    exp_metrics: dict[float] = None,
):
    def _metric_func(z, metrics):
        return (z - metrics["binary64"]) / np.abs(metrics["binary64"])


    def _std_delta_func(x, y, exp_metrics):
        std_delta = np.zeros((len(y), len(x)))
        for yi, y_key in enumerate(y):
            for xi, x_key in enumerate(x):
                std_delta[yi, xi] = np.nanstd(
                    np.abs(
                        np.asarray(exp_metrics.get(f"r{y_key}-p{x_key}", np.nan))
                        - np.asarray(exp_metrics["binary64"])
                    )
                    / np.abs(exp_metrics["binary64"])
                )
        return std_delta

    _plot_metric(
        failures,
        metrics,
        title=title,
        n_subjects=n_subjects,
        exp_metrics=exp_metrics,
        metric_func=_metric_func,
        std_delta_func=_std_delta_func,
    )

In [177]:
from pathlib import Path
from typing import Optional

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots


# Avoid bokeh dependency, by hardcoding the colorblind8 palette
Colorblind8 = (
    "#0072B2",
    "#E69F00",
    "#F0E442",
    "#009E73",
    "#56B4E9",
    "#D55E00",
    "#CC79A7",
    "#000000",
)


def _make_agg_figure(
    data,
    *,
    out_dir: Optional[Path] = None,
    y_value: str = None,
    y_log: bool = False,
    show: bool = False,
    exp_ids: list[str] = None,
):
    # Unique values for levels and exp_ids to determine grid size
    levels = data["level"].unique()
    stages = data["stage"].unique()
    if not exp_ids:
        exp_ids = data["exp_id"].unique()

    # Creating the subplot grid with independent x-axes
    fig = make_subplots(
        rows=len(stages),
        cols=len(levels),
        shared_xaxes=False,
        shared_yaxes=True,
        horizontal_spacing=0.01,
        vertical_spacing=0.05,
    )

    # Loop through each level and exp_id to add traces to the respective subplot
    for l, level in enumerate(levels, 1):
        for s, stage in enumerate(stages, 1):
            for e, exp_id in enumerate(exp_ids):
                for i, subject in enumerate(data["subject"].unique()):
                    subset = data[
                        (data["level"] == level)
                        & (data["stage"] == stage)
                        & (data["exp_id"] == exp_id)
                        & (data["subject"] == subject)
                    ]

                    # Line color
                    color = px.colors.qualitative.Dark24[
                        e % len(px.colors.qualitative.Dark24)
                    ]

                    if exp_id == "binary64":
                        color = "black"
                        dash = "dot"
                    else:
                        color = f"rgba{px.colors.hex_to_rgb(color) + (0.4,)}"
                        dash = "solid"

                    fig.add_trace(
                        go.Scatter(
                            x=subset["iterations"],
                            y=subset[y_value],
                            name=exp_id,
                            mode="lines",
                            fillcolor=color,
                            line=dict(color=color, dash=dash),
                            legendgroup=exp_id,
                            showlegend=(
                                True if (l == 1 and s == 1 and i == 0) else False
                            ),
                            hovertemplate=f"Subject: {subject}<br>Iteration: %{{x}}<br>{y_value}: %{{y}}",
                        ),
                        row=s,
                        col=l,
                    )

    for i, sigma in enumerate(list(range(len(levels)))[::-1], 1):
        fig.add_annotation(
            text=f"Sigma: {sigma}",
            xref="x domain",
            yref="y domain",
            x=0.5,
            y=1.05,
            xanchor="center",
            yanchor="bottom",
            row=1,
            col=i,
            showarrow=False,
            font=dict(size=14, color="black"),
        )

    for i, stage in enumerate(["Rigid (MI)", "Affine (MI)", "SyN (CC)"], 1):
        fig.add_annotation(
            text=stage,
            xref="x domain",
            yref="y domain",
            x=1.05,
            y=0.4,
            xanchor="center",
            yanchor="bottom",
            textangle=90,
            row=i,
            col=len(levels),
            showarrow=False,
            font=dict(size=14, color="black"),
        )

    fig.update_xaxes(title_text="Iterations", row=3)
    fig.update_yaxes(title_text=y_value.replace("_", " ").capitalize(), col=1)
    if y_log:
        fig.update_yaxes(type="log", exponentformat="e", range=[-7, 0])
    fig.update_layout(
        title=f"All subjects",
        showlegend=True,
        legend_title="Experiment ID",
        width=1920,
        height=1080,
    )

    if show:
        fig.show()
    if out_dir:
        out_dir.joinpath(y_value).mkdir(exist_ok=True)
        fig.write_html(out_dir / y_value / f"all_subjects.html")
        fig.write_image(out_dir / y_value / f"all_subjects.png")

# Data processing


In [6]:
from pathlib import Path
import re

DEBUG = False
N_SUBJECT = 20

FILENAME_PATTERN = re.compile(
    r"r(?P<range>\d+)-p(?P<precision>\d+)-(?P<task_id>\d+)-(?P<array_id>\d+)\.out"
)

## Check failed execution


In [7]:
from typing import Iterable


def get_failed_exec(logs: Path | Iterable[Path]) -> int:
    if isinstance(logs, Path):
        logs = [logs]

    failed = list()
    for log in logs:
        with open(log) as f:
            text = f.read()
            if ("ITK ERROR" in text) and ("Elapsed time (stage 2):" not in text):
                failed.append(log)
    return failed

In [8]:
# Group logs by task_id
import re
from collections import defaultdict


logs = list(Path("log").glob(f"r*-p*-*.out"))

experiements = defaultdict(list)
for log in logs:
    m = FILENAME_PATTERN.match(log.name)
    if not m:
        print(log)
        raise ValueError("Invalid log file name")
    experiements[m.group("task_id")].append(log)

In [9]:
# Count failed execution for each task_id
n_failed = defaultdict(dict)
for task_id, logs in experiements.items():
    failed = get_failed_exec(logs)

    range_ = int(FILENAME_PATTERN.match(logs[0].name).group("range"))
    precision_ = int(FILENAME_PATTERN.match(logs[0].name).group("precision"))
    n_failed[range_][precision_] = len(failed)

    print(f"Task {task_id}: {len(failed)} failed")
    if DEBUG and len(failed) != len(logs):
        for log in failed:
            print(log, end="\n\n")

# Manually add failure for range 6
n_failed[6] = {r: N_SUBJECT for r in range(7, 24)}

Task 163579: 0 failed
Task 163580: 0 failed
Task 163577: 0 failed
Task 163578: 0 failed
Task 163581: 0 failed
Task 163582: 0 failed
Task 163583: 0 failed
Task 163603: 0 failed
Task 163604: 0 failed
Task 163605: 0 failed
Task 163606: 0 failed
Task 163607: 0 failed
Task 163608: 20 failed
Task 163609: 20 failed
Task 163610: 20 failed
Task 163611: 20 failed
Task 163612: 20 failed
Task 163613: 0 failed
Task 163614: 0 failed
Task 163615: 0 failed
Task 163616: 0 failed
Task 163617: 0 failed
Task 163618: 0 failed
Task 163619: 0 failed
Task 163620: 0 failed
Task 163621: 0 failed
Task 163622: 0 failed
Task 163623: 0 failed
Task 163624: 0 failed
Task 163625: 20 failed
Task 163626: 20 failed
Task 163627: 20 failed
Task 163628: 20 failed
Task 163629: 20 failed


## QA using metric value


In [10]:
import pandas as pd


p_lvl_header = r"DIAGNOSTIC,Iteration,metricValue,convergenceValue,ITERATION_TIME_INDEX,SINCE_LAST|  Elapsed time"

logs = Path("log")

dfs = list()
for array_id in range(1, N_SUBJECT + 1):
    filenames = list(logs.glob(f"r*-p*-*-{array_id}.out")) + list(
        logs.glob(f"antsRegistration-*-{array_id}.out")
    )
    for filename in filenames:
        if filename.name.startswith("antsRegistration-"):
            exp_id = "binary64"
        else:
            exp_id = "-".join(filename.name.split("-")[:2])

        txt = filename.read_text()

        # Extract subject_id from log
        subject_id = re.search(r"SUBJECT_ID: (?P<subject_id>.*)", txt).group(
            "subject_id"
        )

        # Filter out the header from the stages
        all_data = [x for i, x in enumerate(re.split(p_lvl_header, txt)) if i % 5 != 0]

        # 3 stages with 4 levels of resolution each
        for stage in range(1, 4):
            for i, level in enumerate(all_data[4 * (stage - 1) : 4 * stage]):
                table = defaultdict(list)
                for row in re.split(r"\n", level.strip("XX").strip()):
                    # Skip invalid rows
                    # e.g. error messages or write volumes to disk
                    if not ("2DIAGNOSTIC" in row or "1DIAGNOSTIC" in row):
                        continue

                    # Raise exception if the row is not as expected
                    try:
                        cols = row.split(",")
                        table["iterations"].append(cols[1].strip())
                        table["metric"].append(cols[2].strip())
                        table["convergence_value"].append(cols[3].strip())
                        table["total_time"].append(cols[4].strip())
                        table["since_last"].append(cols[5].strip())
                    except Exception as e:
                        print(cols)
                        raise e

                dfs.append(
                    pd.DataFrame(
                        data={
                            "subject": subject_id,
                            "array_id": array_id,
                            "exp_id": exp_id,
                            "stage": stage,
                            "level": i + 1,
                            "iterations": table["iterations"],
                            "metric": table["metric"],
                            "convergence_value": table["convergence_value"],
                            "total_time": table["total_time"],
                            "since_last": table["since_last"],
                        }
                    )
                )

df = pd.concat(dfs, ignore_index=True)
df = df.astype(
    {
        "stage": int,
        "level": int,
        "iterations": int,
        "metric": float,
        "convergence_value": float,
        "total_time": float,
        "since_last": float,
    }
)
rv = (
    df.groupby(["subject", "array_id", "exp_id", "stage", "level", "iterations"])
    .agg({"metric": "mean", "convergence_value": "mean", "total_time": "mean"})
    .reset_index()
)

In [178]:
figure_dir = Path("figures", "VPREC-exploration")
figure_dir.mkdir(exist_ok=True)

# exp_ids = [
#     "r7-p12",
#     "r7-p13",
#     "r7-p14",
#     "r7-p15",
#     "r7-p16",
#     "r7-p17",
#     "r7-p18",
#     "r7-p19",
#     "r7-p20",
#     "r7-p21",
#     "r7-p22",
#     "r7-p23",
#     "r8-p12",
#     "r8-p13",
#     "r8-p14",
#     "r8-p15",
#     "r8-p16",
#     "r8-p17",
#     "r8-p18",
#     "r8-p19",
#     "r8-p20",
#     "r8-p21",
#     "r8-p22",
#     "r8-p23",
#     "binary64",
# ]

exp_ids = [
    "r7-p12",
    "r8-p23",
    "binary64",
]


data = rv
_make_agg_figure(data, y_value="metric", show=True, exp_ids=exp_ids)
_make_agg_figure(
    data,
    y_value="convergence_value",
    y_log=True,
    show=True,
    exp_ids=exp_ids,
)

## Refined region plot (Rigid)


In [12]:
import numpy as np

STAGE = 1  # Rigid

max_iterations = (
    rv[(rv["stage"] == STAGE) & (rv["level"] == 4)]
    .groupby(["subject", "exp_id"])["iterations"]
    .max()
    .to_dict()
)
subjects = rv["subject"].unique()
exp_ids = rv["exp_id"].unique()
exp_metrics_rigid = defaultdict(dict)
for subject in subjects:
    for exp_id in exp_ids:
        try:
            _metric = rv[
                (rv["subject"] == subject)
                & (rv["exp_id"] == exp_id)
                & (rv["stage"] == STAGE)
                & (rv["level"] == 4)
                & (rv["iterations"] == max_iterations.get((subject, exp_id), 0))
            ]["metric"].values[0]
        except IndexError as e:
            if max_iterations.get((subject, exp_id), 0) == 0:
                _metric = np.nan
            else:
                raise e
        finally:
            exp_metrics_rigid[exp_id][subject] = _metric

In [179]:
subjects = rv["subject"].unique()
n_subjects = rv["subject"].nunique()

# subjects = ["s003"]
worst_metrics = {
    k: np.nanmax([metric for subj, metric in v.items() if subj in subjects])
    for k, v in exp_metrics_rigid.items()
}
mean_metrics = {
    k: np.nanmean([metric for subj, metric in v.items() if subj in subjects])
    for k, v in exp_metrics_rigid.items()
}

# Previous format used for the functions
_exp_metrics = {
    k: [metric for subj, metric in v.items() if subj in subjects]
    for k, v in exp_metrics_rigid.items()
}

In [180]:
plot_metric_value(
    n_failed,
    worst_metrics,
    title="Rigid (MI): Metric value - Worst subject",
    n_subjects=n_subjects,
)
plot_metric_value(
    n_failed,
    mean_metrics,
    title="Rigid (MI): Metric value - Subjects mean",
    exp_metrics=_exp_metrics,
    n_subjects=n_subjects,
)


Degrees of freedom <= 0 for slice.



In [182]:
plot_metric_diff(
    n_failed,
    worst_metrics,
    title="Rigid (MI): Metric difference with Binary64 - Worst subject",
    n_subjects=n_subjects,
)
plot_metric_diff(
    n_failed,
    mean_metrics,
    title="Rigid (MI): Metric difference with Binary64 - Subjects mean",
    exp_metrics=_exp_metrics,
    n_subjects=n_subjects,
)

In [181]:
plot_metric_rel(
    n_failed,
    worst_metrics,
    title="Rigid (MI): Relative difference with Binary64 - Worst subject",
    n_subjects=n_subjects,
)
plot_metric_rel(
    n_failed,
    mean_metrics,
    title="Rigid (MI): Relative difference with Binary64 - Subjects mean",
    exp_metrics=_exp_metrics,
    n_subjects=n_subjects,
)

## Refined region plot (Affine)


In [17]:
import numpy as np

STAGE = 2  # Affine

max_iterations = (
    rv[(rv["stage"] == STAGE) & (rv["level"] == 4)]
    .groupby(["subject", "exp_id"])["iterations"]
    .max()
    .to_dict()
)
subjects = rv["subject"].unique()
exp_ids = rv["exp_id"].unique()
exp_metrics_affine = defaultdict(dict)
for subject in subjects:
    for exp_id in exp_ids:
        try:
            _metric = rv[
                (rv["subject"] == subject)
                & (rv["exp_id"] == exp_id)
                & (rv["stage"] == STAGE)
                & (rv["level"] == 4)
                & (rv["iterations"] == max_iterations.get((subject, exp_id), 0))
            ]["metric"].values[0]
        except IndexError as e:
            if max_iterations.get((subject, exp_id), 0) == 0:
                _metric = np.nan
            else:
                raise e
        finally:
            exp_metrics_affine[exp_id][subject] = _metric

In [183]:
subjects = rv["subject"].unique()
n_subjects = rv["subject"].nunique()

# subjects = ["s003"]
worst_metrics = {
    k: np.nanmax([metric for subj, metric in v.items() if subj in subjects])
    for k, v in exp_metrics_affine.items()
}
mean_metrics = {
    k: np.nanmean([metric for subj, metric in v.items() if subj in subjects])
    for k, v in exp_metrics_affine.items()
}

# Previous format used for the functions
_exp_metrics = {
    k: [metric for subj, metric in v.items() if subj in subjects]
    for k, v in exp_metrics_affine.items()
}

In [184]:
plot_metric_value(
    n_failed,
    worst_metrics,
    title="Affine (MI): Metric value - Worst subject",
    n_subjects=n_subjects,
)
plot_metric_value(
    n_failed,
    mean_metrics,
    title="Affine (MI): Metric value - Subjects mean",
    exp_metrics=_exp_metrics,
    n_subjects=n_subjects,
)

In [186]:
plot_metric_diff(
    n_failed,
    worst_metrics,
    title="Affine (MI): Metric difference with Binary64 - Worst subject",
    n_subjects=n_subjects,
)
plot_metric_diff(
    n_failed,
    mean_metrics,
    title="Affine (MI): Metric difference with Binary64 - Subjects mean",
    exp_metrics=_exp_metrics,
    n_subjects=n_subjects,
)

In [185]:
plot_metric_rel(
    n_failed,
    worst_metrics,
    title="Affine (MI): Relative difference with Binary64 - Worst subject",
    n_subjects=n_subjects,
)
plot_metric_rel(
    n_failed,
    mean_metrics,
    title="Affine (MI): Relative difference with Binary64 - Subjects mean",
    exp_metrics=_exp_metrics,
    n_subjects=n_subjects,
)

## Refined region plot (SyN)


In [22]:
import numpy as np

STAGE = 3  # SyN

max_iterations = (
    rv[(rv["stage"] == STAGE) & (rv["level"] == 4)]
    .groupby(["subject", "exp_id"])["iterations"]
    .max()
    .to_dict()
)
subjects = rv["subject"].unique()
exp_ids = rv["exp_id"].unique()
exp_metrics_syn = defaultdict(dict)
for subject in subjects:
    for exp_id in exp_ids:
        try:
            _metric = rv[
                (rv["subject"] == subject)
                & (rv["exp_id"] == exp_id)
                & (rv["stage"] == STAGE)
                & (rv["level"] == 4)
                & (rv["iterations"] == max_iterations.get((subject, exp_id), 0))
            ]["metric"].values[0]
        except IndexError as e:
            if max_iterations.get((subject, exp_id), 0) == 0:
                _metric = np.nan
            else:
                raise e
        finally:
            exp_metrics_syn[exp_id][subject] = _metric

In [188]:
subjects = rv["subject"].unique()
n_subjects = rv["subject"].nunique()

# subjects = ["s003"]
worst_metrics = {
    k: np.nanmax([metric for subj, metric in v.items() if subj in subjects])
    for k, v in exp_metrics_syn.items()
}
mean_metrics = {
    k: np.nanmean([metric for subj, metric in v.items() if subj in subjects])
    for k, v in exp_metrics_syn.items()
}

# Previous format used for the functions
_exp_metrics = {
    k: [metric for subj, metric in v.items() if subj in subjects]
    for k, v in exp_metrics_syn.items()
}


All-NaN axis encountered


Mean of empty slice



In [189]:
plot_metric_value(
    n_failed,
    worst_metrics,
    title="SyN (CC): Metric value - Worst subject",
    n_subjects=n_subjects,
)
plot_metric_value(
    n_failed,
    mean_metrics,
    title="SyN (CC): Metric value - Subjects mean",
    exp_metrics=_exp_metrics,
    n_subjects=n_subjects,
)

In [190]:
plot_metric_diff(
    n_failed,
    worst_metrics,
    title="SyN (CC): Metric difference with Binary64 - Worst subject",
    n_subjects=n_subjects,
)
plot_metric_diff(
    n_failed,
    mean_metrics,
    title="SyN (CC): Metric difference with Binary64 - Subjects mean",
    exp_metrics=_exp_metrics,
    n_subjects=n_subjects,
)

In [191]:
plot_metric_rel(
    n_failed,
    worst_metrics,
    title="SyN (CC): Relative difference with Binary64 - Worst subject",
    n_subjects=n_subjects,
)
plot_metric_rel(
    n_failed,
    mean_metrics,
    title="SyN (CC): Relative difference with Binary64 - Subjects mean",
    exp_metrics=_exp_metrics,
    n_subjects=n_subjects,
)

# Thesis proposal figures

In [170]:
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots


# Avoid bokeh dependency, by hardcoding the colorblind8 palette
Colorblind8 = (
    "#0072B2",
    "#E69F00",
    "#F0E442",
    "#009E73",
    "#56B4E9",
    "#D55E00",
    "#CC79A7",
    "#000000",
)

# Creating the subplot grid with independent x-axes
fig = make_subplots(
    rows=3,
    cols=4,
    shared_xaxes=False,
    shared_yaxes=True,
    horizontal_spacing=0.01,
    vertical_spacing=0.05,
)

exp_id = {"binary64": "float64", "r8-p23": "float32"}
max_iterations = rv.groupby(["exp_id", "stage", "level", "subject"])["iterations"].max()

# Loop through each level and exp_id to add traces to the respective subplot
for stage in range(1, 4):
    for level in range(1, 5):
        x_axis = list(exp_id.values())
        y_values = [max_iterations[data_format][stage][level].values.mean() for data_format in exp_id.keys()]
        y_std = [max_iterations[data_format][stage][level].values.std() for data_format in exp_id.keys()]

        fig.add_trace(
            go.Bar(
                x=x_axis,
                y=y_values,
                error_y=dict(type='data', array=y_std),
                text=[f"{s:.1f}" for s in y_values],
                textposition='auto',
                marker=dict(color=Colorblind8),
                showlegend=False,
            ),
            row=stage,
            col=level,
        )

for i, sigma in enumerate(list(range(4))[::-1], 1):
    fig.add_annotation(
        text=f"Sigma: {sigma}",
        xref="x domain",
        yref="y domain",
        x=0.5,
        y=1.05,
        xanchor="center",
        yanchor="bottom",
        row=1,
        col=i,
        showarrow=False,
        font=dict(size=14, color="black"),
    )

for i, stage in enumerate(["Rigid", "Affine", "SyN"], 1):
    fig.add_annotation(
        text=stage,
        xref="x domain",
        yref="y domain",
        x=1.05,
        y=0.4,
        xanchor="center",
        yanchor="bottom",
        textangle=90,
        row=i,
        col=4,
        showarrow=False,
        font=dict(size=14, color="black"),
    )

# fig.update_xaxes(title_text="data format", row=3)
# fig.update_yaxes(title_text="Max iterations", col=1)
fig.update_layout(
    height=800,
)
fig.show()
