### Load Pickle file

In [None]:
import pickle
from pathlib import Path

root_module = Path.cwd()
ext = ".pkl"
name = "deep_sea_ube_ablation"
load_dir = root_module.parent.joinpath("data/")
file_dir = load_dir.joinpath(name + ext)
data = pickle.load(open(file_dir, "rb"))

### Get metrics from raw data

In [None]:
import numpy as np
from scipy.stats import sem

data["mean_regret"] = {}
data["ci_regret"] = {}
data["mean_solve_time"] = {}
for key in data["total_regret"].keys():
    # Get mean and standard error of total regret
    data["mean_regret"][key] = np.mean(data["total_regret"][key], axis=-1)
    data["ci_regret"][key] = sem(data["total_regret"][key], axis=-1)
    # Get mean solve times - note: NaN means the run did not solve the task
    data["mean_solve_time"][key] = np.nanmean(data["solve_time"][key])

### Learning Time and Final Total Regret Plots

In [None]:
%matplotlib widget
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
from collections import defaultdict
from itertools import product
from ube_mbrl.utils.plot import PARAMS

plt.rcParams.update(PARAMS)

# Get unique agent types and deep sea sizes in the data
keys_arr = np.array(list((data["total_regret"].keys())), dtype=object)

# convert agent types to array so we can sort it and order it how we like
agent_types = np.array(list(set(keys_arr[:, 0])), dtype=tuple)
agent_types = agent_types[agent_types[:,1].argsort()]
# rearrange to desire order 
order = [2, 1, 0]
agent_types = agent_types[order].tolist()
# convert back to tuple so we use it as index
agent_types = [tuple(agent_type) for agent_type in agent_types]

deep_sea_sizes = np.sort(list(set(keys_arr[:, 1])))

fig_ds, axes = plt.subplots(nrows=1, ncols=2, figsize=(7.5, 1.5),gridspec_kw={'wspace':0.25,'hspace':0.5})
wd = 2
sizes_to_plot = [10, 20, 30, 40]
pos = [-3 * wd/2, -wd/2, wd/2, 3*wd/2]
ticks = set()
patches = []

colors = {}
cmap = plt.get_cmap("tab10")
for i, agent_type in enumerate(agent_types):
    colors[agent_type] = cmap(i)

# Get solve time for sizes to plot
solve_time = defaultdict(list)
for agent, size in product(agent_types, deep_sea_sizes):
    agent = tuple(agent)
    idx = (agent, size)
    solve_time[agent].append((int(size), float(data["mean_solve_time"][idx])))

for i, agent in enumerate(agent_types):
    agent = tuple(agent)
    agent_name, uq_method = agent
    if agent_name == "psrl":
        label = fr"\texttt{{{agent_name}}}"
    else:
        if uq_method == "exact_ube_3":
            label = r"\texttt{ofu-exact-ube}" + " (ours)"
        elif uq_method == "ensemble":
            label = r"\texttt{ofu-ensemble-var}"
        else:
            label = fr"\texttt{{{agent_name}-{uq_method}}}"
    values = np.array(solve_time[agent], dtype=object)
    axes[0].plot(values[:, 0], values[:, 1], marker='o', linewidth=2, label=label, c=colors[agent])

    patches.append(mpatches.Patch(color=colors[agent], label=label))
    for size in sizes_to_plot:
        ticks.add(size)
        idx = (agent, str(size))
        axes[1].bar(
            size + pos[i],
            data["mean_regret"][idx][-1], 
            yerr=data["ci_regret"][idx][-1], 
            width=wd, 
            color=colors[agent]
        )

handles, labels = axes[0].get_legend_handles_labels()
axes[0].legend(handles, labels,
    loc='lower center', ncol=len(axes[0].lines), bbox_to_anchor=(1.1, -0.7), frameon=False
)
axes[0].set_ylabel("Episodes until \n $< 90\%$ failure (log)")
axes[0].set_xlabel("DeepSea size")
axes[0].set_xticks([10, 20, 30, 40])
axes[0].set_yscale('log')
axes[0].set_ylim(top=1000)
axes[0].minorticks_off()
axes[1].set_ylabel("Total regret")
axes[1].set_xlabel("DeepSea size")
axes[1].ticklabel_format(axis='y', style='sci', scilimits=(0,0))


### Save figures

In [None]:
# Save figures
import os
from pathlib import Path
root_module = Path.cwd()
fig_ds_dir = root_module.parent.joinpath(f"figures/deep_sea_ube_ablation_solve_time_and_total_regret.pdf")
fig_ds.savefig(fig_ds_dir, bbox_inches="tight", transparent=False)

# License

>Copyright (c) 2023 Robert Bosch GmbH
>
>This program is free software: you can redistribute it and/or modify <br>
>it under the terms of the GNU Affero General Public License as published<br>
>by the Free Software Foundation, either version 3 of the License, or<br>
>(at your option) any later version.<br>
>
>This program is distributed in the hope that it will be useful,<br>
>but WITHOUT ANY WARRANTY; without even the implied warranty of<br>
>MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the<br>
>GNU Affero General Public License for more details.<br>
>
>You should have received a copy of the GNU Affero General Public License<br>
>along with this program.  If not, see <https://www.gnu.org/licenses/>.