In [None]:
from pathlib import Path

import numpy as np
import torch
from torch.functional import cartesian_prod

from dist_mbrl.agents import EQRSAC
from dist_mbrl.utils.common import PathType

### Load agents

In [None]:
epochs_to_plot = [4, 5, 10, 11, 12]


def load_agent(model_dir: PathType, fname: str, device: str = "cpu") -> EQRSAC:
    model = torch.load(model_dir / fname, map_location=torch.device(device))
    return model


root_module = Path.cwd()
checkpoints_dir = root_module.parent / "data/trained_mountaincar_checkpoints"

agents = {}
for epoch in epochs_to_plot:
    prefix = "agent"
    ext = "pth"
    agents[epoch] = load_agent(checkpoints_dir, f"{prefix}_{epoch}.{ext}")

### For each agent, collect a few trajectories to overlay with value plot

In [None]:
import dist_mbrl.utils.common as utils_common

env_params = {"name": "MountainCarContinuous-v0", "reward_scale": 0.5}
env, eval_env = utils_common.get_env(env_params)
seed = 42
device = "cuda" if torch.cuda.is_available() else "cpu"
rng, generator = utils_common.fix_rng(env, eval_env=eval_env, seed=seed, device=device)

reset_bounds = {"low": -0.5, "high": -0.5}
trajectories_per_agent = 10
steps_per_epoch = 1000
trajectories = {}
empirical_values = {}

for episode in epochs_to_plot:
    agents[episode].device = device
    agents[episode].actor.device = device
    for i in range(trajectories_per_agent):
        obs, _ = env.reset(options=reset_bounds)
        terminated = False
        episode_return = 0
        step = 0
        trajectory = []
        returns = []
        while not (terminated or step >= steps_per_epoch):
            action = agents[episode].act(obs)
            trajectory.append(obs)
            next_obs, reward, terminated, truncated, info = env.step(action)
            episode_return += (agents[episode].gamma ** step) * reward
            obs = next_obs
            step += 1

        trajectory.append(obs)
        returns.append(episode_return)
        trajectories[(episode, i)] = np.array(trajectory)
    empirical_values[episode] = np.mean(np.array(returns))

### Build inputs for the agent's networks

In [None]:
max_pos = 0.5
min_pos = -1.2
max_vel = 0.07
num_pos = 100
num_vels = 100
positions = np.linspace(min_pos, max_pos, num_pos)
vels = np.linspace(-max_vel, max_vel, num_vels)


def build_function_inputs(disc_pos, disc_vels):
    pos = torch.Tensor(disc_pos)
    vels = torch.Tensor(disc_vels)
    inputs = cartesian_prod(pos, vels)
    return inputs


# Inputs to be fed into distributional critic
inputs = build_function_inputs(positions, vels)

### Extract data from agents

In [None]:
# To visualize things, we need the state-value function, not the state-action value function
# We approximate the state-value function by passing the mean actions through the critic
# In other words, we visualize the Q-values for the mean actions
mean_values = {}
std_values = {}
quantiles_init_state = {}
init_state, _ = eval_env.reset(options=reset_bounds)
init_state = torch.from_numpy(init_state).float().to(device)
init_state = init_state.unsqueeze(dim=0)

for epoch in epochs_to_plot:
    agent = agents[epoch]
    mean_actions, log_std = agent.actor.forward(inputs)
    quantiles = agent.get_min_q(inputs, mean_actions, agent.critic)
    mean_values[epoch] = torch.mean(quantiles, dim=-1)
    std_values[epoch] = torch.std(quantiles, dim=-1)

    # also get quantiles at initial state
    act = torch.from_numpy(agent.act(init_state, sample=False)).float().to(device)

    quantiles_init_state[epoch] = (
        agent.critic(init_state, act).detach().numpy().squeeze()
    )

### Plot

In [None]:
%matplotlib widget
import matplotlib.pyplot as plt
import scipy.stats as stats
from matplotlib.axes import Axes
from matplotlib.colors import Colormap
from matplotlib.pyplot import cm

from dist_mbrl.utils.plot import JMLR_PARAMS

plt.rcParams.update(JMLR_PARAMS)


# Helper functions
def fix_white_lines(cnt):
    for c in cnt.collections:
        c.set_edgecolor("face")
    return cnt


def plot_function_contour(
    ax: Axes,
    values: torch.Tensor,
    colormap: Colormap = cm.coolwarm,
    minmax_values: tuple = None,
    show_colorbar: bool = True,
):
    num_levels = 100
    values = torch.reshape(values, (num_pos, num_vels)).detach().numpy().transpose()
    xaxis, yaxis = np.meshgrid(positions, vels)
    if minmax_values:
        min_value = minmax_values[0]
        max_value = minmax_values[1]
    else:
        min_value = np.min(values)
        max_value = np.max(values)

    levels = np.linspace(min_value, max_value, num_levels)
    data = fix_white_lines(
        ax.contourf(xaxis, yaxis, values, levels=levels, cmap=colormap)
    )
    mid_value = np.mean([min_value, max_value])
    ticks = [min_value, mid_value, max_value]
    cbar = plt.colorbar(data, ax=ax, format=lambda x, _: f"{x:.1f}", ticks=ticks)
    if not show_colorbar:
        cbar.remove()


fig = plt.figure(figsize=(6.3, 3.5))
ncols = len(epochs_to_plot)

# Alternate way of plotting on grid to have independent control over axis padding, width, etc.
gs_top = plt.GridSpec(
    nrows=3, ncols=ncols + 1, hspace=1.0, wspace=0.2, width_ratios=[1, 1, 1, 1, 1, 0.1]
)
gs_bottom = plt.GridSpec(
    nrows=3, ncols=ncols, hspace=0.15, wspace=0.09, width_ratios=[1, 1, 1, 1, 1.2]
)
axes = np.empty((3, ncols), dtype=object)
for i in range(ncols):
    axes[0, i] = fig.add_subplot(gs_top[0, i])
    axes[1, i] = fig.add_subplot(gs_bottom[1, i])
    axes[2, i] = fig.add_subplot(gs_bottom[2, i])

quantiles_axes = axes[0, :]
mean_values_axes = axes[1, :]
std_values_axes = axes[2, :]

# Get min-max values of mean and std of values to have common colorbar
min_mean_value = np.min([torch.min(v).item() for v in mean_values.values()])
max_mean_value = np.max([torch.max(v).item() for v in mean_values.values()])
minmax_mean_value = (min_mean_value, max_mean_value)

min_std_value = np.min([torch.min(v).item() for v in std_values.values()])
max_std_value = np.max([torch.max(v).item() for v in std_values.values()])
minmax_std_value = (min_std_value, max_std_value)

for i, episode in enumerate(epochs_to_plot):
    show_color_bar = (i + 1) == len(epochs_to_plot)
    plot_function_contour(
        mean_values_axes[i],
        mean_values[episode],
        minmax_values=minmax_mean_value,
        show_colorbar=show_color_bar,
    )
    plot_function_contour(
        std_values_axes[i],
        std_values[episode],
        colormap=cm.viridis,
        minmax_values=minmax_std_value,
        show_colorbar=show_color_bar,
    )

# Plot trajectories on top of value plots
for i, episode in enumerate(epochs_to_plot):
    for j in range(trajectories_per_agent):
        traj = trajectories[(episode, j)]
        mean_values_axes[i].plot(traj[:, 0], traj[:, 1], c="k", alpha=0.2)

# Plot initial state value distributions
x = np.linspace(-100, 100, 5000)
for i, episode in enumerate(epochs_to_plot):
    data = quantiles_init_state[episode]
    empirical_value = empirical_values[episode]
    kernel = stats.gaussian_kde(data)
    quantiles_axes[i].plot(x, kernel.pdf(x), color="tab:red", lw=2.0)
    quantiles_axes[i].axvline(empirical_value, c="g", ls="--")

# Manage top axes
axes[0, 0].set_ylabel("Prob." + "\n" + "density")
for ax in axes[0, :].flatten():
    ax.set_xlabel(r"$V(s_0)$", labelpad=-2)

for ax in quantiles_axes:
    ax.yaxis.set_ticklabels([])
    ax.yaxis.set_ticks([])

for ax in quantiles_axes[:2]:
    ax.set_xticks([-0, 100])

for ax in quantiles_axes[2:]:
    ax.set_xlim(right=50, left=-10)

# Manage two bottom axes
for ax in axes[-2:, 1:].flatten():
    ax.yaxis.set_ticklabels([])

for ax in axes[-2:-1, :].flatten():
    ax.xaxis.set_ticklabels([])

for epoch, ax in zip(epochs_to_plot, axes[0, :].flatten()):
    ax.set_title(rf"\textbf{{{epoch}K steps}}")

for ax in axes[-2:, 0].flatten():
    ax.set_ylabel(r"$\dot{x}$ [m/s]")

for ax in axes[-1, :].flatten():
    ax.set_xlabel(r"$x$ [m]")

### Save figures

In [None]:
fig_dir = root_module.parent.joinpath("figures/mountaincar_value_viz.pdf")
fig.savefig(fig_dir, bbox_inches="tight", transparent=False)

# License

>Copyright (c) 2024 Robert Bosch GmbH
>
>This program is free software: you can redistribute it and/or modify <br>
>it under the terms of the GNU Affero General Public License as published<br>
>by the Free Software Foundation, either version 3 of the License, or<br>
>(at your option) any later version.<br>
>
>This program is distributed in the hope that it will be useful,<br>
>but WITHOUT ANY WARRANTY; without even the implied warranty of<br>
>MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the<br>
>GNU Affero General Public License for more details.<br>
>
>You should have received a copy of the GNU Affero General Public License<br>
>along with this program.  If not, see <https://www.gnu.org/licenses/>.