In [None]:
import glob
import os
import pickle

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import simplejson
from plotly.subplots import make_subplots

In [None]:
base_path = "ray_results/"

In [None]:
with open("training_results/history_ann_model_1", "rb") as file_pi:
    ann_history1 = pickle.load(file_pi)
with open("training_results/history_ann_model_2", "rb") as file_pi:
    ann_history2 = pickle.load(file_pi)
with open("training_results/history_ann_model_3", "rb") as file_pi:
    ann_history3 = pickle.load(file_pi)
with open("training_results/history_ann_model_4", "rb") as file_pi:
    ann_history4 = pickle.load(file_pi)

In [None]:
with open("training_results/history_cnn_model_1", "rb") as file_pi:
    cnn_history1 = pickle.load(file_pi)
with open("training_results/history_cnn_model_2", "rb") as file_pi:
    cnn_history2 = pickle.load(file_pi)
with open("training_results/history_cnn_model_3", "rb") as file_pi:
    cnn_history3 = pickle.load(file_pi)
with open("training_results/history_cnn_model_4", "rb") as file_pi:
    cnn_history4 = pickle.load(file_pi)

In [None]:
with open("training_results/history_final_model_1", "rb") as file_pi:
    final_history1 = pickle.load(file_pi)
with open("training_results/history_final_model_2", "rb") as file_pi:
    final_history2 = pickle.load(file_pi)
with open("training_results/history_final_model_3", "rb") as file_pi:
    final_history3 = pickle.load(file_pi)
with open("training_results/history_final_model_4", "rb") as file_pi:
    final_history4 = pickle.load(file_pi)

In [None]:
with open("training_results/r2_scores_hpb", "rb") as file_pi:
    hpb_scores = pickle.load(file_pi)
with open("training_results/r2_scores_ahb", "rb") as file_pi:
    ahb_scores = pickle.load(file_pi)
with open("training_results/r2_scores_pbt", "rb") as file_pi:
    pbt_scores = pickle.load(file_pi)
with open("training_results/r2_scores_fifo", "rb") as file_pi:
    fifo_scores = pickle.load(file_pi)

# Visualize Tuning

In [None]:
# HERE ADAPT TRIAL NAME TO THE RESULTS OF THE TUNING PROCESS
# IN OUR CASE THESE ARE THE EXPERIMENT NAMES
# AHB: 8862d_TRIALNUM
ahb_trial = "RayTune_ARDS_ahb/train_function_8862d*"
# HyperBand: 6eaf7_TRIALNUM
hpb_trial = "RayTune_ARDS_hpb/train_function_6eaf7*"
# PBT: 8e674_TRIALNUM
pbt_trial = "RayTune_ARDS_pbt/train_function_8e674*"
# FIFO: 4a159_TRIALNUM
fio_trial = "RayTune_ARDS_None/train_function_4a159*"

In [None]:
# AHB
fig, ax = plt.subplots(1, 2, figsize=(30, 10))
ahb_duration = []
for trial in sorted(glob.glob(os.path.join(base_path, ahb_trial))):
    progress = pd.read_csv(os.path.join(trial, "progress.csv"))
    try:
        ahb_duration.append(progress.time_total_s.values[-1])
    except IndexError:
        continue
    ax[0].plot(progress.mse)
    ax[1].plot(progress.val_mse)

In [None]:
def make_plot_with_subplots(
    base_path,
    experiment_path,
    metric_1,
    metric_2,
    plot_title,
    shared_yaxes,
    skipfooter,
    show_legend,
):

    fig = make_subplots(
        rows=1,
        cols=2,
        subplot_titles=(metric_1, metric_2),
        shared_yaxes=shared_yaxes,
        x_title="Epoch",
        y_title="Error",
    )

    for trial_path in sorted(glob.glob(os.path.join(base_path, experiment_path))):
        progress = pd.read_csv(
            os.path.join(trial_path, "progress.csv"),
            skipfooter=skipfooter,
            engine="python",
        )

        fig.add_trace(
            go.Scatter(
                x=progress.index,
                y=progress[metric_1],
                mode="lines",
                name=progress.trial_id.unique()[0],
            ),
            row=1,
            col=1,
        )

        fig.add_trace(
            go.Scatter(
                x=progress.index,
                y=progress[metric_2],
                mode="lines",
                name=progress.trial_id.unique()[0],
            ),
            row=1,
            col=2,
        )

    fig.update_layout(
        legend_title_text="Trial IDs", title=plot_title, showlegend=show_legend
    )
    return fig

In [None]:
fig = make_plot_with_subplots(
    base_path, ahb_trial, "mse", "val_mse", "AHB Trials", False, 0, False
)
fig.show()
# fig.write_html("figures/ahb_trials.html")

In [None]:
# HPB
fig, ax = plt.subplots(1, 2, figsize=(30, 10))
hpb_duration = []
for trial in sorted(glob.glob(os.path.join(base_path, hpb_trial))):
    progress = pd.read_csv(os.path.join(trial, "progress.csv"))
    try:
        hpb_duration.append(progress.time_total_s.values[-1])
    except IndexError:
        continue
    ax[0].plot(progress.mse)
    ax[1].plot(progress.val_mse)

In [None]:
fig = make_plot_with_subplots(
    base_path, hpb_trial, "mse", "val_mse", "Hyperband Trials", False, 0, True
)
fig.show()
# fig.write_html("figures/hpb_trials.html")

In [None]:
# PBT
fig, ax = plt.subplots(1, 2, figsize=(30, 10))
pbt_duration = []
for trial in sorted(glob.glob(os.path.join(base_path, pbt_trial))):
    progress = pd.read_csv(os.path.join(trial, "progress.csv"), skipfooter=1)
    try:
        pbt_duration.append(progress.time_total_s.values[-1])
    except IndexError:
        continue
    ax[0].plot(progress.mse)
    ax[1].plot(progress.val_mse)

In [None]:
fig = make_plot_with_subplots(
    base_path, pbt_trial, "mse", "val_mse", "PBT Trials", False, 1, False
)
fig.show()
# fig.write_html("figures/pbt_trials.html")

In [None]:
# FIFO
fig, ax = plt.subplots(1, 2, figsize=(30, 10))
fio_duration = []
for trial in sorted(glob.glob(os.path.join(base_path, fio_trial))):
    progress = pd.read_csv(os.path.join(trial, "progress.csv"), skipfooter=1)
    try:
        fio_duration.append(progress.time_total_s.values[-1])
    except IndexError:
        continue
    ax[0].plot(progress.mse)
    ax[1].plot(progress.val_mse)

In [None]:
fig = make_plot_with_subplots(
    base_path, fio_trial, "mse", "val_mse", "FIFO Trials", False, 1, False
)
fig.show()
# fig.write_html("figures/fifo_trials.html")

# Graph Tuning Results for Paper

In [None]:
# FrankenGraph
fig, ax = plt.subplots(ncols=2, nrows=2, figsize=(12.5, 10), sharey=True)
for trial in sorted(glob.glob(os.path.join(base_path, fio_trial))):
    progress = pd.read_csv(os.path.join(trial, "progress.csv"), skipfooter=1)
    ax[0, 0].plot(progress.val_mse)
ax[0, 0].set_title("(a)", fontsize=30)
ax[0, 0].grid(which="major", axis="y")
ax[0, 0].set_ylabel("Loss", fontsize=30)
ax[0, 0].tick_params(axis="both", which="major", labelsize=20)

for trial in sorted(glob.glob(os.path.join(base_path, hpb_trial))):
    progress = pd.read_csv(os.path.join(trial, "progress.csv"))
    ax[0, 1].plot(progress.val_mse)
ax[0, 1].set_title("(b)", fontsize=30)
ax[0, 1].grid(which="major", axis="y")
ax[0, 1].tick_params(axis="both", which="major", labelsize=20)

for trial in sorted(glob.glob(os.path.join(base_path, ahb_trial))):
    progress = pd.read_csv(os.path.join(trial, "progress.csv"))
    ax[1, 0].plot(progress.val_mse)
ax[1, 0].set_title("(c)", fontsize=30)
ax[1, 0].set_ylabel("Loss", fontsize=30)
ax[1, 0].set_xlabel("Epoch", fontsize=30)
ax[1, 0].grid(which="major", axis="y")
ax[1, 0].tick_params(axis="both", which="major", labelsize=20)

for trial in sorted(glob.glob(os.path.join(base_path, pbt_trial))):
    progress = pd.read_csv(os.path.join(trial, "progress.csv"), skipfooter=1)
    ax[1, 1].plot(progress.val_mse)
ax[1, 1].set_title("(d)", fontsize=30)
ax[1, 1].set_xlabel("Epoch", fontsize=30)
ax[1, 1].grid(which="major", axis="y")
ax[1, 1].tick_params(axis="both", which="major", labelsize=20)

plt.tight_layout()
plt.show()
# plt.savefig("figures/tuning_performance.png")

In [None]:
def make_plot(base_path, experiment_path, metric, plot_title):
    fig = go.Figure()
    for trial_path in sorted(glob.glob(os.path.join(base_path, experiment_path))):
        progress = pd.read_csv(os.path.join(trial_path, "progress.csv"), skipfooter=0)
        fig.add_trace(
            go.Scatter(
                x=progress.index,
                y=progress[metric],
                mode="lines",
                name=progress.trial_id.unique()[0],
            )
        )
    fig.update_layout(
        legend_title_text="Trial IDs",
        title=plot_title,
        xaxis_title="Epoch",
        yaxis_title=metric,
        showlegend=False,
    )
    return fig

In [None]:
def add_trace(fig, base_path, experiment_path, metric, row, col):
    for trial in sorted(glob.glob(os.path.join(base_path, experiment_path))):
        progress = pd.read_csv(
            os.path.join(trial, "progress.csv"), skipfooter=1, engine="python"
        )

        fig.add_trace(
            go.Scatter(
                x=progress.index,
                y=progress[metric],
                mode="lines",
                #                       name=progress.trial_id.unique()[0]
            ),
            row=row,
            col=col,
        )
    return fig

In [None]:
metric = "val_mse"
fig = make_subplots(
    rows=2,
    cols=2,
    subplot_titles=("(a)", "(b)", "(c)", "(d)"),
    shared_yaxes=True,
    x_title="Epoch",
    y_title="Error",
    horizontal_spacing=0.05,
    vertical_spacing=0.05,
)

fig = add_trace(fig, base_path, fio_trial, metric, 1, 1)
fig = add_trace(fig, base_path, hpb_trial, metric, 1, 2)
fig = add_trace(fig, base_path, ahb_trial, metric, 2, 1)
fig = add_trace(fig, base_path, pbt_trial, metric, 2, 2)


fig.update_layout(
    showlegend=False,
    height=1200,
    width=1200,
    font=dict(size=18),
    yaxis_showticklabels=True,
    yaxis2_showticklabels=True,
    yaxis3_showticklabels=True,
    yaxis4_showticklabels=True,
)
fig.update_annotations(font_size=22)

fig.show()
# fig.write_html("figures/val_mse_all_schedulers.html")

# Visualize Initial Training Results

In [None]:
fig, ax = plt.subplots(nrows=2, ncols=4, figsize=(20, 10), sharey=True)
ax[0, 0].set_ylim(0, 0.5)
ax[0, 0].plot(ann_history1["mae"])
ax[0, 0].plot(ann_history1["val_mae"])
# ax[0, 0].set_title("(a)", fontsize=30)
ax[0, 0].grid(which="major", axis="y")
ax[0, 0].set_ylabel("Fully-Connected", fontsize=30)
# ax[0].set_xlabel("Epochs", fontsize=24)
ax[0, 0].tick_params(axis="both", which="major", labelsize=20)

ax[0, 1].plot(ann_history2["mae"])
ax[0, 1].plot(ann_history2["val_mae"])
# ax[0, 1].set_title("(b)", fontsize=30)
ax[0, 1].grid(which="major", axis="y")
# ax[0, 1].set_ylabel("(b)", fontsize=30)
# ax[1].set_xlabel("Epochs", fontsize=24)
ax[0, 1].tick_params(axis="both", which="major", labelsize=20)

ax[0, 2].plot(ann_history3["mae"])
ax[0, 2].plot(ann_history3["val_mae"])
# ax[0, 2].set_title("(c)", fontsize=30)
ax[0, 2].grid(which="major", axis="y")
# ax[0, 2].set_ylabel("(c)", fontsize=30)
# ax[2].set_xlabel("Epochs", fontsize=24)
ax[0, 2].tick_params(axis="both", which="major", labelsize=20)

ax[0, 3].plot(ann_history4["mae"])
ax[0, 3].plot(ann_history4["val_mae"])
# ax[0, 3].set_title("(d)", fontsize=30)
ax[0, 3].grid(which="major", axis="y")
# ax[0, 3].set_ylabel("(d)", fontsize=30)
# ax[3].set_xlabel("Epochs", fontsize=24)
ax[0, 3].tick_params(axis="both", which="major", labelsize=20)

ax[1, 0].plot(cnn_history1["mae"])
ax[1, 0].plot(cnn_history1["val_mae"])
# ax[1, 0].set_title("(e)", fontsize=30)
ax[1, 0].grid(which="major", axis="y")
ax[1, 0].set_ylabel("CNN", fontsize=30)
ax[1, 0].set_xlabel("Epoch", fontsize=24)
ax[1, 0].tick_params(axis="both", which="major", labelsize=20)

ax[1, 1].plot(cnn_history2["mae"])
ax[1, 1].plot(cnn_history2["val_mae"])
# ax[1, 1].set_title("(f)", fontsize=30)
ax[1, 1].grid(which="major", axis="y")
# ax[0, 1].set_ylabel("(b)", fontsize=30)
ax[1, 1].set_xlabel("Epoch", fontsize=24)
ax[1, 1].tick_params(axis="both", which="major", labelsize=20)

ax[1, 2].plot(cnn_history3["mae"])
ax[1, 2].plot(cnn_history3["val_mae"])
# ax[1, 2].set_title("(g)", fontsize=30)
ax[1, 2].grid(which="major", axis="y")
# ax[0, 2].set_ylabel("(c)", fontsize=30)
ax[1, 2].set_xlabel("Epoch", fontsize=24)
ax[1, 2].tick_params(axis="both", which="major", labelsize=20)

ax[1, 3].plot(cnn_history4["mae"])
ax[1, 3].plot(cnn_history4["val_mae"])
# ax[1, 3].set_title("(h)", fontsize=30)
ax[1, 3].grid(which="major", axis="y")
# ax[1, 3].set_ylabel("(d)", fontsize=30)
ax[1, 3].set_xlabel("Epoch", fontsize=24)
ax[1, 3].tick_params(axis="both", which="major", labelsize=20)

plt.tight_layout()
plt.show()
# plt.savefig("figures/training_performance.png")

In [None]:
def add_traces(fig, data, metric_1, metric_2, show_legend, row, col):
    df = pd.DataFrame.from_dict(data)
    fig.add_trace(
        go.Scatter(
            mode="lines",
            x=df.index,
            y=df[metric_1],
            marker=dict(color="Blue"),
            name=metric_1,
            showlegend=show_legend,
        ),
        row=row,
        col=col,
    )

    fig.add_trace(
        go.Scatter(
            mode="lines",
            x=df.index,
            y=df[metric_2],
            marker=dict(color="Orange"),
            name=metric_2,
            showlegend=show_legend,
        ),
        row=row,
        col=col,
    )

    return fig

In [None]:
metric_1 = "mae"
metric_2 = "val_mae"
fig = make_subplots(
    rows=2,
    cols=4,
    shared_yaxes=True,
    x_title="Epoch",
    y_title="Error",
    row_titles=["Fully-Connected", "CNN"],
)

add_traces(fig, ann_history1, metric_1, metric_2, True, 1, 1)
add_traces(fig, ann_history2, metric_1, metric_2, False, 1, 2)
add_traces(fig, ann_history3, metric_1, metric_2, False, 1, 3)
add_traces(fig, ann_history4, metric_1, metric_2, False, 1, 4)

add_traces(fig, cnn_history1, metric_1, metric_2, False, 2, 1)
add_traces(fig, cnn_history2, metric_1, metric_2, False, 2, 2)
add_traces(fig, cnn_history3, metric_1, metric_2, False, 2, 3)
add_traces(fig, cnn_history4, metric_1, metric_2, False, 2, 4)

fig.update_layout(
    showlegend=False,
    height=700,
    width=1200,
    font=dict(size=18),
    yaxis_showticklabels=True,
    yaxis2_showticklabels=True,
    yaxis3_showticklabels=True,
    yaxis4_showticklabels=True,
    yaxis5_showticklabels=True,
    yaxis6_showticklabels=True,
    yaxis7_showticklabels=True,
    yaxis8_showticklabels=True,
)

# fig.update_layout(yaxis=dict(range=[0,0.6]))
fig.update_yaxes(
    #     title_text="Fully-Connected",
    range=[0, 0.5],
    row=1,
    col=1,
)
fig.update_yaxes(
    #     title_text="CNN",
    range=[0, 0.5],
    row=2,
    col=1,
)

fig.update_annotations(font_size=22)
fig.show()
# fig.write_html("figures/ann_vs_cnn_mae_and_val_mae.html")

# Graph Final Models

In [None]:
fig, ax = plt.subplots(nrows=3, ncols=4, figsize=(20, 15), sharey="row")

ax[0, 0].plot(final_history1["mae"])
ax[0, 0].plot(final_history1["val_mae"])
ax[0, 0].set_title("FIFO", fontsize=30)
ax[0, 0].grid(which="major", axis="y")
ax[0, 0].set_ylabel("Accuracy", fontsize=30)
ax[0, 0].set_xlabel("Epoch", fontsize=24)
ax[0, 0].set_ylim([0, 0.4])
ax[0, 0].tick_params(axis="both", which="major", labelsize=20)

ax[0, 1].plot(final_history2["mae"])
ax[0, 1].plot(final_history2["val_mae"])
ax[0, 1].set_title("HyperBand", fontsize=30)
ax[0, 1].grid(which="major", axis="y")
# ax[0, 1].set_ylabel("(b)", fontsize=30)
ax[0, 1].set_xlabel("Epoch", fontsize=24)
ax[0, 1].tick_params(axis="both", which="major", labelsize=20)

ax[0, 2].plot(final_history3["mae"])
ax[0, 2].plot(final_history3["val_mae"])
ax[0, 2].set_title("Async. HyperBand", fontsize=30)
ax[0, 2].grid(which="major", axis="y")
# ax[0, 2].set_ylabel("(c)", fontsize=30)
ax[0, 2].set_xlabel("Epoch", fontsize=24)
ax[0, 2].tick_params(axis="both", which="major", labelsize=20)

ax[0, 3].plot(final_history4["mae"])
ax[0, 3].plot(final_history4["val_mae"])
ax[0, 3].set_title("PBT", fontsize=30)
ax[0, 3].grid(which="major", axis="y")
# ax[0, 3].set_ylabel("(d)", fontsize=30)
ax[0, 3].set_xlabel("Epoch", fontsize=24)
ax[0, 3].tick_params(axis="both", which="major", labelsize=20)

ax[1, 0].plot(final_history1["mse"])
ax[1, 0].plot(final_history1["val_mse"])
ax[1, 0].set_title("FIFO", fontsize=30)
ax[1, 0].grid(which="major", axis="y")
ax[1, 0].set_ylabel("Accuracy", fontsize=30)
ax[1, 0].set_xlabel("Epoch", fontsize=24)
ax[1, 0].set_ylim([0, 0.4])
ax[1, 0].tick_params(axis="both", which="major", labelsize=20)

ax[1, 1].plot(final_history2["mse"])
ax[1, 1].plot(final_history2["val_mse"])
ax[1, 1].set_title("HyperBand", fontsize=30)
ax[1, 1].grid(which="major", axis="y")
# ax[0, 1].set_ylabel("(b)", fontsize=30)
ax[1, 1].set_xlabel("Epoch", fontsize=24)
ax[1, 1].tick_params(axis="both", which="major", labelsize=20)

ax[1, 2].plot(final_history3["mse"])
ax[1, 2].plot(final_history3["val_mse"])
ax[1, 2].set_title("Async. HyperBand", fontsize=30)
ax[1, 2].grid(which="major", axis="y")
# ax[0, 2].set_ylabel("(c)", fontsize=30)
ax[1, 2].set_xlabel("Epoch", fontsize=24)
ax[1, 2].tick_params(axis="both", which="major", labelsize=20)

ax[1, 3].plot(final_history4["mse"])
ax[1, 3].plot(final_history4["val_mse"])
ax[1, 3].set_title("PBT", fontsize=30)
ax[1, 3].grid(which="major", axis="y")
# ax[0, 3].set_ylabel("(d)", fontsize=30)
ax[1, 3].set_xlabel("Epoch", fontsize=24)
ax[1, 3].tick_params(axis="both", which="major", labelsize=20)

ax[2, 0].bar(["P$_a$O$_2$", "P$_a$CO$_2$", "HCO$_3$", "pH"], fifo_scores)
# ax[1, 0].set_title("(e)", fontsize=30)
ax[2, 0].grid(which="major", axis="y")
ax[2, 0].set_ylabel("R$^2$ score", fontsize=30)
# ax[0].set_xlabel("Epochs", fontsize=24)
ax[2, 0].set_ylim([0.75, 1.0])
ax[2, 0].tick_params(axis="both", which="major", labelsize=20)

ax[2, 1].bar(["P$_a$O$_2$", "P$_a$CO$_2$", "HCO$_3$", "pH"], hpb_scores)
# ax[1, 1].set_title("(f)", fontsize=30)
ax[2, 1].grid(which="major", axis="y")
# ax[0, 1].set_ylabel("(b)", fontsize=30)
# ax[1].set_xlabel("Epochs", fontsize=24)
ax[2, 1].set_ylim([0.75, 1.0])
ax[2, 1].tick_params(axis="both", which="major", labelsize=20)

ax[2, 2].bar(["P$_a$O$_2$", "P$_a$CO$_2$", "HCO$_3$", "pH"], ahb_scores)
# ax[1, 2].set_title("(g)", fontsize=30)
ax[2, 2].grid(which="major", axis="y")
# ax[0, 2].set_ylabel("(c)", fontsize=30)
# ax[2].set_xlabel("Epochs", fontsize=24)
ax[2, 2].set_ylim([0.75, 1.0])
ax[2, 2].tick_params(axis="both", which="major", labelsize=20)

ax[2, 3].bar(["P$_a$O$_2$", "P$_a$CO$_2$", "HCO$_3$", "pH"], pbt_scores)
# ax[1, 3].set_title("(h)", fontsize=30)
ax[2, 3].grid(which="major", axis="y")
# ax[1, 3].set_ylabel("(d)", fontsize=30)
# ax[3].set_xlabel("Epochs", fontsize=24)
ax[2, 3].set_ylim([0.75, 1.0])
ax[2, 3].tick_params(axis="both", which="major", labelsize=20)

plt.tight_layout()
plt.show()
# plt.savefig("figures/final_performance.png")

In [None]:
def add_scatter_traces_to_final(
    fig, data, metric_1, metric_2, show_legend, row, col, color1, color2
):
    df = pd.DataFrame.from_dict(data)
    fig.add_trace(
        go.Scatter(
            mode="lines",
            x=df.index,
            y=df[metric_1],
            marker=dict(color=color1),
            name=metric_1,
            showlegend=show_legend,
        ),
        row=row,
        col=col,
    )

    fig.add_trace(
        go.Scatter(
            mode="lines",
            x=df.index,
            y=df[metric_2],
            marker=dict(color=color2),
            name=metric_2,
            showlegend=show_legend,
        ),
        row=row,
        col=col,
    )

In [None]:
def add_bar_traces_to_final(fig, data, bar_names, row, col, color):
    fig.add_trace(
        go.Bar(x=bar_names, y=data, marker=dict(color=color), showlegend=False),
        row=row,
        col=col,
    )

In [None]:
metric_1 = "mae"
metric_2 = "val_mae"
metric_3 = "mse"
metric_4 = "val_mse"

color1 = 'cornflowerblue'
color2 = 'orange'
color3 = 'violet'
color4 = 'mediumseagreen'

fig = make_subplots(
    rows=3,
    cols=4,
    shared_yaxes=True,
    x_title="Epoch",
    row_titles=["MAE", "MSE", "R<sup>2</sup> Score"],
    column_titles=["FIFO", "HyperBand", "Async HyberBand", "PBT"],
)

add_scatter_traces_to_final(
    fig, final_history1, metric_1, metric_2, True, 1, 1, color1, color2
)
add_scatter_traces_to_final(
    fig, final_history1, metric_3, metric_4, True, 2, 1, color1, color2
)

add_scatter_traces_to_final(
    fig, final_history2, metric_1, metric_2, False, 1, 2, color1, color2
)
add_scatter_traces_to_final(
    fig, final_history2, metric_3, metric_4, False, 2, 2, color1, color2
)

add_scatter_traces_to_final(
    fig, final_history3, metric_1, metric_2, False, 1, 3, color1, color2
)
add_scatter_traces_to_final(
    fig, final_history3, metric_3, metric_4, False, 2, 3, color1, color2
)

add_scatter_traces_to_final(
    fig, final_history4, metric_1, metric_2, False, 1, 4, color1, color2
)
add_scatter_traces_to_final(
    fig, final_history4, metric_3, metric_4, False, 2, 4, color1, color2
)

bar_names = [
    "P<sub>a</sub>O<sub>2</sub>",
    "P<sub>a</sub>CO<sub>2</sub>",
    "HCO<sub>3</sub>",
    "pH",
]

add_bar_traces_to_final(fig, fifo_scores, bar_names, 3, 1, color1)
add_bar_traces_to_final(fig, hpb_scores, bar_names, 3, 2, color1)
add_bar_traces_to_final(fig, ahb_scores, bar_names, 3, 3, color1)
add_bar_traces_to_final(fig, pbt_scores, bar_names, 3, 4, color1)

fig.update_layout(
    showlegend=False,
    height=800,
    width=1200,
    font=dict(size=18),
    yaxis_showticklabels=True,
    yaxis2_showticklabels=True,
    yaxis3_showticklabels=True,
    yaxis4_showticklabels=True,
    yaxis5_showticklabels=True,
    yaxis6_showticklabels=True,
    yaxis7_showticklabels=True,
    yaxis8_showticklabels=True,
    yaxis9 = dict(
        showticklabels=True,
        range = [0.75, 1]),
    yaxis10 = dict(
        showticklabels=True,
        range = [0.75, 1]),
    yaxis11 = dict(
        showticklabels=True,
        range = [0.75, 1]),
    yaxis12 = dict(
        showticklabels=True,
        range = [0.75, 1])
)

fig.layout.annotations[7].update(y=0.39)

fig.show()
fig.write_html("figures/schedulers_and_metrics.html")

In [None]:
fig.layout

In [None]:
# metric_1 = "mae"
# metric_2 = "val_mae"
# fig = make_subplots(
#     rows=2,
#     cols=4,
#     shared_yaxes=True,
#     x_title="Epoch",
#     y_title="Error",
#     row_titles=["Accuracy", "Accuracy" "CNN"],
# )

# add_traces(fig, ann_history1, metric_1, metric_2, True, 1, 1)
# add_traces(fig, ann_history2, metric_1, metric_2, False, 1, 2)
# add_traces(fig, ann_history3, metric_1, metric_2, False, 1, 3)
# add_traces(fig, ann_history4, metric_1, metric_2, False, 1, 4)

# add_traces(fig, cnn_history1, metric_1, metric_2, False, 2, 1)
# add_traces(fig, cnn_history2, metric_1, metric_2, False, 2, 2)
# add_traces(fig, cnn_history3, metric_1, metric_2, False, 2, 3)
# add_traces(fig, cnn_history4, metric_1, metric_2, False, 2, 4)

# fig.update_layout(
#     showlegend=False,
#     height=700,
#     width=1200,
#     font=dict(size=18),
#     yaxis_showticklabels=True,
#     yaxis2_showticklabels=True,
#     yaxis3_showticklabels=True,
#     yaxis4_showticklabels=True,
#     yaxis5_showticklabels=True,
#     yaxis6_showticklabels=True,
#     yaxis7_showticklabels=True,
#     yaxis8_showticklabels=True,
# )

# # fig.update_layout(yaxis=dict(range=[0,0.6]))
# fig.update_yaxes(
#     #     title_text="Fully-Connected",
#     range=[0, 0.5],
#     row=1,
#     col=1,
# )
# fig.update_yaxes(
#     #     title_text="CNN",
#     range=[0, 0.5],
#     row=2,
#     col=1,
# )

# fig.update_annotations(font_size=22)
# fig.show()
# fig.write_html("figures/ann_vs_cnn_mae_and_val_mae.html")