In [None]:
import numpy as np
import scipy.stats as stats
from gymnasium.wrappers.compatibility import EnvCompatibility
from rlberry.envs.benchmarks.grid_exploration.nroom import NRoom
from rlberry.seeding import Seeder, safe_reseed

import dist_mbrl.tabular.util as util
from dist_mbrl.mbrl.util import ReplayBuffer
from dist_mbrl.tabular.agents import TabularAgent
from dist_mbrl.tabular.config import base_default_parameters

### Setup Env

In [None]:
parameters = base_default_parameters
env = EnvCompatibility(
    NRoom(
        nrooms=3,
        room_size=5,
        success_probability=1.0,
        initial_state_distribution="center",
    )
)
obs_shape = (1,)
act_shape = (1,)

# Fix RNG
seed = 42
rng = util.fix_rng(seed=seed)
safe_reseed(env.env, Seeder(seed))

# Task horizon
steps_per_episode = 20

agent = TabularAgent(
    env.observation_space.n, env.action_space.n, params=parameters["agent"]
)

### Load pre-trained optimal policy

In [None]:
import pickle
from pathlib import Path

root_module = Path.cwd()
ext = ".pkl"
name = "nroom_policy"
file_dir = root_module.parent.joinpath("data/" + name + ext)

data = pickle.load(open(file_dir, "rb"))
optimal_policy = data["opt_policy"]

### Collect Env data and train model

In [None]:
from copy import deepcopy

warmup_episodes = 0
num_episodes = 100
episode_to_record = [1, 10, 100]
reward_scale = 1
agents = []
ep_returns = np.zeros(warmup_episodes + num_episodes)
agent.pi = optimal_policy
for i in range(warmup_episodes + num_episodes):
    if i < warmup_episodes:
        random_actions = np.random.choice(agent.num_actions, agent.num_states)
        agent.pi = agent.actions_to_policy_matrix(random_actions)
    else:
        agent.pi = optimal_policy

    ep_buffer = ReplayBuffer(
        steps_per_episode,
        obs_shape,
        act_shape,
        obs_type=np.int32,
        action_type=np.int32,
        rng=rng,
    )
    obs, _ = env.reset()
    initial_state = obs
    ep_step = 0
    terminated = False
    while (not terminated) and (ep_step < steps_per_episode):
        action = agent.act(obs)
        next_obs, reward, terminated, *_ = env.step(action)
        if reward == 1:
            terminated = True
        ep_buffer.add(obs, action, next_obs, reward_scale * reward, terminated)
        ep_returns[i] += (agent.gamma**ep_step) * (reward_scale * reward)
        obs = next_obs
        ep_step += 1

    # Update agent's MDP posterior
    agent.update_posterior_mdp(ep_buffer)

    # Record agent
    if (i - warmup_episodes + 1) in episode_to_record:
        agents.append(deepcopy(agent))

# Get empirical estimate of value at initial state
opt_value = np.mean(ep_returns)

### Ground-truth value distribution

In [None]:
from typing import Tuple

num_samples = 1000


def compute_value_ensemble(num_samples, agent) -> Tuple[np.ndarray, np.ndarray]:
    """
    Returns an ensemble of V-functions and Q-functions of the same size as the ensemble of
    transition models passed as argument to the function.
    """
    ensemble = agent.sample_ensemble_from_posterior(num_samples)
    p_ensemble, r_ensemble = ensemble
    num_models = p_ensemble.shape[0]
    vfs, qfs = [], []
    for i in range(num_models):
        p = p_ensemble[i]
        r = r_ensemble[i]
        vf, qf = util.solve_bellman_eq(p, r, agent.pi, agent.gamma)
        vfs.append(vf)
        qfs.append(qf)
    v_ensemble = np.stack([value for value in vfs], axis=0)
    return v_ensemble, p_ensemble, r_ensemble


vf = []
for i, agent in enumerate(agents):
    v_ensemble, p_ensemble, r_ensemble = compute_value_ensemble(num_samples, agent)
    vf.append(v_ensemble)

### Value distribution via Quantile regression

In [None]:
NUM_QUANTILES = 100
ENSEMBLE_SIZE = 1
NUM_STATES = vf[0].shape[1]
tau = (2 * np.arange(NUM_QUANTILES) + 1) / (2.0 * NUM_QUANTILES)


def get_quantiles(values, num_quantiles):
    values = np.sort(values)
    quant_levels = (2 * np.arange(len(values)) + 1) / (2.0 * len(values))
    target_levels = (2 * np.arange(num_quantiles) + 1) / (2.0 * num_quantiles)
    idx = [np.abs(quant_levels - level).argmin() for level in target_levels]
    return values[idx]


def dist_value_iteration(
    agent: TabularAgent,
    num_quantiles,
    max_iter=int(1e4),
    gamma=0.99,
    lr=5e-1,
    epsilon=1e-8,
):
    # Random init guess of the value distribution function
    # Force terminal state to have a value of zero
    theta_i = 1 * np.sort(np.random.rand(num_quantiles, NUM_STATES), axis=0)
    theta_i[:, -1] *= 0

    theta_list = [theta_i]

    for i in range(max_iter):
        # First: sample models from the posterior to estimate the gradient
        p, r = agent.sample_ensemble_from_posterior(num_samples=ENSEMBLE_SIZE)
        r_pi = np.einsum("eij, ij -> ei", r, agent.pi)
        p_pi = np.einsum("eijk, ij -> eik", p, agent.pi)

        # Second: Compute the gradient of the quantile regression loss
        theta_j = np.expand_dims(r_pi, axis=1) + gamma * np.einsum(
            "eij, mj -> emi", p_pi, theta_i
        )
        theta_j = np.expand_dims(theta_j, axis=1)
        tmp = np.expand_dims(theta_i, axis=(0, 2))
        indicator_fn = (theta_j - tmp < 0).astype(float)
        grad_loss = np.expand_dims(tau, axis=(0, -2, -1)) - indicator_fn
        grad_loss = np.mean(grad_loss, axis=(0, 2))

        # Update the params by taking a step in the direction of the gradient
        new_theta_i = theta_i + lr * grad_loss
        theta_list.append(new_theta_i)
        new_theta_i[:, -1] *= 0
        if np.any(np.abs(theta_i - new_theta_i) > epsilon):
            theta_i = new_theta_i
        else:
            break

    return theta_i, theta_list


true_quantiles = []
pred_quantiles = []
wass_ipm = []
for i, agent in enumerate(agents):
    true_quantiles.append(get_quantiles(vf[i][:, initial_state], NUM_QUANTILES))
    pred_quantiles.append(
        dist_value_iteration(agent, NUM_QUANTILES, gamma=agent.gamma)[-1]
    )
    theta = np.array(pred_quantiles[i])
    wass = (1 / NUM_QUANTILES) * np.sum(
        np.abs(true_quantiles[i] - theta[:, :, initial_state]), axis=1
    )
    wass_ipm.append(wass)

for wass in wass_ipm:
    print(wass[-1])

### Plot

In [None]:
%matplotlib widget
import matplotlib.pyplot as plt

from dist_mbrl.utils.plot import JMLR_PARAMS

plt.rcParams.update(JMLR_PARAMS)

fig = plt.figure(figsize=(6.5, 3.5))

# Alternate way of plotting on grid to have independent control over axis padding, width, etc.
gs_top = plt.GridSpec(nrows=3, ncols=3, hspace=0.0, wspace=0.3, bottom=0.13)
gs_bottom = plt.GridSpec(nrows=3, ncols=3, hspace=0.6, wspace=0.3)
axes = np.empty((3, 3), dtype=object)
for i in range(3):
    axes[0, i] = fig.add_subplot(gs_top[0, i])
    axes[1, i] = fig.add_subplot(gs_top[1, i])
    axes[2, i] = fig.add_subplot(gs_bottom[2, i])

cmap = plt.get_cmap("tab10")

for i in range(len(agents)):
    # First plot: ground-truth value distribution
    x = np.linspace(-15, 15, 5000)
    theta = np.array(pred_quantiles[i])
    kernel_v = stats.gaussian_kde(vf[i][:, initial_state])
    axes[0, i].plot(
        x,
        kernel_v.pdf(x),
        color="tab:blue",
        lw=2.0,
        ls="--",
        dashes=(5, 1),
        label=r"$\mu^{\pi^\star}(s_0)$",
    )
    kernel_v = stats.gaussian_kde(theta[-1, :, initial_state])
    axes[0, i].plot(
        x, kernel_v.pdf(x), color="tab:orange", lw=2.0, label=r"$\mu_q(s_0)$"
    )
    axes[0, i].axvline(opt_value, c="g", ls=":", label=r"$v^{\pi^\star, p}(s_0)$")
    axes[0, i].set_title(rf"\textbf{{Episode {{{episode_to_record[i]}}}")
    axes[0, i].set_xticks([-15, 15])
    axes[0, i].set_xlim([-15, 15])
    axes[0, i].set_xticklabels([])

    # Second plot: ground-truth CDF versus our method's estimate
    N = 1
    mean_quantiles = np.mean(np.sort(theta[-N:, :, initial_state]), axis=0)
    axes[1][i].plot(
        np.sort(vf[i][:, initial_state]),
        np.linspace(0, 1, len(vf[i][:, initial_state]), endpoint=False),
        ls="--",
        lw=2.0,
        dashes=(5, 1),
        c="tab:blue",
    )
    axes[1][i].plot(mean_quantiles, tau, c="tab:orange", lw=2.0)
    axes[1, i].axvline(opt_value, c="g", ls=":")
    axes[1][i].set_xlabel(r"$V^{\pi^\star}(s_0)$", labelpad=-10)
    axes[1][i].set_xlim([-15, 15])
    axes[1][i].set_ylim(top=1.2)
    axes[1][i].set_xticks([-15, 15])

    # Third plot: wasserstein distance between our prediction and the true projection
    x_grad = np.linspace(0, 1, theta.shape[0])
    axes[2][i].plot(x_grad, wass_ipm[i], c="tab:red", ls="-", lw=1.0)
    axes[2][i].ticklabel_format(axis="x", style="sci", scilimits=(0, 0))
    axes[2, i].set_ylim(top=2, bottom=0)
    axes[2][i].set_xlabel(r"Gradient steps $(\times 10^4)$")

axes[0][0].legend(loc="lower center", ncol=3, bbox_to_anchor=(1.8, -3.0), frameon=False)

axes[0][0].set_ylabel("Prob." + "\n" + "density")
axes[1][0].set_ylabel("Cumulative" + "\n" + "Prob.")
axes[2][0].set_ylabel("QR error")

plt.show()

### Save figures

In [None]:
fig_dir = root_module.parent.joinpath("figures/tabular_nroom.pdf")
fig.savefig(fig_dir, bbox_inches="tight", transparent=False)

# License

>Copyright (c) 2024 Robert Bosch GmbH
>
>This program is free software: you can redistribute it and/or modify <br>
>it under the terms of the GNU Affero General Public License as published<br>
>by the Free Software Foundation, either version 3 of the License, or<br>
>(at your option) any later version.<br>
>
>This program is distributed in the hope that it will be useful,<br>
>but WITHOUT ANY WARRANTY; without even the implied warranty of<br>
>MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the<br>
>GNU Affero General Public License for more details.<br>
>
>You should have received a copy of the GNU Affero General Public License<br>
>along with this program.  If not, see <https://www.gnu.org/licenses/>.