# Utils

In [1]:
from pathlib import Path
from collections import defaultdict
import re

import pandas as pd


def _make_data(exp_name):
    p_lvl_header = r"DIAGNOSTIC,Iteration,metricValue,convergenceValue,ITERATION_TIME_INDEX,SINCE_LAST|  Elapsed time"
    subjects = Path("subject_ids.txt").read_text().splitlines()

    logs = Path(f"log/")

    dfs = list()
    for subject in subjects:
        for filename in logs.glob(f"{exp_name}-{subject}-*-*.out"):
            exp_id = filename.stem.split("-")[-2]
            txt = filename.read_text()


            for i, level in enumerate(re.split(p_lvl_header, txt)[1:-1]):
                table = defaultdict(list)
                for row in re.split(r"\n", level.strip("XX").strip()):
                    cols = row.split(",")
                    table["iterations"].append(cols[1].strip())
                    table["metric"].append(cols[2].strip())
                    table["relative_error"].append(cols[3].strip())
                    table["total_time"].append(cols[4].strip())
                    table["since_last"].append(cols[5].strip())

                dfs.append(
                    pd.DataFrame(
                        data={
                            "subject": subject,
                            "exp_id": exp_id,
                            "stage": len(exp_id),
                            "level": i + 1,
                            "iterations": table["iterations"],
                            "metric": table["metric"],
                            "relative_error": table["relative_error"],
                            "total_time": table["total_time"],
                            "since_last": table["since_last"],
                        }
                    )
                )

    df = pd.concat(dfs, ignore_index=True)
    df = df.astype(
        {
            "stage": int,
            "level": int,
            "iterations": int,
            "metric": float,
            "relative_error": float,
            "total_time": float,
            "since_last": float,
        }
    )
    return (
        df.groupby(["subject", "exp_id", "stage", "level", "iterations"])
        .agg({"metric": "mean", "relative_error": "mean", "total_time": "mean"})
        .reset_index()
    )


In [46]:
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots


# Avoid bokeh dependency, by hardcoding the colorblind8 palette
Colorblind8 = (
    '#0072B2',
    '#E69F00',
    '#F0E442',
    '#009E73',
    '#56B4E9',
    '#D55E00',
    '#CC79A7',
    '#000000'
)

def _make_figure(data, *, out_dir: Path, y_value: str = None, y_log: bool = False, show: bool = False):
    for subject in data["subject"].unique():
        subject_data = data[data["subject"] == subject]
        
        # Unique values for levels and exp_ids to determine grid size
        levels = subject_data["level"].unique()
        stages = subject_data["stage"].unique()
        exp_ids = subject_data["exp_id"].unique()

        # Creating the subplot grid with independent x-axes
        fig = make_subplots(
            rows=len(stages),
            cols=len(levels),
            shared_xaxes=False,
            shared_yaxes=True,
        )

        # Loop through each level and exp_id to add traces to the respective subplot
        for l, level in enumerate(levels, 1):
            for s, stage in enumerate(stages, 1):
                for exp_id in exp_ids:
                    subset = subject_data[
                        (subject_data["level"] == level)
                        & (subject_data["stage"] == stage)
                        & (subject_data["exp_id"] == exp_id)
                    ]
                    
                    # Line color
                    # Line with the same colors are perturbed at the same stages
                    color = Colorblind8[int(exp_id[::-1], 2)]

                    fig.add_trace(
                        go.Scatter(
                            x=subset["iterations"],
                            y=subset[y_value],
                            name=exp_id,
                            mode="lines",
                            fillcolor=color,
                            line=dict(color=color),
                            legendgroup=exp_id,
                            showlegend=(True if l == 1 else False),
                        ),
                        row=s,
                        col=l,
                    )

        for i, sigma in enumerate([8, 4, 2, 1], 1):
            fig.add_annotation(
                text=f"Sigma: {sigma}",
                xref="x domain", 
                yref="y domain",
                x=0.5, 
                y=1.05,
                xanchor="center",
                yanchor="bottom",
                row=1, 
                col=i,
                showarrow=False,
                font=dict(size=14, color="black")
            )

        for i, stage in enumerate(["Rigid", "Affine", "SyN"], 1):
            fig.add_annotation(
                text=stage,
                xref="x domain", 
                yref="y domain",
                x=1.05, 
                y=0.4,
                xanchor="center",
                yanchor="bottom",
                textangle=90,
                row=i, 
                col=len(levels),
                showarrow=False,
                font=dict(size=14, color="black")
            )

        fig.update_xaxes(title_text="Iterations", row=3)
        fig.update_yaxes(title_text=y_value.replace("_", " ").capitalize(), col=1)
        if y_log:
            fig.update_yaxes(type="log", exponentformat="e")
        fig.update_layout(
                title=f"Subject: {subject}",
                showlegend=True,
                legend_title="Experiment ID",
                width=1920,
                height=1080,
            )
        
        if show:
            fig.show()
        out_dir.joinpath(y_value).mkdir(exist_ok=True)
        fig.write_html(out_dir / y_value / f"{subject}.html")
        fig.write_image(out_dir / y_value / f"{subject}.png")


In [74]:
def _make_agg_figure(data, *, out_dir: Path, y_value: str = None, y_log: bool = False, show: bool = False):
    # Unique values for levels and exp_ids to determine grid size
    levels = data["level"].unique()
    stages = data["stage"].unique()
    exp_ids = ["0", "1", "00", "11", "000", "111"]

    # Creating the subplot grid with independent x-axes
    fig = make_subplots(
        rows=len(stages),
        cols=len(levels),
        shared_xaxes=False,
        shared_yaxes=True,
    )

    # Loop through each level and exp_id to add traces to the respective subplot
    for l, level in enumerate(levels, 1):
        for s, stage in enumerate(stages, 1):
            for exp_id in exp_ids:
                for i, subject in enumerate(data["subject"].unique()):
                    subset = data[
                        (data["level"] == level)
                        & (data["stage"] == stage)
                        & (data["exp_id"] == exp_id)
                        & (data["subject"] == subject)
                    ]
                    
                    # Line color
                    # Line with the same colors are perturbed at the same stages
                    if "0" in exp_id:
                        color = Colorblind8[0]
                        dash = "dot"
                        group = "0"
                    else:
                        color = f"rgba{px.colors.hex_to_rgb(Colorblind8[1]) + (0.4,)}"
                        dash = "solid"
                        group = "1"

                    fig.add_trace(
                        go.Scatter(
                            x=subset["iterations"],
                            y=subset[y_value],
                            name=exp_id,
                            mode="lines",
                            fillcolor=color,
                            line=dict(color=color, dash=dash),
                            legendgroup=group,
                            showlegend=(True if (l == 1 and i == 0) else False),
                            hovertemplate=f"Subject: {subject}<br>Iteration: %{{x}}<br>{y_value}: %{{y}}",
                        ),
                        row=s,
                        col=l,
                    )

    for i, sigma in enumerate([8, 4, 2, 1], 1):
        fig.add_annotation(
            text=f"Sigma: {sigma}",
            xref="x domain", 
            yref="y domain",
            x=0.5, 
            y=1.05,
            xanchor="center",
            yanchor="bottom",
            row=1, 
            col=i,
            showarrow=False,
            font=dict(size=14, color="black")
        )

    for i, stage in enumerate(["Rigid", "Affine", "SyN"], 1):
        fig.add_annotation(
            text=stage,
            xref="x domain", 
            yref="y domain",
            x=1.05, 
            y=0.4,
            xanchor="center",
            yanchor="bottom",
            textangle=90,
            row=i, 
            col=len(levels),
            showarrow=False,
            font=dict(size=14, color="black")
        )

    fig.update_xaxes(title_text="Iterations", row=3)
    fig.update_yaxes(title_text=y_value.replace("_", " ").capitalize(), col=1)
    if y_log:
        fig.update_yaxes(type="log", exponentformat="e")
    fig.update_layout(
            title=f"All subjects",
            showlegend=True,
            legend_title="Experiment ID",
            width=1920,
            height=1080,
        )
    
    if show:
        fig.show()
    out_dir.joinpath(y_value).mkdir(exist_ok=True)
    fig.write_html(out_dir / y_value / f"all_subjects.html")
    fig.write_image(out_dir / y_value / f"all_subjects.png")


# Figures

In [75]:
EXP_NAME = "flint-vprec-single"

figure_dir = Path("figures", EXP_NAME)
figure_dir.mkdir(exist_ok=True)

data = _make_data(EXP_NAME)
show_plots = True

In [76]:
_make_agg_figure(data, out_dir=figure_dir, y_value="metric", show=show_plots)
_make_agg_figure(data, out_dir=figure_dir, y_value="relative_error", y_log=True, show=show_plots)

## Metric value

In [9]:
_make_figure(data, out_dir=figure_dir, y_value="metric", show=False)

## Relative error

In [10]:
_make_figure(data, out_dir=figure_dir, y_value="relative_error", y_log=True, show=False)