In [None]:
import os
import pickle
import warnings
from typing import List, Dict, Union, Tuple, Optional, Callable

import numpy as np
import pandas as pd
import plotly.graph_objects as go
from rliable import library as rly
from rliable import metrics
from scipy.stats import mannwhitneyu, ttest_ind
from statsmodels.stats.multitest import multipletests

warnings.filterwarnings('ignore')

layer_of_interest = [1, 2, 3, 5, 10, 15, 18]

COLORS = {
    1: (255, 127, 14), # orange
    3: (44, 160, 44), # green
    18: (214, 39, 40), # red #D62728
    2: (148, 103, 189), # purple #9467BD
    5: (31, 119, 180), # blue #1F77B4
    10: (188, 189, 34), # lime green #BCBD22
    15: (23, 190, 207), # light blue #17BECF
    4: (140, 86, 75), # brown #8C564B
    6: (227, 119, 194), # pink #E377C2
    7: (127, 127, 127), # grey #7F7F7F
}

COLORS_DARKER = {
    1: (229, 118, 20), # orange
    3: (31, 138, 44), # green
    18: (176, 32, 32), # red #D62728
    2: (148, 103, 189), # purple #9467BD
    5: (31, 119, 180), # blue #1F77B4
    10: (188, 189, 34), # lime green #BCBD22
    15: (23, 190, 207), # light blue #17BECF
    4: (140, 86, 75), # brown #8C564B
    6: (227, 119, 194), # pink #E377C2
    7: (127, 127, 127), # grey #7F7F7F
}

RAINBOW_COLORS = [
    "rgb(244, 67, 54)",
    "rgb(232, 30, 99)",
    "rgb(192, 35, 120)",
    "rgb(169, 37, 150)",
    "rgb(156, 39, 176)",
    "rgb(103, 58, 183)",
    "rgb(63, 81, 181)",
    "rgb(33, 150, 243)",
    "rgb(3, 169, 244)",
    "rgb(0, 188, 212)",
    "rgb(0, 150, 136)",
    "rgb(76, 175, 80)",
    "rgb(139, 195, 74)",
    "rgb(205, 220, 57)",
    "rgb(255, 235, 59)",
    "rgb(255, 193, 7)",
    "rgb(255, 152, 0)",
    "rgb(255, 87, 34)"
]

FONT_SIZE_TITLE = 24
FONT_SIZE_AXIS = 22
FONT_SIZE_TICKS = 20
FONT_SIZE_LEGEND = 20

# Load Data
The training data of this work can be viewed directly in [wandb](https://wandb.ai/elba/Master-Thesis/table?nw=zglvzydk8v9). To run this notebook please first download the data with the script in `data/download_results.sh`.

In [None]:
train_run_data = os.path.join("data", "training-runs", "run_df.csv")
if os.path.exists(train_run_data):
    run_df = pd.read_csv(train_run_data)
else:
    raise FileNotFoundError(f"Ensure the training data is downloaded with the script in `data/download_results.sh`")

In [None]:
run_df.head()

# Transform Data
The data for the `rliable` library needs to be in a specific shape to calculate the IQM and CI. Also other investigations require a specific format or some form of calculations.

In [None]:
canonical_names = {
    'MiniGrid-MemoryS11-v0': 'Memory',
    'psychlab_continuous_recognition': "Continuous Recognition"
}

tech_names = {v: k for k, v in canonical_names.items()}

In [None]:
# general data for showing training curves with IQM
run_dict = {}
for env in canonical_names:
    name = canonical_names[env]
    run_dict[name] = {}
    tmp_df = run_df.query("`env`==@env")
    for n in tmp_df["setting"].unique():
        tmp_data = (run_df
                   .query("`setting`== @n & `env` == @env")
                   .pivot(columns="name", values=["rollout/ep_rew_mean"], index="global_step"))
    
        run_dict[name][n] = tmp_data.values.T

iqm = lambda scores: np.array([metrics.aggregate_iqm(scores[:,i]) for i in range(scores.shape[-1])])

In [None]:
# data for inspecting runtime vs. performance
runtime_test = {}
for env in run_df["env"].unique():
    runtime_test[env] = {}
    for n1 in sorted(run_dict[canonical_names[env]]):
        end_rew = np.expand_dims(np.array(run_dict[canonical_names[env]][n1][:,-1]), axis=1)
        runtime_test[env][n1] = end_rew
    
    iqm_scores, iqm_cis = rly.get_interval_estimates(
        runtime_test[env],
        iqm,
        reps=50
    )

    runtime_test[env]["iqm_scores"] = iqm_scores
    runtime_test[env]["iqm_cis"] = iqm_cis


perf_vs_time = {}
for env in run_df["env"].unique():
    perf_vs_time[env] = {}
    for se in runtime_test[env]["iqm_scores"]:
        perf_vs_time[env][se] = {}
        perf_vs_time[env][se]["iqm"] = runtime_test[env]["iqm_scores"][se].item()
        perf_vs_time[env][se]["ci_low"] = runtime_test[env]["iqm_cis"][se][0].item()
        perf_vs_time[env][se]["ci_high"] = runtime_test[env]["iqm_cis"][se][1].item()
        perf_vs_time[env][se]["runtime"] = run_df.query("`setting` == @se & `host` == 'celantur-delta' & `env` == @env")["runtime"].mean()
        perf_vs_time[env][se]["std"] = run_df.query("`setting` == @se & `host` == 'celantur-delta' & `env` == @env")["runtime"].std()

In [None]:
# data for comparing update and rollout times
rollout_data = os.path.join("data", "rollout-times", "rollout_times.pkl")
update_data = os.path.join("data", "update-times", "update_times.pkl")

if os.path.exists(rollout_data):
    with open(rollout_data, "rb") as f:
        rollout_times = pickle.load(f)
else:
    raise FileNotFoundError(f"Ensure the data is downloaded with the script in `data/download_results.sh`")

if os.path.exists(update_data):
    with open(update_data, "rb") as f:
        update_times = pickle.load(f)


In [None]:
iqm = lambda scores: np.array([metrics.aggregate_iqm(scores[:,i]) for i in range(scores.shape[-1])])

runtime_test = {}
for env in run_df["env"].unique():
    runtime_test[env] = {}
    for n1 in sorted(run_dict[canonical_names[env]]):
        end_rew = np.expand_dims(np.array(run_dict[canonical_names[env]][n1][:,-1]), axis=1)
        runtime_test[env][n1] = end_rew
    
    iqm_scores, iqm_cis = rly.get_interval_estimates(
        runtime_test[env],
        iqm,
        reps=50
    )

    runtime_test[env]["iqm_scores"] = iqm_scores
    runtime_test[env]["iqm_cis"] = iqm_cis

perf_vs_time = {}
for env in run_df["env"].unique():
    perf_vs_time[env] = {}
    for se in runtime_test[env]["iqm_scores"]:
        perf_vs_time[env][se] = {}
        perf_vs_time[env][se]["iqm"] = runtime_test[env]["iqm_scores"][se].item()
        perf_vs_time[env][se]["ci_low"] = runtime_test[env]["iqm_cis"][se][0].item()
        perf_vs_time[env][se]["ci_high"] = runtime_test[env]["iqm_cis"][se][1].item()
        perf_vs_time[env][se]["runtime"] = run_df.query("`setting` == @se & `host` == 'celantur-delta' & `env` == @env")["runtime"].mean()
        perf_vs_time[env][se]["std"] = run_df.query("`setting` == @se & `host` == 'celantur-delta' & `env` == @env")["runtime"].std()

# Evaluate Results
How to the different SHELM variant compare? In this section we visualize the results after training each model variant on 10 different random seeds. The results of our experiments can be summarized by the IQM+CI of the collected reward during training as well the a comparison between the performance and runtime of each model variant.

## IQM + CI
The IQM + CI is calculated with the `rliable` package and should give a robut estimate on how good a RL model performs

In [None]:
def plot_iqm(
        x: List[float],
        y: Dict[str, Union[List[float], np.ndarray]],
        y_error: Dict[str, Tuple[List[float], List[float]]],
        title: str,
        key: Optional[Callable[[str], Tuple[int, float]]] = None
) -> go.Figure:
    """
    Plots the IQM (Inter Quartile Mean) + CI (Confidence Interval) over time.

    Args:
        x (List[float]): The x-axis values.
        y (Dict[str, Union[List[float], np.ndarray]]): A dictionary containing the y-axis values for each experiment.
            The dictionary should have the following structure:
            {
                name: y_values
            }
            where `name` is a string representing the name of the experiment and `y_values` is a list or numpy array of y-axis values.
        y_error (Dict[str, Tuple[List[float], List[float]]]): A dictionary containing the error bars for each experiment.
            The dictionary should have the following structure:
            {
                name: (lower_error_values, upper_error_values)
            }
            where `name` is a string representing the name of the experiment and `lower_error_values` and `upper_error_values` are lists or numpy arrays of error bar values.
        title (str): The title of the plot.
        key (Optional[Callable[[str], Tuple[int, float]]], optional): A function used to sort the names of the experiments.
            The function should take a string representing the name of the experiment and return a tuple of integers `(n_layer, lr)`.
            Defaults to None.

    Returns:
        go.Figure: The plot as a Plotly figure object.

    Notes:
        - The function sets the x-axis title to "Number of Interaction Steps" and the y-axis title to "Accumulated Reward".
        - The function sorts the names of the experiments based on the `key` function if provided, otherwise the names are sorted alphabetically.
    """
    
    fig = go.Figure()
    
    for name in sorted(y, key=key):
        if "Tanh" in name:
            n_layer, lr, _ = name.split(" | ")
        else:
            n_layer, lr = name.split(" | ")
        fig.add_trace(
            go.Scatter(
                name=f"n layer: {n_layer}",
                x=x,
                y=y[name],
                mode="lines",
                line=dict(
                    color=f"rgb{str(COLORS[int(n_layer)])}"
                )
            )
        )

        fig.add_trace(
            go.Scatter(
                name=f"upper {name}",
                x=x,
                y=y_error[name][1],
                mode="lines",
                line=dict(
                    width=0
                ),
                showlegend=False
            )
        )

        fig.add_trace(
            go.Scatter(
                name=f"lower {name}",
                x=x,
                y=y_error[name][0],
                mode="lines",
                fillcolor=f"rgba{str(COLORS[int(n_layer)])[:-1]},0.1)",
                fill="tonexty",
                line=dict(
                    width=0
                ),
                showlegend=False
            )
        )

    fig.update_layout(
        title=dict(
            text=title,
            font_size=FONT_SIZE_TITLE,
            xanchor="center",
            x=0.5,
            yanchor="top",
            y=0.989,
        ),
        xaxis=dict(
            title="Number of Interaction Steps",
            tickangle=-35,
            tickfont_size=FONT_SIZE_TICKS,
            title_font_size=FONT_SIZE_AXIS,
        ),
        yaxis=dict(
            title="Accumulated Reward",
            tickangle=-35,
            ticksuffix="  ",
            tickfont_size=FONT_SIZE_TICKS,
            title_font_size=FONT_SIZE_AXIS,
        ),
        legend=dict(
            yanchor="bottom",
            y=0.05,
            xanchor="right",
            x=0.99,
            bgcolor="rgba(0,0,0,0)",
            font_size=FONT_SIZE_LEGEND
        ),
        height=700,
        width=900,
        margin=dict(
            l=10,
            r=20,
            b=30,
            t=40,#20
            pad=10
        ),
    )

    return fig

### Memory Environment

In [None]:
mem_idx = run_df[run_df["name"] == "spring-sweep-8"]["global_step"].values

mem_iqm_scores, mem_iqm_cis = rly.get_interval_estimates(
    run_dict["Memory"],
    iqm,
    reps=50
)

In [None]:
world_combined_mem = plot_iqm(
    x=mem_idx,
    y=mem_iqm_scores,
    y_error=mem_iqm_cis,
    title="MiniGrid - Memory",
    key=lambda x: (int(x[:2]), float(x[-6:]))
)

world_combined_mem.show()

### Continuous Recognition Environment

In [None]:
cr_idx = run_df[run_df["name"] == "dandy-sweep-5"]["global_step"].values

cr_iqm_scores, cr_iqm_cis = rly.get_interval_estimates(
    run_dict["Continuous Recognition"],
    iqm,
    reps=50
)

In [None]:
world_combined_psy = plot_iqm(
    x=cr_idx,
    y=cr_iqm_scores,
    y_error=cr_iqm_cis,
    title="PsychLab - Continuous Recognition",
    key=lambda x: (int(x[:2]), float(x[-6:]))
)

world_combined_psy.show()

## Preformance vs. Runtime
Because some models perform nearly identical it is important to also highlight the differences in runtime.

In [None]:
def plot_perf_v_time(
        data: Dict[str, Dict[str, Dict[str, float]]],
        title: str
    ) -> go.Figure:
    """
    Plots performance vs time using the given data.

    Args:
        data (Dict[str, Dict[str, Dict[str, float]]]): A dictionary containing the data to be plotted.
            The dictionary should have the following structure:
            {
                env_name: {
                    name: {
                        "iqm": float,
                        "runtime": float,
                        "ci_high": float,
                        "ci_low": float,
                        "std": float
                    }
                }
            }
        title (str): The title of the plot.

    Returns:
        go.Figure: The plot as a Plotly figure object.

    Notes:
        - The function sorts the names of the experiments based on the first two characters and the last six characters.
        - The function sets the y-axis range to [0, 1] if the title contains "Memory", and to [0, 47] otherwise.
    """

    env_name = title.split(" - ")[1]
    env = tech_names[env_name]
    
    fig = go.Figure()
    
    sorted_names = sorted(data[env], key=lambda x: (int(x[:2]), float(x[-6:])))
    
    y = [data[env][name]["iqm"] for name in sorted_names]
    x = [data[env][name]["runtime"] / 3600 for name in sorted_names]
    array = [abs(data[env][name]["iqm"] - data[env][name]["ci_high"]) for name in sorted_names]
    arrayminus = [abs(data[env][name]["iqm"] - data[env][name]["ci_low"]) for name in sorted_names]
    array_x = [data[env][name]["std"] / 3600 for name in sorted_names]
    
    symbols = ["circle", "square", "diamond", "cross", "pentagon"]
    symbols.extend([f"{sym}-open" for sym in symbols])

    for name in sorted_names:
        n_layer = int(name.split(' | ')[0])
        fig.add_trace(
            go.Scatter(
                name=f"n_layer: {n_layer}",
                x=[None],
                y=[None],
                mode="markers",
                marker=dict(
                    size=11,
                    symbol="diamond",
                    line_width=2,
                    line_color=f"rgb{COLORS_DARKER[n_layer]}",
                    color=f"rgb{COLORS[n_layer]}"
                )
            )
        )
    

    fig.add_trace(
        go.Scatter(
            x=x,
            y=y,
            mode="markers",
            error_x=dict(
                array=array_x,
                symmetric=True,
                color="black",
                thickness=0.7
            ),
            error_y=dict(
                arrayminus=arrayminus,
                array=array,
                symmetric=False,
                color="black",
                thickness=0.7
            ),
            marker=dict(
                size=13,
                color=[f'rgb{COLORS[int(l.split(" | ")[0])]}' for l in sorted_names],
                symbol="diamond",
                line_width=2,
                line_color=[f'rgb{COLORS_DARKER[int(l.split(" | ")[0])]}' for l in sorted_names],
            ),
            showlegend=False
        )
    )

    fig.update_layout(
        title=dict(
            text=title,
            font_size=FONT_SIZE_TITLE,
            xanchor="center",
            x=0.5,
            yanchor="top",
            y=0.989,
        ),
        xaxis=dict(
            title="Runtime [h]",
            tickangle=-35,
            tickfont_size=FONT_SIZE_TICKS,
            title_font_size=FONT_SIZE_AXIS,
        ),
        yaxis=dict(
            title="Accumulated Reward",
            tickangle=-35,
            ticksuffix="  ",
            tickfont_size=FONT_SIZE_TICKS,
            title_font_size=FONT_SIZE_AXIS,
            range=[0., 1.] if "Memory" in title else [0., 47.]
        ),
        legend=dict(
            yanchor="bottom",
            y=0.01,
            xanchor="right",
            x=0.99,
            bgcolor="rgba(0,0,0,0)",
            traceorder="reversed",
            font_size=FONT_SIZE_LEGEND
        ),
        height=700,
        width=900,
        margin=dict(
            l=10,
            r=20,
            b=30,
            t=40,
            pad=10
        ),
    )
    
    return fig

In [None]:
def plot_times_box(
        data: Dict[str, List[float]],
        title: str,
        xaxis_title: str,
) -> go.Figure:
    """
    Plots the rollout or update times as a box plot.

    Args:
        data (Dict[str, List[float]]): A dictionary containing the rollout times for each layer.
            The dictionary should have the following structure:
            {
                layer: times
            }
            where `layer` is a string representing the layer and `times` is a list of rollout times.
        title (str): The title of the plot.

    Returns:
        go.Figure: The plot as a Plotly figure object.
    """
    
    fig = go.Figure()
    
    for layer, times in data.items():
        fig.add_trace(
            go.Box(
                name=f"n layer: {layer}",
                x=times,
                marker_color=f"rgb{COLORS[int(layer)]}",
            )
        )

    fig.update_layout(
        title=dict(
            text=title,
            font_size=FONT_SIZE_TITLE,
            xanchor="center",
            x=0.5,
            yanchor="top",
            y=0.98,
        ),
        xaxis=dict(
            title=xaxis_title,
            tickangle=-35,
            tickfont_size=FONT_SIZE_TICKS,
            title_font_size=FONT_SIZE_AXIS,
        ),
        yaxis=dict(
            ticksuffix="  ",
            tickfont_size=FONT_SIZE_TICKS,
            title_font_size=FONT_SIZE_AXIS,
        ),
        showlegend=False,
        height=700,
        width=900,
        margin=dict(
            l=10,
            r=20,
            b=30,
            t=50,#20
            pad=10
        ),
    )

    return fig

In [None]:
def plot_bar_and_error(
        data: Dict[str, Dict[str, float]],
        title: str,
        yaxis_title: str
) -> go.Figure:
    """
    Generates a bar plot with error bars based on the provided data.

    Parameters:
    - data (Dict[str, Dict[str, float]]): A dictionary containing the data to be plotted. The keys represent the layers, and the values are dictionaries with the following keys:
        - iqm (float): The value for the y-axis.
        - ci_low (float): The lower bound of the confidence interval.
        - ci_high (float): The upper bound of the confidence interval.
        - runtime (float): The runtime value.
        - std (float): The standard deviation value.
    - title (str): The title of the plot.
    - yaxis_title (str): The title of the y-axis.

    Returns:
    - fig (go.Figure): The generated plot as a Plotly figure object.
    """
    fig = go.Figure()
    
    x = []
    y = []
    y_error_neg = []
    y_error_pos = []
    colors = []
    for layer in sorted(data, key=lambda x: int(x.split(" ")[0])):
        plotting_data = data[layer]
        layer = int(layer.split(' | ')[0])
        x_value = f"Layer {layer}"
        y_value = plotting_data["iqm"] if "Reward" in yaxis_title else plotting_data["runtime"] / 3600
        y_error_neg_value = plotting_data["iqm"] - plotting_data["ci_low"] if "Reward" in yaxis_title else plotting_data["std"] / 3600
        y_error_pos_value = plotting_data["ci_high"] - plotting_data["iqm"] if "Reward" in yaxis_title else plotting_data["std"] / 3600
        color = f"rgb{COLORS[layer]}"
        
        x.append(x_value)
        y.append(y_value)
        y_error_pos.append(y_error_pos_value)
        y_error_neg.append(y_error_neg_value)
        colors.append(color)
    
    fig.add_trace(
        go.Bar(
            x=x,
            y=y,
            error_y=dict(
                array=y_error_pos,
                arrayminus=y_error_neg if "Reward" in yaxis_title else None,
                symmetric="Reward" not in yaxis_title
            ),
            marker_color=colors
        )
    )

    fig.update_layout(
        title=dict(
            text=title,
            font_size=FONT_SIZE_TITLE,
            xanchor="center",
            x=0.5,
            yanchor="top",
            y=0.98,
        ),
        xaxis=dict(
            # tickangle=-35,
            tickfont_size=FONT_SIZE_TICKS,
            title_font_size=FONT_SIZE_AXIS,
        ),
        yaxis=dict(
            title=yaxis_title,
            ticksuffix="  ",
            tickfont_size=FONT_SIZE_TICKS,
            title_font_size=FONT_SIZE_AXIS,
        ),
        showlegend=False,
        height=700,
        width=900,
        margin=dict(
            l=10,
            r=20,
            b=30,
            t=50,#20
            pad=10
        ),
    )

    return fig

### Memory Environment

In [None]:
perf_v_runtime_memory = plot_perf_v_time(perf_vs_time, "MiniGrid - Memory")

perf_v_runtime_memory.show()

In [None]:
reward_bar_mem = plot_bar_and_error(
    perf_vs_time["MiniGrid-MemoryS11-v0"],
    "MiniGrid - Memory",
    "Accumulated Reward"
)

reward_bar_mem.show()

In [None]:
runtime_bar_mem = plot_bar_and_error(
    perf_vs_time["MiniGrid-MemoryS11-v0"],
    "MiniGrid - Memory",
    "Runtime [h]"
)


runtime_bar_mem.show()

In [None]:
rollout_times_mem = plot_times_box(rollout_times[tech_names["Memory"]], "MiniGrid - Memory", "Rollout time [sec]")

rollout_times_mem.show()

In [None]:
update_times_mem = plot_times_box(update_times[tech_names["Memory"]], "MiniGrid - Memory", "Update time [sec]")

update_times_mem.show()

### Continuous Recognition Environment

In [None]:
perf_v_runtime_psych = plot_perf_v_time(perf_vs_time, "PsychLab - Continuous Recognition")

perf_v_runtime_psych.show()

In [None]:
reward_bar_psy = plot_bar_and_error(
    perf_vs_time["psychlab_continuous_recognition"],
    "PsychLab - Continuous Recognition",
    "Accumulated Reward"
)

reward_bar_psy.show()

In [None]:
runtime_bar_psy = plot_bar_and_error(
    perf_vs_time["psychlab_continuous_recognition"],
    "PsychLab - Continuous Recognition",
    "Runtime [h]"
)

runtime_bar_psy.show()

In [None]:
rollout_times_psy = plot_times_box(rollout_times[tech_names["Continuous Recognition"]], "PsychLab - Continuous Recognition", "Rollout time [sec]")

rollout_times_psy.show()

In [None]:
update_times_psy = plot_times_box(update_times[tech_names["Continuous Recognition"]], "PsychLab - Continuous Recognition", "Update time [sec]")

update_times_psy.show()

## Significance Test
A statistical test is used to find any significant differences in the model variant performances.

In [None]:
ALPHA = 0.05
CORRECT = True

def highlight_sign(cell):
    if type(cell) != str and cell < ALPHA :
        return 'background: lightgreen; color: black'
    else:
        return ''

### Memory Environment

In [None]:
sign = {}

for n1 in sorted(run_dict["Memory"], key=lambda x: int(x.split(" ")[0])):
    sign[n1] = []
    for n2, data in sorted(run_dict["Memory"].items(), key=lambda x: int(x[0].split(" ")[0])):
        result = ttest_ind(
            a=run_dict["Memory"][n1][:,-1],
            b=run_dict["Memory"][n2][:,-1],
            equal_var=False,
            alternative="greater"
        )
        sign[n1].append(result.pvalue if n1 != n2 else np.nan)

if not CORRECT:
    display(pd.DataFrame(data=sign, index=[model for model in sign]).T.style.applymap(highlight_sign))
else:
    p_vals = pd.DataFrame(data=sign, index=[model for model in sign]).values.flatten()
    _, pps, *_ = multipletests(p_vals, alpha=0.05, method='holm')
    result_df = pd.DataFrame(
        pps.reshape(7,7),
        index=[model for model in sign],
        columns=[model for model in sign]
    )
    display(result_df.T.style.applymap(highlight_sign))

In [None]:
sign = {}

for n1 in sorted(run_dict["Memory"], key=lambda x: int(x.split(" ")[0])):
    sign[n1] = []
    for n2, data in sorted(run_dict["Memory"].items(), key=lambda x: int(x[0].split(" ")[0])):
        result = mannwhitneyu(
            run_dict["Memory"][n1][:,-1],
            run_dict["Memory"][n2][:,-1],
            alternative="greater"
        )
        sign[n1].append(result.pvalue if n1 != n2 else np.nan)

if not CORRECT:
    display(pd.DataFrame(data=sign, index=[model for model in sign]).T.style.applymap(highlight_sign))
else:
    p_vals = pd.DataFrame(data=sign, index=[model for model in sign]).values.flatten()
    _, pps, *_ = multipletests(p_vals, alpha=0.05, method='holm')
    result_df = pd.DataFrame(
        pps.reshape(7,7),
        index=[model for model in sign],
        columns=[model for model in sign]
    )
    display(result_df.T.style.applymap(highlight_sign))

### Continuous Recognition Environment

In [None]:
sign = {}

for n1 in sorted(run_dict["Continuous Recognition"], key=lambda x: int(x.split(" ")[0])):
    sign[n1] = []
    for n2, data in sorted(run_dict["Continuous Recognition"].items(), key=lambda x: int(x[0].split(" ")[0])):
        result = ttest_ind(
            a=run_dict["Continuous Recognition"][n1][:,-1],
            b=run_dict["Continuous Recognition"][n2][:,-1],
            equal_var=False,
            alternative="greater"
        )
        sign[n1].append(result.pvalue if n1 != n2 else np.nan)

if not CORRECT:
    display(pd.DataFrame(data=sign, index=[model for model in sign]).T.style.applymap(highlight_sign))
else:
    p_vals = pd.DataFrame(data=sign, index=[model for model in sign]).values.flatten()
    _, pps, *_ = multipletests(p_vals, alpha=0.05, method='holm')
    result_df = pd.DataFrame(
        pps.reshape(3,3),
        index=[model for model in sign],
        columns=[model for model in sign]
    )
    display(result_df.T.style.applymap(highlight_sign))

In [None]:
sign = {}

for n1 in sorted(run_dict["Continuous Recognition"], key=lambda x: int(x.split(" ")[0])):
    sign[n1] = []
    for n2, data in sorted(run_dict["Continuous Recognition"].items(), key=lambda x: int(x[0].split(" ")[0])):
        result = mannwhitneyu(
            run_dict["Continuous Recognition"][n1][:,-1],
            run_dict["Continuous Recognition"][n2][:,-1],
            alternative="greater"
        )
        sign[n1].append(result.pvalue if n1 != n2 else np.nan)

if not CORRECT:
    display(pd.DataFrame(data=sign, index=[model for model in sign]).T.style.applymap(highlight_sign))
else:
    p_vals = pd.DataFrame(data=sign, index=[model for model in sign]).values.flatten()
    _, pps, *_ = multipletests(p_vals, alpha=0.05, method='holm')
    result_df = pd.DataFrame(
        pps.reshape(3,3),
        index=[model for model in sign],
        columns=[model for model in sign]
    )
    display(result_df.T.style.applymap(highlight_sign))