# `rlplay`-ing around with the Double-Dee QN

In [None]:
import torch
import numpy

import matplotlib.pyplot as plt
%matplotlib inline

import gym

See example.ipynb for the overview of `rlplay`

<br>

## Tabular CartPole with Q-learning

### The environment

A version of the Taxi Environment that disassembles the observation into pieces

In [None]:
from gym.spaces import Dict, Discrete


class StructuredTaxiEnv(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = Dict(dict(
            row=Discrete(5),
            col=Discrete(5),
            location=Discrete(5),
            destination=Discrete(4)
        ))

    @staticmethod
    def observation(n):
        # observation = r x c x p x d \in 5 x 5 x 5 x 4
        # n, d = divmod(observation, 4)
        # n, p = divmod(n, 5)
        # r, c = divmod(n, 5)
        d = n % 4
        n = n // 4
        p = n % 5
        n = n // 5
        c = n % 5
        r = n // 5
        return dict(row=r, col=c, location=p, destination=d)

A wrapper that makes Taxi's `.render` compatible with `core_evaluate`

In [None]:
class RenderPatch(gym.Wrapper):
    def render(self, mode='human'):
        result = self.env.render(mode)
        if result is None:
            return True

        return result

The environment factory

In [None]:
from gym.envs.toy_text import TaxiEnv


def base_factory(struct=False, seed=None):
    env = gym.make("Taxi-v3")
    if struct:
        env = StructuredTaxiEnv(env)
    return RenderPatch(env)

<br>

### the Actor

A procedure and a layer, which converts the input integer data into its
little-endian binary representation as float $\{0, 1\}^m$ vectors.

In [None]:
def onehotbits(input, n_bits=63, dtype=torch.float):
    """Encode integers to fixed-width binary floating point vectors"""
    assert not input.dtype.is_floating_point
    assert 0 < n_bits < 64  # torch.int64 is signed, so 64-1 bits max

    # n_bits = {torch.int64: 63, torch.int32: 31, torch.int16: 15, torch.int8 : 7}

    # get mask of set bits
    pow2 = torch.tensor([1 << j for j in range(n_bits)]).to(input.device)
    x = input.unsqueeze(-1).bitwise_and(pow2).to(bool)

    # upcast bool to float to get one-hot
    return x.to(dtype)


class OneHotBits(torch.nn.Module):
    def __init__(self, n_bits=63, dtype=torch.float):
        assert 1 <= n_bits < 64
        super().__init__()
        self.n_bits, self.dtype = n_bits, dtype

    def forward(self, input):
        return onehotbits(input, n_bits=self.n_bits, dtype=self.dtype)

A special module dictionary, which applies itself to the input dict of tensors

In [None]:
from typing import Optional, Mapping
from torch.nn import Module, ModuleDict as BaseModuleDict


class ModuleDict(BaseModuleDict):
    """The ModuleDict, that applies itself to the input dicts."""
    def __init__(
        self,
        modules: Optional[Mapping[str, Module]] = None,
        dim: Optional[int]=-1
    ) -> None:
        super().__init__(modules)
        self.dim = dim

    def forward(self, input):
        # enforce concatenation in the order of the declaration in  __init__
        return torch.cat([
            m(input[k]) for k, m in self.items()
        ], dim=self.dim)

A simple tabular q-learner actor

In [None]:
from rlplay.engine import BaseActorModule


class TaxiActor(BaseActorModule):
    def __init__(self, duelling=False, struct=False, epsilon=0.1):
        super().__init__()
        self.duelling = duelling

        # the table is actually an embedding of the state index
        n_dim = 6 + (1 if duelling else 0)
        if not struct:
            self.table = torch.nn.Embedding(500, n_dim)

        else:
            self.table = torch.nn.Sequential(
                ModuleDict(dict(
                    row=torch.nn.Embedding(5, 4),
                    col=torch.nn.Embedding(5, 4),
                    location=torch.nn.Embedding(5, 4),
                    destination=torch.nn.Embedding(4, 4),
                )),
                torch.nn.Linear(16, n_dim)
            )

        # for updating the exploration epsilon in the clones
        self.register_buffer('epsilon', torch.tensor(epsilon))

    def forward(self, obs, act, rew, fin, *, hx=None, stepno=None, virtual=False):
        qv, hx = self.table(obs), ()
        if self.duelling:
            # dueling net [Wang et al. (2016)](https://arxiv.org/abs/1511.06581)
            val, adv= torch.split_with_sizes(qv, [1, 6], dim=-1)
            qv = adv + (val - adv.mean(dim=-1, keepdim=True))

            val, actions = val.squeeze(-1), qv.max(dim=-1).indices

        else:
            val, actions = qv.max(dim=-1)

        if self.training:
            *head, n_actions = qv.shape
            actions = actions.where(
                torch.rand(head).gt(self.epsilon),
                torch.randint(n_actions, size=head))

        return actions, hx, dict(q=qv, value=val)

<br>

### D-DQN loss

Service functions for the algorithms

In [None]:
from plyr import apply, suply, xgetitem


def timeshift(state, *, shift=1):
    """Get current and shifted slices of nested objects."""
    # use `xgetitem` to let None through
    # XXX `curr[t]` = (x_t, a_{t-1}, r_t, d_t), t=0..T-H
    curr = suply(xgetitem, state, index=slice(None, -shift))

    # XXX `next[t]` = (x_{t+H}, a_{t+H-1}, r_{t+H}, d_{t+H}), t=0..T-H
    next = suply(xgetitem, state, index=slice(shift, None))

    return curr, next

Double DQN loss for contiguous trajectory fragments. 

* `.state[t+1].rew` -- $r_{t+1}$
* `.state[t+1].fin` -- $d_{t+1}$
* `.state[t+1].act` -- $a_t$ which caused $
    (s_t, z_t, a_t) \longrightarrow (s_{t+1}, z_{t+1}, r_{t+1}, d_{t+1})
$
* `_, _, fragment.actor[t] = actor(.state[t])` -- $
    q(z_t, h_t, \cdot; \theta_{\text{old}})
$ -- used for rollout collection
* `_, _, info_module[t] = module(.state[t])` -- $
    q(z_t, h_t, \cdot; \theta)
$ -- the current q-function, producing $(h_t)_{t=0}^{T+1}$
* `_, _, info_target[t] = target(.state[t])` -- $
    q(z_t, h_t, \cdot; \theta_-)
$ -- the q-target with $(h^-_t)_{t=0}^{T+1}$

The current q-network minimizes the $\mathrm{TD}(0)$-error is

$$
\delta_t(\theta)
    = \bigl(
        r_{t+1}
        + \gamma 1_{\{\neg d_{t+1}\}} v^*(z_{t+1})
    \bigr) - q(z_t, h_t, a_t; \theta)
    \,, $$

where the approximate state value estimate $
    v^*(z_{t+1})
$ is one of
* `Q-learning`: $
    \max_a q(z_{t+1}, h_{t+1}, a; \theta)
$
* `DQN`: $
    \max_a q(z_{t+1}, h_{t+1}, a; \theta_-)
$
* `double DQN`: $
    q(z_{t+1}, h_{t+1}, \hat{a}_{t+1}; \theta_-)
$ for $
    \hat{a}_{t+1} = \arg \max_a Qq(z_{t+1}, h_{t+1}, a; \theta)
$

One of the works related to DQN learning of recurrent agents is [Kapturowski et al. (2018)](https://openreview.net/forum?id=r1lyTjAqYX), who propose to use a burn-in period in
the contiguous trajectory fragment in order to compensate for the representation drift,
due to different RNN parameters $\theta_-$ and $\theta$.

There is no clear-cut evidence suggesting that the hidden recurrent sequences $h_t$
and $h^-_t$ yield significantly different results.

In [None]:
import torch.nn.functional as F


# @torch.enable_grad()
def ddq_learn(fragment, module, *, gamma=0.95, target=None, double=False):
    r"""Compute the Double-DQN loss over a _contiguous_ fragment of a trajectory.

    Details
    -------
    In Q-learning the action value function minimizes the TD-error

    $$
        r_{t+1}
            + \gamma 1_{\neg d_{t+1}} v^*(z_{t+1})
            - q(z_t, a_t; \theta)
        \,, $$

    w.r.t. Q-network parameters $\theta$ where $z_t$ is the actionable state,
    $r_{t+1}$ is the reward for $s_t \to s_{t+1}$ transition. The value of
    $z_t$ include the current observation $x_t$ and the recurrent state $h_t$,
    the last action $a_{t-1}$, the last reward $r_t$, and termination flag
    $d_t$.

    In the classic Q-learning there is no target network and the next state
    optimal state value function is bootstrapped using the current Q-network
    (`module`):

    $$
        v^*(z_{t+1})
            \approx \max_a q(z_{t+1}, a; \theta)
        \,. $$

    The DQN method, proposed by

        [Minh et al. (2013)](https://arxiv.org/abs/1312.5602),

    uses a secondary Q-network (`target`) to estimate the value of the next
    state:

    $$
        v^*(z_{t+1})
            \approx \max_a q(z_{t+1}, a; \theta^-)
        \,, $$

    where $\theta^-$ are the frozen parameters of the target Q-network. The
    Double DQN algorithm of

        [van Hasselt et al. (2015)](https://arxiv.org/abs/1509.06461)

    unravels the $\max$ operator as
    $
        \max_k u_k \equiv u_{\arg \max_k u_k}
    $
    and replaces the outer $u$ with the Q-values of the target Q-network, while
    computing the inner $u$ (inside the $\arg\max$) with the current Q-network.
    Specifically, the Double DQN value estimate is

    $$
        v^*(z_{t+1})
            \approx q(z_{t+1}, \hat{a}_{t+1}; \theta^-)
            \,,
            \hat{a}_{t+1}
                = \arg \max_a q(z_{t+1}, a; \theta)
        \,, $$

    for $
        \hat{a}_{t+1}
            = \arg \max_a q(s_{t+1}, a; \theta)
    $ being the action taken by the current Q-network $\theta$ at $z_{t+1}$.

    Recurrent DQN
    -------------
    The key problem with the recurrent state $h_t$ in $z_t$ is its representation
    drift: the endogenous states used for collecting trajectory data during the
    rollout are produced by an actor with stale parameters $\theta_{\text{old}}$,
    and thus might have high discrepancy with the recurrent state produced by
    the current Q-network $\theta$ or the target $\theta-_$. To mitigate this
        
        [Kapturowski et al. (2018)](https://openreview.net/forum?id=r1lyTjAqYX)
    
    proposed to spend a slice `burn-in` of the recorded trajectory on
    aligning the recurrent representation. Specifically, starting with $h_0$
    (contained in `fragment.hx`) they propose to launch two sequences $h_t$
    and $h^-_t$ from the same $h^-_0 = h_0$ using $q(\cdot; \theta)$ and
    $q(\cdot; \theta^-)$, respectively.
    """

    trajectory, hx = fragment.state, fragment.hx
    obs, act, rew, fin = trajectory.obs, trajectory.act, trajectory.rew, trajectory.fin

    # get $Q(z_t, h_t, \cdot; \theta)$ for all t=0..T
    _, _, info_module = module(
        obs, act, rew, fin, hx=hx, stepno=trajectory.stepno)

    # get the next state `state[t+1]` $z_{t+1}$ to access $a_t$
    state_next = suply(xgetitem, trajectory, index=slice(1, None))

    # $\hat{A}_t$, the module's response to current and next state,
    #  contains the q-values. `curr` is $q(z_t, h_{t+1}, \cdot; \theta)$
    #  and `next` is $q(z_{t+1}, h_{t+1}, \cdot; \theta)$ is `next`.
    info_module_curr, info_module_next = timeshift(info_module)

    # get $q(z_t, h_t, a_t; \theta)$ for all t=0..T-1
    q_replay = info_module_curr['q'].gather(-1, state_next.act.unsqueeze(-1))

    # get $\hat{v}_{t+1}(z_{t+1}) = ...$
    with torch.no_grad():
        if target is None:
            # get $... = \max_a Q(z_{t+1}, h_{t+1}, a; \theta)$
            q_value = info_module_next['q'].max(dim=-1, keepdim=True).values

        else:
            _, _, info_target = target(
                obs, act, rew, fin, hx=hx, stepno=trajectory.stepno)

            info_target_next = suply(xgetitem, info_target, index=slice(1, None))
            if not double:
                # get $... = \max_a Q(z_{t+1}, h^-_{t+1}, a; \theta^-)$
                q_value = info_target_next['q'].max(dim=-1, keepdim=True).values

            else:
                # get $\hat{a}_{t+1} = \arg \max_a Q(z_{t+1}, h_{t+1}, a; \theta)$
                hat_act = info_module_next['q'].max(dim=-1).indices.unsqueeze(-1)

                # get $... = Q(z_{t+1}, h^-_{t+1}, \hat{a}_{t+1}; \theta^-)$
                q_value = info_target_next['q'].gather(-1, hat_act)

        # get $r_{t+1} + \gamma 1_{d_{t+1}} \hat{v}_{t+1}(z_{t+1})$ using inplace ops
        q_value.masked_fill_(state_next.fin.unsqueeze(-1), 0.)
        q_value.mul_(gamma).add_(state_next.rew.unsqueeze(-1))

    # td-error ell-2 loss
    return F.mse_loss(q_replay, q_value, reduction='sum'), {}

<br>

The following is my incomplete take on [Bellemare et al. (2017)](http://proceedings.mlr.press/v70/bellemare17a.html)

Then the distributional Bellman operator on a policy $\pi_\theta$ at $(s, a)$ is
defined as a random variable on $\mathbb{R}$ with the following law

$$
(T Q_\theta)(s, a)
    \overset{D}{=}
        r + \gamma Q_\theta(s', a')
    \,, \text{ for }
    (r, s') \sim p(r, s'\mid s, a)
    \,, \text{ and }
    a' \sim \pi_\theta(a \mid s')
    \,.
$$

Formally, this means the following (omitting the dependency of on $(s, a)$):

\begin{align}
  \mathbb{P}\bigl(
    T Q_\theta \in U
  \bigr)
    &= \mathbb{P}\bigl(
      R + \gamma Q_\theta(S', A') \in U
    \bigr)
    \\
    &= \mathbb{E}_{r, s'\sim p(r, s'\mid s, a)}
        \mathbb{E}_{a'\sim \pi_\theta(a \mid s')}
          \mathbb{P}\bigl(
            R + \gamma Q_\theta(S', A') \in U
            \,\big \vert\, S'=s, R=r, A'=a'
          \bigr)
    \\
    &= \mathbb{E}_{r, s'\sim p(r, s'\mid s, a)}
        \mathbb{E}_{a'\sim \pi_\theta(a \mid s')}
          \mathbb{P}\bigl(
            Q_\theta(s', a') \in \frac{U - r} \gamma
          \bigr)
    \,,
\end{align}

or, in other words, $T$ acts on the conditional distribution $
    Q_\theta(\cdot \mid s, a)
$ thus: for any bounded Borel measurable $f$

\begin{align}
  \int (T Q_\theta)(dv \mid s, a) f(v)
    &= \int
        p(ds', dr\mid s, a)  % clump p and R into one cond-distrib
        \pi_\theta(da' \mid s')
        Q_\theta(dv\mid s', a')
        f(r + \gamma v)
    \\
    &=  % Fubini
        \mathbb{E}_{r, s'\sim p(r, s'\mid s, a)}
        \mathbb{E}_{a'\sim \pi_\theta(a \mid s')}
        \mathbb{E}_{v\sim Q_\theta(v\mid s', a')}
            f(r + \gamma v)
    \,.
\end{align}

Thus suggests several approximations, depending on the tractability of the inner expectations.

Ideally we would like to find such distribution, that
$$
(T Q_\theta)(s, a)
    \overset{D}{=}
    Q_\theta(s, a)
    \,, $$
for all or most $s$ and $a$.

If $f$ is $L$-Lipschitz, then
\begin{align}
\biggl\lvert
    \int \bigl(
        (T H)(dv \mid s, a) - (T G)(dv \mid s, a)
        \bigr) f(v)
    \biggr\rvert
    &\leq \int
        p(ds', dr\mid s, a)
        \pi_\theta(da' \mid s')
        H(dh\mid s', a')
        G(dg \mid s', a')
        \bigl\lvert
            f(r + \gamma h) - f(r + \gamma g)
        \bigr\rvert
    \\
    &\leq L \gamma \int
        p(ds', dr\mid s, a)
        \pi_\theta(da' \mid s')
        H(dh \mid s', a')
        G(dg \mid s', a')
        \bigl\lvert h - g \bigr\rvert
    \\
    &\leq L \gamma \int
        p(ds', dr\mid s, a)
        \pi_\theta(da' \mid s')
        \int
            H(dh \mid s', a')
            G(dg \mid s', a')
            \bigl\lvert h - g \bigr\rvert
    \\
    &\leq L \gamma \sup_{s, a}
        \int
            H(dh \mid s, a)
            G(dg \mid s, a)
            \bigl\lvert h - g \bigr\rvert
    \,,
\end{align}

so there is hope for contraction or even convergence, however, it is better to
consult [Bellemare et al. (2017)](http://proceedings.mlr.press/v70/bellemare17a.html)
first.

<!--
$$
\biggl\lvert
    \int \bigl(
        (T Q_\theta)(dv \mid s, a) - Q_\theta(dv \mid s, a)
        \bigr) f(v)
    \biggr\rvert
    \leq \int
        p(ds', dr\mid s, a)
        \pi_\theta(da' \mid s')
        Q_\theta(dv'\mid s', a')
        Q_\theta(dv\mid s, a)
        \bigl\lvert
            f(r + \gamma v') - f(v)
        \bigr\rvert
    \,. $$
-->

Below is the loss for a version of the distributional DQN, specific to Gaussian approximations
and maxent regularization.

We minimize (in expectation over transition $s, a \to s'$) the Kullback-Leibler
divergence between distributions $
    (T Q_\theta)(dv \mid s, a)
$ and $
    Q_\theta(dv \mid s, a)
$, while at the same time encouraging $
    Q_\theta(dv \mid s, a)
$ to exhibit high entropy. The actions $a' \sim \pi_\theta(a\mid s')$
are sampled greedily as $
    \arg \max_a \int Q_\theta(dv \mid s', a) v
$ (assuming finite action space).

In [None]:
def gdqn(fragment, module, *, gamma, C=1e-3):
    trajectory, hx = fragment.state, fragment.hx
    obs, act, rew, fin = trajectory.obs, trajectory.act, trajectory.rew, trajectory.fin

    # get $Q(z_t, h_t, \cdot; \theta)$ for all t=0..T
    act, _, info_module = module(
        obs, act, rew, fin, hx=hx, stepno=trajectory.stepno)

    # get the next state `state[t+1]` $z_{t+1}$ to access $a_t$
    state_next, act = suply(xgetitem, (trajectory, act.unsqueeze(-1)), index=slice(1, None))

    # $\hat{A}_t$, the module's response to current and next state,
    #  contains the q-values. `curr` is $q(z_t, h_{t+1}, \cdot; \theta)$
    #  and `next` is $q(z_{t+1}, h_{t+1}, \cdot; \theta)$ is `next`.
    info_module_curr, info_module_next = timeshift(info_module)

    # get T Z_\theta(\cdot \mid z_t, a_t)
    #  \overset{=}{D} N(\cdot \mid r_{t+1} + \gamma 1_{d_{t+1}} \hat{v}_{t+1}(z_{t+1}), ...)
    # use action a' \sim \arg \max_a z_\theta(a \mid z_{t+1}, a_{t+1})
    factor = torch.where(state_next.fin, 1e-4, gamma)  # XXX some leakage from the next traj.
    with torch.no_grad():
        # get $r_{t+1} + \gamma 1_{d_{t+1}} \hat{v}_{t+1}(z_{t+1})$
        zed = info_module_next['loc'].gather(-1, act).squeeze(-1)
        loc = state_next.rew + factor * zed
        # XXX should the termination does not affect the scale?
        scl = factor * info_module_next['scl'].gather(-1, act).squeeze(-1)
        q = Normal(loc, scl)

    # get Z_\theta(\cdot \mid s_t, a_t)
    act = state_next.act.unsqueeze(-1)
    p = Normal(
        info_module_curr['loc'].gather(-1, act).squeeze(-1),
        info_module_curr['scl'].gather(-1, act).squeeze(-1),
    )
    
    # distance from T z_\theta to z_\theta on s-a-r-s' samples
    dst = kl_divergence(q, p).mean()  # q is fixed target.
    # XXX fwd kl means covering, not mode-seeking.

    # use kl from N(0, 1) prior, but can we also use entropy?
    # pi = Normal(*torch.tensor([0., 1.])).expand(loc.shape)
    # kl_reg = kl_divergence(p, pi).mean()
    reg = -p.entropy().mean()

    return dst + C * reg, {
        'dst': float(dst),
        'reg': float(reg),
    }

Let's define an actor for the distributional Q-learning with the Gaussian approximation:

$$
Q_\theta(
    v \mid z, a
    ) = \mathcal{N}\bigl(
        v \,\big\vert \,
        \mu(z, a),
        \sigma^2(z, a)
    \bigr)
    \,. $$

The greedy action sampling must take into account the distribution of $
    j^* = \arg\max_i z_i
$ for a collection of independent $
    z_i \sim \mathcal{N}(Z_i \mid \mu_i, \sigma^2_i)
$, not just the means. Specifically, we must make sure that $j^*=j$ with probability

$$
\mathbb{P}(Z_j \geq  \max_{i\neq j} Z_i)
    = \mathbb{E}
        \mathbb{P}(Z_j \geq \max_{i\neq j} Z_i \big \vert Z_j)
    = \mathbb{E}_{\xi \sim \mathcal{N}(0, 1)}
        \prod_{i\neq j} \Phi\biggl(
            \frac{\mu_j - \mu_i}{\sigma_i}
            + \frac{\sigma_j}{\sigma_i} \xi
        \biggr)
    \,, $$
    
wherein we used independence and the fact that the max of independent Gaussian rvs satisfies

$$
\mathbb{P}(\max_i Z_i \leq c)
    = \prod_i \mathbb{P}(Z_i \leq c)
    = \prod_i \Phi\biggl(
        \frac{c - \mu_i}{\sigma_i}
    \biggr)
    \,, $$

for $\Phi$ -- the univariate standard Gaussian CDF.

But this is hard to compute... So we use the largest mean heuristic.
Maybe we can use extreme value asymptotics here?

In [None]:
from rlplay.engine import BaseActorModule
from torch.distributions import Independent, Normal, kl_divergence


class GaussTaxiActor(BaseActorModule):
    def __init__(self, *, struct=False, epsilon=0.1):
        super().__init__()
        self.struct = struct

        if not struct:
            self.table = torch.nn.Embedding(500, 2 * 6)

        else:
            self.table = torch.nn.Sequential(
                ModuleDict(dict(
                    row=torch.nn.Embedding(5, 3),
                    col=torch.nn.Embedding(5, 3),
                    location=torch.nn.Embedding(5, 3),
                    destination=torch.nn.Embedding(4, 3),
                )),
                torch.nn.Linear(12, 32),
                torch.nn.ReLU(),
                torch.nn.Linear(32, 2 * 6),
            )

        # for updating the exploration epsilon in the clones
        self.register_buffer('epsilon', torch.tensor(epsilon))

    def forward(self, obs, act, rew, fin, *, hx=None, stepno=None, virtual=False):
        out, hx = self.table(obs), ()

        loc, out = torch.chunk(out, 2, dim=-1)
        scl = torch.clamp(F.softplus(out), min=1e-3)

        # val, act = Normal(loc, scl).sample().max(dim=-1)
        val, act = loc.max(dim=-1)  # unfounded heuristic :(

        return act, hx, dict(loc=loc, scl=scl, q=loc, value=val)

<br>

### Run!

prepare the optimizer for the learner

In [None]:
gamma = 0.99
use_dist = True
use_target = False
use_double = False
use_duelling = False

# `target` does not work for some reason at all with taxi, maybe the freeze schedule is off?
#  or contiguous fragments work to the detriment of learning

# `duelling` also fails for both target and double, and sorta for ordinary q

Initialize the learner and the environment factories

In [None]:
from functools import partial


structured = False

factory_eval = partial(base_factory, struct=structured)
factory = partial(base_factory, struct=structured)

if use_dist:
    learner = GaussTaxiActor(struct=structured)
else:
    learner = TaxiActor(duelling=use_duelling, epsilon=1., struct=structured)

learner.train()
device_ = torch.device('cpu')  # torch.device('cuda:0')
learner.to(device=device_)

optim = torch.optim.Adam(learner.parameters(), lr=1e-2)

Initialize the sampler

In [None]:
T, B = 25, 8

In [None]:
from rlplay.engine.rollout import multi

batchit = multi.rollout(
    factory,
    learner,
    n_steps=T,
    n_actors=8,
    n_per_actor=B,
    n_buffers=16,
    n_per_batch=2,
    sticky=False,
    pinned=False,
    clone=True,
    close=False,
    device=device_,
    start_method='fork',  # fork in notebook for macos, spawn in linux
)

A generator of evaluation rewards

In [None]:
from rlplay.engine.rollout.evaluate import evaluate

test_it = evaluate(factory_eval, learner, n_envs=4, n_steps=500,
                   clone=False, device=device_, start_method='fork')

Implement your favourite training method

In [None]:
import tqdm
import copy
from math import log, exp
from torch.nn.utils import clip_grad_norm_

torch.set_num_threads(1)

# the training loop
losses, rewards = [], []
decay = -log(2) / 50  # exploration epsilon halflife
for epoch in tqdm.tqdm(range(400)):
    # freeze the target for q
    target = copy.deepcopy(learner) if use_target else None

    for j, batch in zip(range(100), batchit):
        if use_dist:
            loss, info = gdqn(batch, learner, gamma=gamma)

        else:
            loss, info = ddq_learn(batch, learner, target=target,
                                   gamma=gamma, double=use_double)


        optim.zero_grad()
        loss.backward()
        grad = clip_grad_norm_(learner.parameters(), max_norm=1e3)
        optim.step()

        losses.append(dict(
            loss=float(loss), grad=float(grad), **info
        ))

    learner.epsilon.mul_(exp(decay)).clip_(0.1, 1.0)

    # fetch the evaluation results lagged by one inner loop!
    rewards.append(next(test_it))

In [None]:
# close the generators
batchit.close()
test_it.close()

<br>

In [None]:
def collate(records):
    """collate identically keyed dicts"""
    out, n_records = {}, 0
    for record in records:
        for k, v in record.items():
            out.setdefault(k, []).append(v)
    
    return out


data = {k: numpy.array(v) for k, v in collate(losses).items()}

In [None]:
plt.semilogy(data['loss'])

In [None]:
plt.semilogy(data['grad'])

In [None]:
if 'dst' in data:
    plt.semilogy(data['dst'])

In [None]:
if 'reg' in data:
    plt.semilogy(data['reg'])

In [None]:
rewards = numpy.stack(rewards, axis=0)

In [None]:
rewards

In [None]:
m, s = numpy.median(rewards, axis=-1), rewards.std(axis=-1)

In [None]:
fi, ax = plt.subplots(1, 1, figsize=(4, 2), dpi=300)

ax.plot(numpy.mean(rewards, axis=-1))
ax.plot(numpy.median(rewards, axis=-1))
ax.plot(numpy.min(rewards, axis=-1))
ax.plot(numpy.std(rewards, axis=-1))
# ax.plot(m+s * 1.96)
# ax.plot(m-s * 1.96)

plt.show()

<br>

The ultimate evaluation run

In [None]:
from rlplay.engine import core

with factory_eval() as env:
    learner.eval()
    eval_rewards, info = core.evaluate([
        env
    ], learner, render=True, n_steps=1e2, device=device_)

print(sum(eval_rewards))

<br>

Let's analyze the performance

In [None]:
(info['q'] - numpy.expand_dims(info['value'], -1)).shape

In [None]:
plt.plot((info['q'] - numpy.expand_dims(info['value'], -1))[:, 0])

In [None]:
from rlplay.algo.returns import npy_returns, npy_deltas

td_target = eval_rewards + gamma * info['value'][1:]
td_error = td_target - info['value'][:-1]
# td_error = npy_deltas(
#     eval_rewards, numpy.zeros_like(eval_rewards, dtype=bool), info['value'][:-1],
#     gamma=gamma, bootstrap=info['value'][-1])

fig, ax = plt.subplots(1, 1, figsize=(4, 2), dpi=300)
ax.semilogy(abs(td_error) / abs(td_target))
ax.set_title('relative td(1)-error');

In [None]:
from rlplay.algo.returns import npy_returns, npy_deltas

# plt.plot(
#     npy_returns(eval_rewards, numpy.zeros_like(eval_rewards, dtype=bool),
#                 gamma=gamma, bootstrap=info['value'][-1]))
fig, ax = plt.subplots(1, 1, figsize=(4, 2), dpi=300)
ax.plot(info['value']);

In [None]:
with torch.no_grad():
    obs = torch.arange(500).unsqueeze(1)
    if structured:
        obs = StructuredTaxiEnv.observation(obs)

    _, _, info = learner(obs, act=None, rew=None, fin=None)    
    if 'loc' in info:
        # shuffle dims to `a d p r c`
        loc = info['loc'].reshape(5, 5, 5, 4, -1).permute(4, 3, 2, 0, 1)
        scl = info['scl'].reshape(5, 5, 5, 4, -1).permute(4, 3, 2, 0, 1)
        p = Normal(loc, scl)
        ent = p.entropy()

    else:
        loc = info['q'].reshape(5, 5, 5, 4, -1).permute(4, 3, 2, 0, 1)

In [None]:
d = 'R'
dst = {'R': 0, 'G': 1, 'Y': 2, 'B': 3}

# +---------+
# |R: | : :G|
# | : | : : |
# | : : : : |
# | | : | : |
# |Y| : |B: |
# +---------+

x = loc[:, dst[d]].permute(1, 2, 0, 3)  # .reshape(5 * 5, 6 * 5)
plt.imshow(torch.cat([
    x, x[:, :, :4].max(2, keepdim=True).values,
], dim=2).reshape(5 * 5, 7 * 5))

# cols: action snewpd, rows: location RGYBT
plt.xlabel('action + pool')
plt.xticks(list(range(2, 35, 5)), [*'snewpd', 'max'])
plt.ylabel('location')
plt.yticks(list(range(2, 25, 5)), list('RGYBT'))
plt.title(f'destination: {d}');

<br>

In [None]:
assert False