### Load Pickle files from all environments

In [None]:
import pickle
from pathlib import Path

root_module = Path.cwd()
ext = ".pkl"
env_names = ["HalfCheetah", "Walker2D", "Ant"]
load_dir = root_module.parent.joinpath("data/")

env_data = {}
for env_name in env_names:
    file_dir = load_dir.joinpath(env_name + ext)
    env_data[env_name] = pickle.load(open(file_dir, "rb"))

### Smoothen Returns and Compute Mean + Confidence Intervals

In [None]:
import numpy as np
from scipy.stats import sem

WINDOW_SIZE = 10
all_types = ["exact_ube_3", "pombu", "ensemble", "none", "sac"]

def rolling_average(w: int, arr: np.ndarray) -> np.ndarray:
    """
    Expects an array of size (num_points, num_seeds) where we want to smoothen out each individual
    curve by a moving average
    """
    if arr.size == 0:
        return arr
    one_array = np.ones(w) / w
    centered_window = ((w-1) // 2, int(np.ceil((w-1) / 2)))
    padded_arr = np.pad(arr, [centered_window, (0,0)], mode='edge')
    return np.apply_along_axis(lambda m: np.convolve(m, one_array, mode='valid'), axis=0, arr=padded_arr)

for env_name in env_names:
    env_data[env_name]["mean_returns"] = {}
    env_data[env_name]["ci_returns"] = {}
    for type in all_types:
        returns = env_data[env_name]["returns"][type]
        smoothened_returns = rolling_average(WINDOW_SIZE, returns)
        env_data[env_name]["mean_returns"][type] = np.mean(smoothened_returns, axis=-1)
        env_data[env_name]["ci_returns"][type] = sem(smoothened_returns, axis=-1)

### Plot results

In [None]:
%matplotlib widget
import matplotlib.pyplot as plt
from ube_mbrl.utils.plot import PARAMS

plt.rcParams.update(PARAMS)
plt.rcParams['axes.formatter.useoffset'] = False

fig, axes = plt.subplots(nrows=1, ncols=len(env_names), figsize=(7.5, 1.5),gridspec_kw={'wspace':0.15,'hspace':0.2})

ax_dict = {}
for env_name, ax in zip(env_names, axes.flatten()):
    ax_dict[env_name] = ax


colors = {}
cmap = plt.get_cmap("tab10")
for i, type in enumerate(all_types):
    colors[type] = cmap(i)

ep_length = 1000 

for env_name in env_names:
    ax_dict[env_name].set_title(env_name)
    for type in all_types[::-1]:
        if type == "exact_ube_3":
            label = r"\texttt{exact-ube}" + " (ours)"
        elif type == "ensemble":
            label = r"\texttt{ensemble-var}"
        elif type == "none":
            label = r"\texttt{ensemble-mean}"
        else: label = fr"\texttt{{{type}}}"

        steps = env_data[env_name]["steps"][type]
        mean_returns = env_data[env_name]["mean_returns"][type]
        ci_returns = env_data[env_name]["ci_returns"][type]

        ax_dict[env_name].plot(
            steps // ep_length,
            mean_returns,
            linestyle="-", 
            linewidth=2, 
            label=label, 
            c=colors[type]
        )

        ax_dict[env_name].fill_between(
            steps // ep_length,
            mean_returns + ci_returns,
            mean_returns - ci_returns,
            alpha=0.2,
            color=colors[type]
        )

axes[0].set_ylabel(r"Return ($\times 10^3$)")
axes[1].set_yticks([0e3, 1e3])
handles, labels = axes[0].get_legend_handles_labels()
axes[0].legend(handles[::-1], labels[::-1], loc = 'lower center', columnspacing=0.8, ncol=5, bbox_to_anchor=(1.65, -0.6), frameon=False)
for ax in axes:
    ax.set_xlabel("Episode")
    ax.ticklabel_format(axis='y', style='sci', scilimits=(0,0))

In [None]:
# Save figures
fig_dir = root_module.parent.joinpath(f"figures/pybullet_benchmark.pdf")
fig.savefig(fig_dir, bbox_inches="tight", transparent=False)

# License

>Copyright (c) 2023 Robert Bosch GmbH
>
>This program is free software: you can redistribute it and/or modify <br>
>it under the terms of the GNU Affero General Public License as published<br>
>by the Free Software Foundation, either version 3 of the License, or<br>
>(at your option) any later version.<br>
>
>This program is distributed in the hope that it will be useful,<br>
>but WITHOUT ANY WARRANTY; without even the implied warranty of<br>
>MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the<br>
>GNU Affero General Public License for more details.<br>
>
>You should have received a copy of the GNU Affero General Public License<br>
>along with this program.  If not, see <https://www.gnu.org/licenses/>.