### Read data from file

In [None]:
import pickle
from pathlib import Path

root_module = Path.cwd()
ext = ".pkl"
file_name = "dmc_action_cost"
load_dir = root_module.parent.joinpath("data/")
file_dir = load_dir.joinpath(file_name + ext)
raw_data = pickle.load(open(file_dir, "rb"))

### Get all the data we need from the metrics

In [None]:
from dist_mbrl.utils.process_results import process_raw_results

# Post-process raw return data to get mean and standard error
metric_data = process_raw_results(raw_data)

# Retrieve relevant variables from data
env_names = metric_data["env_names"]
steps = metric_data["steps"]
mean_returns = metric_data["mean_returns"]
ci_returns = metric_data["ci_returns"]

### Plotting

In [None]:
%matplotlib widget
import matplotlib.pyplot as plt

from dist_mbrl.utils.plot import JMLR_PARAMS, LIGHT_GREY

plt.rcParams.update(JMLR_PARAMS)

fig, axes = plt.subplots(
    nrows=2, ncols=3, figsize=(6.0, 3.0), gridspec_kw={"wspace": 0.3, "hspace": 0.2}
)

ax_dict = {
    ("cartpole-swingup_sparse", "0.0"): axes[0, 0],
    ("cartpole-swingup_sparse", "0.001"): axes[0, 1],
    ("cartpole-swingup_sparse", "0.003"): axes[0, 2],
    ("pendulum-swingup", "0.0"): axes[1, 0],
    ("pendulum-swingup", "0.01"): axes[1, 1],
    ("pendulum-swingup", "0.03"): axes[1, 2],
}

colors = {}
quantile_levels = ["50", "70", "90", "nan"]
cmap = plt.get_cmap("tab10")
for i, type in enumerate(quantile_levels):
    colors[type] = cmap(i)

ep_length = 1000


# Custom processing of labels
def custom_process_label(params):
    quantile = params[1]
    agent_type = params[2]
    if quantile == "nan":
        return rf"\texttt{{{agent_type}}}"
    else:
        quantile = int(quantile) / 100.0
        return rf"\texttt{{{agent_type}-{quantile}}}"


for env_name in env_names:
    for idx in mean_returns[env_name].keys():
        color_id, cost = idx[1], idx[3]
        label = custom_process_label(idx)
        ax_dict[(env_name, cost)].plot(
            steps[env_name][idx] // ep_length,
            mean_returns[env_name][idx],
            linestyle="-",
            linewidth=1.5,
            label=label,
            c=colors[color_id],
        )
        ax_dict[(env_name, cost)].fill_between(
            steps[env_name][idx] // ep_length,
            mean_returns[env_name][idx] - ci_returns[env_name][idx],
            mean_returns[env_name][idx] + ci_returns[env_name][idx],
            alpha=0.2,
            color=colors[color_id],
        )
        ax_dict[(env_name, cost)].grid(color=LIGHT_GREY)

axes[0, 0].set_ylabel(r"\textbf{{cartpole-swingup}}" + "\n \n Return")
axes[1, 0].set_ylabel(r"\textbf{{pendulum}}" + "\n \n Return")

axes[0, 0].set_title(r"Action cost $0\times$")
axes[0, 1].set_title(r"Action cost $1\times$")
axes[0, 2].set_title(r"Action cost $3\times$")


for ax in axes[0, :]:
    ax.set_xticklabels([])

for ax in axes[1, :]:
    ax.set_xlabel("Episodes")
    ax.set_title("")

for ax in axes.flatten():
    ax.set_ylim([-50, 900])

axes[0, 0].legend(loc="lower center", ncol=4, bbox_to_anchor=(1.5, -2.0), frameon=False)

### Save figures

In [None]:
fig_dir = root_module.parent.joinpath("figures/dmc_action_cost.pdf")
fig.savefig(fig_dir, bbox_inches="tight", transparent=False)

# License

>Copyright (c) 2024 Robert Bosch GmbH
>
>This program is free software: you can redistribute it and/or modify <br>
>it under the terms of the GNU Affero General Public License as published<br>
>by the Free Software Foundation, either version 3 of the License, or<br>
>(at your option) any later version.<br>
>
>This program is distributed in the hope that it will be useful,<br>
>but WITHOUT ANY WARRANTY; without even the implied warranty of<br>
>MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the<br>
>GNU Affero General Public License for more details.<br>
>
>You should have received a copy of the GNU Affero General Public License<br>
>along with this program.  If not, see <https://www.gnu.org/licenses/>.