### Read data from file

In [None]:
%load_ext autoreload
%autoreload 2
import pickle
from pathlib import Path

root_module = Path.cwd()
ext = ".pkl"
file_name = "dmc_num_samples"
load_dir = root_module.parent.joinpath("data/")
file_dir = load_dir.joinpath(file_name + ext)
raw_data = pickle.load(open(file_dir, "rb"))

### Process raw results

In [None]:
from dist_mbrl.utils.process_results import process_raw_results

# Post-process raw return data to get mean and standard error
metric_data = process_raw_results(raw_data)

# Retrieve relevant variables from data
env_names = metric_data["env_names"]
steps = metric_data["steps"]
mean_returns = metric_data["mean_returns"]
ci_returns = metric_data["ci_returns"]

### Plotting

In [None]:
%matplotlib widget
import matplotlib.pyplot as plt
import numpy as np

from dist_mbrl.utils.plot import (
    JMLR_PARAMS,
    handle_1D_axes_and_legend,
    plot_with_symmetric_intervals,
)

plt.rcParams.update(JMLR_PARAMS)

# Define grid of plots
ncols = 4
nrows = int(np.ceil(len(env_names) / ncols))
fig, axes = plt.subplots(
    nrows=nrows,
    ncols=ncols,
    figsize=(6.8, 1.0),
    gridspec_kw={"wspace": 0.4, "hspace": 0.4},
)

# Reference axes by environment name
ax_dict = {}
for idx, ax in zip(env_names, axes.flatten()):
    ax_dict[idx] = ax

# Assign colors and linestyle to each of the hparam combination we compare
cmap = plt.get_cmap("tab10")
colors = {"mean": cmap(0), "ofu": cmap(1)}
line_types = {"1": ":", "5": "--", "10": "-"}


# Custom processing of labels
def custom_process_label(params):
    agent_type = params[2]
    num_samples = int(params[0]) * int(params[1])
    return f"{agent_type}," + rf" \#$(s',a')$={{{num_samples}}}"


# Plot for all environments and all methods
ep_length = 1000
for env_name in env_names:
    for idx in mean_returns[env_name].keys():
        plot_with_symmetric_intervals(
            ax=ax_dict[env_name],
            x=steps[env_name][idx] // ep_length,
            y=mean_returns[env_name][idx],
            yerr=ci_returns[env_name][idx],
            label=custom_process_label(idx),
            ls=line_types[idx[1]],
            title=env_name,
            color=colors[idx[2]],
        )

order = [0, 1, 4, 5, 2, 3]
handle_1D_axes_and_legend(
    axes=axes, order=order, legend_ncol=3, columnspacing=0.8, legend_offset=(2.4, -1.3)
)

### Save figures

In [None]:
fig_dir = root_module.parent.joinpath("figures/dmc_num_samples.pdf")
fig.savefig(fig_dir, bbox_inches="tight", transparent=False)

# License

>Copyright (c) 2024 Robert Bosch GmbH
>
>This program is free software: you can redistribute it and/or modify <br>
>it under the terms of the GNU Affero General Public License as published<br>
>by the Free Software Foundation, either version 3 of the License, or<br>
>(at your option) any later version.<br>
>
>This program is distributed in the hope that it will be useful,<br>
>but WITHOUT ANY WARRANTY; without even the implied warranty of<br>
>MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the<br>
>GNU Affero General Public License for more details.<br>
>
>You should have received a copy of the GNU Affero General Public License<br>
>along with this program.  If not, see <https://www.gnu.org/licenses/>.