In [None]:
import random

import numpy as np
import scipy.stats as stats

from dist_mbrl.envs.toy_mdp import ToyMDP1

# Set Seeds
seed = 0
np.random.seed(seed)
random.seed(seed)

### Functions to build posterior MDPs and compute values

In [None]:
# Define posterior over the three parameters defining the toy MRP
def build_posterior(alphas, beta):
    p = []
    r = []
    for alpha in alphas:
        mdp = ToyMDP1(alpha, beta)
        p.append(mdp.p)
        r.append(mdp.r)
    return np.array(p), np.array(r)


def compute_values(p_ensemble, r_ensemble, discount=0.99):
    num_models = p_ensemble.shape[0]
    vfs = []
    for i in range(num_models):
        p = p_ensemble[i]
        r = r_ensemble[i]
        vfs.append(np.linalg.inv(np.eye(p.shape[0]) - discount * p).dot(r))
    return np.stack([value for value in vfs], axis=0)

### Example posteriors to test

In [None]:
beta = 0.0
num_samples = 5000

alphas = []
# First: standard Gaussian
alphas.append(
    np.clip(stats.norm.rvs(loc=0.5, scale=0.1, size=num_samples), a_min=0.0, a_max=1.0)
)

# Second: multimodal distribution
alphas.append(
    np.concatenate(
        (
            np.clip(
                stats.norm.rvs(loc=0.3, scale=0.03, size=int(np.ceil(num_samples / 2))),
                a_min=0.0,
                a_max=1.0,
            ),
            np.clip(
                stats.norm.rvs(loc=0.6, scale=0.05, size=int(np.ceil(num_samples / 2))),
                a_min=0.0,
                a_max=1.0,
            ),
        ),
        axis=0,
    )
)

# third: heavy-tailed
alphas.append(
    np.concatenate(
        (
            np.clip(
                stats.norm.rvs(loc=0.3, scale=0.03, size=int(np.ceil(num_samples / 2))),
                a_min=0.0,
                a_max=1.0,
            ),
            np.clip(
                stats.norm.rvs(loc=0.5, scale=0.15, size=int(np.ceil(num_samples / 2))),
                a_min=0.0,
                a_max=1.0,
            ),
        ),
        axis=0,
    )
)

p_all, r_all, vf_true_all = [], [], []
for alpha in alphas:
    p, r = build_posterior(alpha, beta)
    p_all.append(p)
    r_all.append(r)
    vf_true_all.append(compute_values(p, r))

#### Value-Distributional Value Iteration

In [None]:
NUM_QUANTILES = 10
ENSEMBLE_SIZE = 1
NUM_STATES = p.shape[1]
tau = (2 * np.arange(NUM_QUANTILES) + 1) / (2.0 * NUM_QUANTILES)


def get_quantiles(values, num_quantiles):
    values = np.sort(values)
    quant_levels = (2 * np.arange(len(values)) + 1) / (2.0 * len(values))
    target_levels = (2 * np.arange(num_quantiles) + 1) / (2.0 * num_quantiles)
    idx = [np.abs(quant_levels - level).argmin() for level in target_levels]
    return values[idx]


def dist_value_iteration(
    alphas, beta, num_quantiles, max_iter=int(1e4), gamma=0.99, lr=5e-3, epsilon=1e-8
):
    # Random init guess of the value distribution function
    # Force terminal state to have a value of zero
    theta_i = 1 * np.sort(np.random.rand(num_quantiles, NUM_STATES), axis=0)
    theta_i[:, -1] *= 0

    theta_list = [theta_i]
    for _ in range(max_iter):
        # First: sample models from the posterior to estimate the gradient
        alphas_ens = np.random.choice(alphas, size=ENSEMBLE_SIZE)
        p, r = build_posterior(alphas_ens, beta)

        # Second: Compute the gradient of the quantile regression loss
        theta_j = np.expand_dims(r, axis=1) + gamma * np.einsum(
            "eij, mj -> emi", p, theta_i
        )
        theta_j = np.expand_dims(theta_j, axis=1)
        tmp = np.expand_dims(theta_i, axis=(0, 2))
        indicator_fn = (theta_j - tmp < 0).astype(float)
        grad_loss = np.expand_dims(tau, axis=(0, -2, -1)) - indicator_fn
        grad_loss = np.mean(grad_loss, axis=(0, 2))

        # Update the params by taking a step in the direction of the gradient
        new_theta_i = theta_i + lr * grad_loss
        theta_list.append(new_theta_i)
        if np.any(np.abs(theta_i - new_theta_i) > epsilon):
            theta_i = new_theta_i
        else:
            break

    return theta_i, theta_list


true_quantiles = []
pred_quantiles = []
for alpha, vf_true in zip(alphas, vf_true_all):
    true_quantiles.append(get_quantiles(vf_true[:, 0], NUM_QUANTILES))
    pred_quantiles.append(dist_value_iteration(alpha, beta, NUM_QUANTILES)[-1])

### Compare QR errors for different values of beta

In [None]:
betas = np.linspace(0, 1, 50)
wass_ipms = []
for alpha in alphas:
    wass_ipm = np.zeros_like(betas)
    for i, beta in enumerate(betas):
        p, r = build_posterior(alpha, beta)
        vf_true = compute_values(p, r)
        true_q = get_quantiles(vf_true[:, 0], NUM_QUANTILES)
        pred_q = dist_value_iteration(alpha, beta, NUM_QUANTILES)[-1]
        # Compute 1-Wasserstein metric between
        wass_ipm[i] = (1 / NUM_QUANTILES) * np.sum(np.abs(true_q - pred_q[-1][:, 0]))
    wass_ipms.append(wass_ipm)

### Plot

In [None]:
%matplotlib widget
import warnings

import matplotlib.pyplot as plt

from dist_mbrl.utils.plot import JMLR_PARAMS

warnings.filterwarnings("ignore", category=DeprecationWarning)

plt.rcParams.update(JMLR_PARAMS)

fig, axes = plt.subplots(
    nrows=3, ncols=3, figsize=(6.5, 4.5), gridspec_kw={"wspace": 0.30, "hspace": 0.7}
)
cmap = plt.get_cmap("tab10")
x = np.linspace(-1, 1, 1000)
for i, (vf_true, true_q, pred_q) in enumerate(
    zip(vf_true_all, true_quantiles, pred_quantiles)
):
    # Plot the value posterior
    kernel_true = stats.gaussian_kde(vf_true[:, 0])
    axes[0, i].plot(
        x, kernel_true.pdf(x), color="tab:blue", linewidth=2.0, label="true"
    )
    theta = np.array(pred_q)
    x_grad = np.linspace(0, 1, theta.shape[0])
    for j in range(NUM_QUANTILES):
        axes[1, i].plot(
            x_grad,
            true_q[j] - theta[:, j, 0],
            c=cmap(j),
            linestyle="-",
            lw=1.5,
            alpha=0.4,
        )
        axes[1, i].axhline(0.0, c="k", linestyle="--", lw=1.5)

    axes[0, i].set_xlabel(r"$V(s_0)$", labelpad=-0.5)
    axes[1, i].ticklabel_format(axis="x", style="sci", scilimits=(0, 0))
    axes[1, i].set_xlabel(r"Gradient steps $(\times 10^4)$")

for i, wass_ipm in enumerate(wass_ipms):
    axes[2, i].plot(betas, wass_ipm, linestyle="-", lw=1.5)
    axes[2, i].set_xlabel(r"$\beta$")

axes[0, 0].set_ylabel(r"$\mu(s_0)$" + "\n" + r"($\beta=0$)")
axes[1, 0].set_ylabel("Quantile error" + "\n" + r"($\beta=0$)")
axes[2, 0].set_ylabel("QR error")

plt.show()

### Save figures

In [None]:
from pathlib import Path

root_module = Path.cwd()
fig_dir = root_module.parent.joinpath("figures/tabular_eqr_performance.pdf")
fig.savefig(fig_dir, bbox_inches="tight", transparent=False)

# License

>Copyright (c) 2024 Robert Bosch GmbH
>
>This program is free software: you can redistribute it and/or modify <br>
>it under the terms of the GNU Affero General Public License as published<br>
>by the Free Software Foundation, either version 3 of the License, or<br>
>(at your option) any later version.<br>
>
>This program is distributed in the hope that it will be useful,<br>
>but WITHOUT ANY WARRANTY; without even the implied warranty of<br>
>MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the<br>
>GNU Affero General Public License for more details.<br>
>
>You should have received a copy of the GNU Affero General Public License<br>
>along with this program.  If not, see <https://www.gnu.org/licenses/>.