# `rlplay`-ing around

In [None]:
import torch
import numpy

import matplotlib.pyplot as plt
%matplotlib inline

<br>

## Rollout collection

Rollout collection is designed to be as much `plug-n-play` as possible, i.e. it
supports **arbitrarily structured nested containers** of arrays or tensors for
environment observations and actions. The actor, however, should **expose**
certain API (described below).

In [None]:
from rlplay.engine import collect  # the collector's core

# print(collect.__doc__)

It's role is to serve as a *middle-man* between the **actor-environment** pair
and the **training loop**: to track the trajectory of the actor in the environment,
and properly record it into the data buffer.

For example, it is not responsible for seeding or randomization of environments
(i'm looking at you, `AtariEnv`), and datatype casting (except for rewards,
which are cast to `fp32` automatically). In theory, there is **no need** for
special data preporcessing, except for, perhaps, casting data to proper dtypes,
like from `numpy.float64` observations to `float32` in `CartPole`.

#### Semantics

The collector just carefully records the trajectory by alternating between
the **REACT** and **STEP+EMIT** phases in the following fashion:

$$
    \cdots
        \longrightarrow t
        \overset{\mathrm{REACT}}{\longrightarrow} t + \tfrac12
        \overset{\mathrm{STEP+EMIT}}{\longrightarrow} t + 1
        \longrightarrow \cdots
    \,, $$

where the half-times $t + \tfrac12$ are commonly reffered to as the `afterstates`:
the actor has chosen an action in response to the current observation, yet has
not interacted with the environment.

So the `time` advances in halves, and the proper names for the half times
in the diagram above are the `state`, the `afterstate` and the `next state`,
respectively.

The collected `fragment` data has the following structure:
* `.state` $z_t$ **the current "extended" observation**
  * `.stepno` $n_t$ the step counter
  * `.obs` $x_t$ **the current observation** emitted by transitioning to $s_t$
  * `.act` $a_{t-1}$ **the last action** which caused $s_{t-1} \longrightarrow s_t$
  * `.rew` $r_t$ **the previous reward** received by getting to $s_t$
  * `.fin` $d_t$ **the termination flag** indicating if $s_t$ is terminal in the env

* `.actor` $A_t$ auxiliary data from the actor due to **REACT**

* `.env` $E_{t+1}$ auxiliary data from the environment due to **STEP+EMIT**

* `.hx` $h_0$ the starting recurrent state of the actor

Here $s_t$ denotes **the unobserved true full state** of the environment.

The actor $\theta$ interacts with the environment and generates the following
<span style="color:orange">**tracked**</span> data during the rollout,
unobserved/non-tracked data <span style="color:red">**in red**</span> and
$t = 0..T-1$:

*  ${\color{orange}{h_0}}$, the starting recurrent state, is recorded in $\,.\!\mathtt{hx}$

* **REACT**: the actor performs the following update ($t \to t + \frac12$)

$$ 
    \bigl(
        \underbrace{
            .\!\mathtt{state}[\mathtt{t}]
        }_{{\color{orange}{z_t}}},\,
        {\color{red}{h_t}}
    \bigr)
        \overset{\text{Actor}_{\theta_{\text{old}}}}{\longrightarrow}
        \bigl(
            \underbrace{
                .\!\mathtt{state}.\!\mathtt{act}[\mathtt{t+1}]
            }_{a_t \leadsto {\color{orange}{z_{t+1}}}},\,
            \underbrace{
                .\!\mathtt{actor}[\mathtt{t}]
            }_{{\color{orange}{A_t}}},\,
            {\color{red}{h_{t+1}}}
        \bigr)
    \,, $$

* **STEP+EMIT**: the environment updates it's unobserved state and emits
the observed data ($t + \frac12 \to t+1_-$)

$$ 
    \bigl(
        {\color{red}{s_t}},\,
        \underbrace{
            .\!\mathtt{state}.\!\mathtt{act}[\mathtt{t+1}]
        }_{a_t \leadsto {\color{orange}{z_{t+1}}}}
    \bigr)
        \overset{\text{Env}}{\longrightarrow}
        \bigl(
            {\color{red}{s_{t+1}}},\,
            \underbrace{
                .\!\mathtt{state}.\!\mathtt{obs}[\mathtt{t+1}]
            }_{x_{t+1} \leadsto {\color{orange}{z_{t+1}}}},\,
            \underbrace{
                .\!\mathtt{state}.\!\mathtt{rew}[\mathtt{t+1}]
            }_{r_{t+1} \leadsto {\color{orange}{z_{t+1}}}},\,
            \underbrace{
                .\!\mathtt{state}.\!\mathtt{fin}[\mathtt{t+1}]
            }_{d_{t+1} \leadsto {\color{orange}{z_{t+1}}}},\,
            \underbrace{
                .\!\mathtt{env}[\mathtt{t}]
            }_{{\color{orange}{E_{t+1}}}}
        \bigr)
    \,, $$

* collect loop ($t + 1_- \to t+1$)

$$ 
    \bigl(
        {\color{orange}{n_t}},\,
        {\color{orange}{d_{t+1}}}
    \bigr)
        \longrightarrow
            \underbrace{
                .\!\mathtt{state}.\!\mathtt{stepno}[\mathtt{t+1}]
            }_{n_{t+1} \leadsto {\color{orange}{z_{t+1}}}}
    \,. $$

Here $r_t$ is a scalar reward, $d_t = \top$ if $s_t$ is terminal, or $\bot$
otherwise, $n_{t+1} = 0$ if $d_t = \top$, else $1 + n_t$, and $a \leadsto b$
means $a$ being recored into $b$.

In general, we may treat $z_t$, the extended observation, as an ordinary
observation, by **suitably modifying** the environment: we can make it
recall the most recent action $a_{t-1}$ and compute the termination indicator
$d_t$ of the current state, and let it keep track of the interaction counter
$n_t$, and, finally, we can configure it to supply the most recent reward
$r_t$ as part of the emitted observation.

Hence we essentially consider the following POMDP setup:
\begin{align}
    a_t, h_{t+1}, A_t
        &\longleftarrow \operatorname{Actor}(z_t, h_t; \theta)
        \,, \\
    z_{t+1}, r_{t+1}, E_{t+1}, s_{t+1}
        &\longleftarrow \operatorname{Env}(s_t, a_t)
        \,, \\
\end{align}

Specifically, let $
(z_t)_{t=0}^T
    = (n_t, x_t, a_{t-1}, r_t, d_t)_{t=0}^T
$ be the trajectory fragment in `.state`, and $h_0$, `.hx`, be the starting
(not necessarily the initial) recurrent state of the actor at the begining
of the rollout.

##### Requirements

* all nested containers **must be** built from pure python `dicts`, `lists`, `tuples` or `namedtuples`

* the environment communicates either in **numpy arrays** or in python **scalars**, but not in data types that are incompatible with pytorch (such as `str` or `bytes`)

```python
# example
obs = {
    'camera': {
        'rear': numpy.zeros(3, 320, 240),
        'front': numpy.zeros(3, 320, 240),
    },
    'proximity': (+0.1, +0.2, -0.1, +0.0,),
    'other': {
        'fuel_tank': 78.5,
        'passenger': False,
    },
}
```

* the actor communicates in torch tensors **only**

* the environment produces **float scalar** rewards (other data may be communicated through auxiliary environment info-dicts)

### Creating the actors

Rollout collection relies on the following API of the actor:
* `.reset(j, hx)` reset the recurrent state of the j-th environment in the batch (if applicable)
  * `hx` contains tensors with shape `(n_lstm_layers * n_dir) x batch x hidden`, or is an empty tuple
  * the returned `hx` is the updated recurrent state


* `.step(stepno, obs, act, rew, fin, /, *, hx, virtual)` get the next action $a_t$, the recurrent state $h_{t+1}$, and
the **extra info** in response to $x_t$, $a_{t-1}$, $r_t$, $d_t$, $h_t$, and $n_t$, respectively.
  * extra info `dict` **might** include `value` key with a `T x B` tensor of state value estimates $
      v(s_t) \approx G_t = \mathbb{E} \sum_{j\geq t} \gamma^{j-t} r_{j+1}
  $.
  * MUST allocate new `hx` if the recurrent state is updated
  * MUST NOT change the inputs in-place


In [None]:
from rlplay.engine import BaseActorModule

# BaseActorModule.step??

`BaseActorModule` is essentially a thin sub-class of `torch.nn.Module`, that implements
the API through `.forward(obs, act, rew, fin, *, hx, stepno)`, which should return three things:

1. `actions` prescribed actions in the environment, with data of shape `n_steps x batch x ...`
  * can be a nested container of dicts, lists, and tuples


2. `hx` data with shape `n_steps x batch x ...`
  * can be a nested container of dicts, lists, and tuples
  * **if an actor is not recurrent**, then must return an empty container, e.g. a tuple `()`


3. `info` object, which might be a tensor or a nested object containing data in tensors
`n_steps x batch x ...`. For example, one may communicate the following data:
  * `value` -- the state value estimates $v(z_t)$
  * `logits` -- the policy logits $\log \pi(\cdot \mid z_t)$
  * `q` -- $Q(z_t, \cdot)$ values

Here is an example actor, that wraps a simple MLP policy.

In [None]:
from rlplay.utils.common import multinomial

class nonRecurrentPolicyWrapper(BaseActorModule):
    """Example wrapper for a non-recurrent policy.
    
    Details
    -------
    This example assumes flat `Discrete(n)` action space, and
    simple non-structured observation space, e.g. a python scalar
    or a `numpy.array`.
    """

    def __init__(self, policy, *, epsilon=0.1):
        super().__init__()
        self.policy, self.epsilon = policy, epsilon

    def forward(self, obs, act=None, rew=None, fin=None,
                *, hx=None, stepno=None, virtual=False):
        # Everything is  [T x B x ...]
        logits, hx = self.policy(obs, act, rew), ()

        # XXX eps-greedy?
        if self.training:
            unif = torch.tensor(1. / logits.shape[-1])

            prob = logits.detach().exp()
            prob.mul_(1 - self.epsilon)
            prob.add_(unif, alpha=self.epsilon)

            actions = multinomial(prob)

        else:
            actions = logits.argmax(dim=-1)

        # return just policy logits, for example, for REINFORCE algo
        return actions, hx, dict(logits=logits)

<br>

### Manual rollout collection

We shall need the following procedures from the core of the engine:

In [None]:
from rlplay.engine.core import prepare, startup, collect

Manual collection requires an `actor` and a batch of environment instances `envs`.

Prepare the run-time context for the specified `actor` and the environments
```python
# settings
sticky = False  # whether to stop interacting if an env resets mid-fragment
device = None   # specifies the device to put the actor's inputs and data onto
pinned = False  # whether to keep the running context in non-resizable pinned
                #  (non-paged) memory for faster host-device transfers

# initialize a buffer for one rollout fragment
buffer = prepare(envs[0], actor, n_steps, len(envs),
                 pinned=False, device=device)

# the running context tor the actor and the envs (optionally pinned)
ctx, fragment = startup(envs, actor, buffer, pinned=pinned)

while not done:
    # collect the fragment
    collect(envs, actor, fragment, ctx, sticky=sticky, device=device)

    # fragment.pyt -- torch tensors, fragment.npy -- numpy arrays (aliased on-host)
    do_stuff(actor, fragment.pyt)
```

<br>

### Rollout collection (same-process)

Collect rollouts within the current process

In [None]:
from rlplay.engine.rollout import same

The parameters have the following meaning
```python
it = same.rollout(
    envs,           # the batch of environment instances
    actor,          # the actor which interacts with the batch
    n_steps=51,     # the length of the rollout fragment
    sticky=False,   # whether to stop interacting if an env resets mid-fragment
    device=None,    # specifies the device to put the actor's inputs onto
)
```

`rollout()` returns an iterator, which has, roughly, the same logic,
as the manual collection above.

Inside the infinite loop it copies `fragment.pyt` onto `device`, before
yielding it to the user. It also does not spawn its own batch of environments,
unlike parallel variants.

The user has to manually limit the number of iterations using, for example,

```python
it = same.rollout(...)

for b, batch in zip(range(100), it):
    # train on batch
    pass

it.close()
```

<br>

### Rollout collection (single-process)

Single-actor rollout sampler running in a parallel process (double-buffered).

In [None]:
from rlplay.engine.rollout import single

Under the hood the functions creates **two** rollout fragment buffers, maintains
a reference to the specified `actor`, makes a shared copy of it (on the host), and
then spawns one worker process.

The worker, in turn, makes its own local copy of the actor on the specified device,
initializes the environments and the running context. During collection it altrenates
between the buffers, into which it records the rollout fragments it collects. Except
for double buffering, the logic is identical to `rollout`.

The local copies of the actor are **automatically updated** from the maintained reference.

```python
it = single.rollout(
    factory,              # the environment factory
    actor,                # the actor reference, used to update the local actors

    n_steps,              # the duration of a rollout fragment
    n_envs,               # the number of independent environments in the batch

    sticky=False,         # do we freeze terminated environments until the end of the rollout?
                          #  required if we wish to leverage cudnn's fast RNN implementations,
                          #  instead of manually stepping through the RNN core.

    clone=True,           # should the worker use a local clone of the reference actor

    close=True,           # should we `.close()` the environments when cleaning up?
                          #  some envs are very particular about this, e.g. nle

    start_method='fork',  # `fork` in notebooks, `spawn` in linux/macos and if we interchange
                          #  cuda tensors between processes (we DO NOT do that: we exchange indices
                          #  to host-shapred tensors)

    device=None,          # the device on which to collect rollouts (the local actor is moved
                          #  onto this device)
)

# ...

it.close()
```

<br>

### Rollout collection (multi-process)

A more load-balanced multi-actor milti-process sampler

In [None]:
from rlplay.engine.rollout import multi

This version of the rollout collector allocates several buffers and spawns
many parallel workers. Each worker creates it own local copy of the actor,
instantiates `n_envs` local environments and allocates a running context for
all of them. The rollout collection in each worker is **hardcoded to run on
the host device**.

```python
it = multi.rollout(
    factory,              # the environment factory
    actor,                # the actor reference, used to update the local actors

    n_steps,              # the duration of each rollout fragment

    n_actors,             # the number of parallel actors
    n_per_actor,          # the number of independent environments run in each actor
    n_buffers,            # the size of the pool of buffers, into which rollout
                          #  fragments are collected. Should not be less than `n_actors`.
    n_per_batch,          # the number of fragments collated into a batch

    sticky=False,         # do we freeze terminated environments until the end of the rollout?
                          #  required if we wish to leverage cudnn's fast RNN implementations,
                          #  instead of manually stepping through the RNN core.

    pinned=False,

    clone=True,           # should the parallel actors use a local clone of the reference actor

    close=True,           # should we `.close()` the environments when cleaning up?
                          #  some envs are very particular about this, e.g. nle

    device=None,          # the device onto which to move the rollout batches

    start_method='fork',  # `fork` in notebooks, `spawn` in linux/macos and if we interchange
                          #  cuda tensors between processes (we DO NOT do that: we exchange indices
                          #  to host-shared tensors)
)

# ...

it.close()
```

<br>

### Evaluation (same-process)

In order to evaluate an actor in a batch of environments, one can use `evaluate`.

In [None]:
from rlplay.engine import evaluate as core_evaluate

The function *does not* collect the rollout data, except for the rewards.
Below is the intended use case.
* **NB** this is run in the same process, hence blocks until completion, which
might take considerable time (esp. if `n_steps` is unbounded)

In [None]:
# same process
def same_evaluate(
    factory, actor, n_envs=4,
    *, n_steps=None, close=True, render=False, device=None
):
    # spawn a batch of environments
    envs = [factory() for _ in range(n_envs)]

    try:
        while True:
            rewards, _ = core_evaluate(
                envs, actor, n_steps=n_steps,
                render=render, device=device)

            # get the accumulated rewards (gamma=1)
            yield sum(rewards)

    finally:
        if close:
            for e in envs:
                e.close()

<br>

### Evaluation (parallel process)

Like rollout collection, evaluation can (and probably shoulb) be performed in
a parallel process, so that it does not burden the main thread with computations
not related to training.

In [None]:
from rlplay.engine.rollout.evaluate import evaluate

<br>

## CartPole with REINFORCE or A2C

### the CartPole Environment

In [None]:
import gym

# hotfix for gym's unresponsive viz (spawns gl threads!)
import rlplay.utils.integration.gym

The environment factory

In [None]:
import time
from rlplay.zoo.env import NarrowPath


class FP32Observation(gym.ObservationWrapper):
    def observation(self, observation):
        obs = observation.astype(numpy.float32)
        obs[0] = 0.  # mask the position info
        return obs

#     def step(self, action):
#         obs, reward, done, info = super().step(action)
#         reward -= abs(obs[1]) / 10  # punish for non-zero speed
#         return obs, reward, done, info

class OneHotObservation(gym.ObservationWrapper):
    def observation(self, observation):
        return numpy.eye(1, self.env.observation_space.n,
                         k=observation, dtype=numpy.float32)[0]

def factory(unwrap=True):
    env = gym.make("CartPole-v0")
    env = env.unwrapped if unwrap else env
    return FP32Observation(env)

#     return gym.make("Taxi-v3").unwrapped
    # return OneHotObservation(NarrowPath())

In [None]:
from functools import partial

factory_eval = partial(factory, unwrap=True)

<br>

### the algorithms

Service functions for the algorithms

In [None]:
from rlplay.engine.utils.plyr import suply, xgetitem

def timeshift(state, *, shift=1):
    """Get current and shfited slices of nested objects."""
    # use xgetitem to lett None through
    # XXX `curr[t]` = (x_t, a_{t-1}, r_t, d_t), t=0..T-H
    curr = suply(xgetitem, state, index=slice(None, -shift))

    # XXX `next[t]` = (x_{t+H}, a_{t+H-1}, r_{t+H}, d_{t+H}), t=0..T-H
    next = suply(xgetitem, state, index=slice(shift, None))

    return curr, next

The reinforce PG algo

In [None]:
from rlplay.algo.returns import pyt_returns

In [None]:
# @torch.enable_grad()
def reinforce(batch, module, *, gamma=0.99, C_entropy=1e-2,
              c_rho=float('inf')):
    r"""The REINFORCE algorithm (importance-weighted off-policy).

    The basic policy-gradient alogrithm with a baseline $b_t$:
    $$
        \nabla_\theta J(s_t)
            = \mathbb{E}_{a \sim \beta(a\mid s_t)}
                \frac{\pi(a\mid s_t)}{\beta(a\mid s_t)}
                    \bigl( r_{t+1} + \gamma G_{t+1} - b_t \bigr)
                \nabla_\theta \log \pi(a\mid s_t)
        \,. $$
    """
    
    state, state_next = timeshift(batch.state)

    # REACT: (state[t], h_t) \to (\hat{a}_t, h_{t+1}, \hat{A}_t)
    _, _, info = module(
        state.obs, state.act, state.rew, state.fin,
        hx=batch.hx, stepno=state.stepno)

    # The present value of the future rewards following `state[t]`:
    #    G_t = r_{t+1} + \gamma G_{t+1}
    ret = pyt_returns(state_next.rew, state_next.fin,
                      gamma=gamma, bootstrap=torch.tensor(0.))

    # Assume .act is unstructured: `act[t]` = a_{t+1} -->> T x B x 1
    act = state_next.act.unsqueeze(-1)

    # \pi is the target policy, \mu is the behaviour policy
    log_pi, log_mu = info['logits'], batch.actor['logits']

    # the importance weights
    log_pi_a = log_pi.gather(-1, act).squeeze(-1)
    log_mu_a = log_mu.gather(-1, act).squeeze(-1)
    rho = log_mu_a.sub_(log_pi_a.detach())\
                  .neg_().exp_().clamp_(max=c_rho)

    # the policy surrogate score
    #    \frac1T \sum_t \rho_t (G_t - b_t) \log \pi(a_t \mid s_t)
    reinfscore = log_pi_a.mul(ret.sub(ret.mean(dim=0)).mul_(rho)).mean()

    # the policy entropy score (neg entropy)
    #   - H(\pi(•\mid s)) = - (-1) \sum_a \pi(a\mid s) \log \pi(a\mid s)
    f_min = torch.finfo(log_pi.dtype).min
    negentropy = log_pi.exp().mul(log_pi.clamp(min=f_min)).sum(dim=-1).mean()

    # maximize the entropy and the reinforce score
    # \ell := - \frac1T \sum_t G_t \log \pi(a_t \mid s_t)
    #         - C \mathbb{H} \pi(\cdot \mid s_t)
    loss = C_entropy * negentropy - reinfscore
    return loss.mean(), dict(entropy=-float(negentropy),
                             policy_score=float(reinfscore),)

Actor-critic algo

In [None]:
import torch.nn.functional as F

# @torch.enable_grad()
def a2c(batch, module, *, gamma=0.99, C_entropy=1e-2, C_value=0.25, c_rho=1.0):
    r"""The Advantage Actor-Critic algorithm (importance-weighted off-policy).

    Close to REINFORCE, but uses spearate baseline estimate to compute
    advantages in the policy grad.
    $$
        \nabla_\theta J(s_t)
            = \mathbb{E}_{a \sim \beta(a\mid s_t)}
                \frac{\pi(a\mid s_t)}{\beta(a\mid s_t)}
                    \bigl( r_{t+1} + \gamma G_{t+1} - v(s_t) \bigr)
                \nabla_\theta \log \pi(a\mid s_t)
        \,, $$

    where the critic estimates the value function under the current policy
    $$
        v(s_t) \approx \mathbb{E}_{\pi_{\geq t}}
                        G_t(a_t, s_{t+1}, a_{t+1}, ... \mid s_t)
        \,. $$
    """
    state, state_next = timeshift(batch.state)

    # REACT: (state[t], h_t) \to (\hat{a}_t, h_{t+1}, \hat{A}_t)
    _, _, info = module(
        state.obs, state.act, state.rew, state.fin,
        hx=batch.hx, stepno=state.stepno)
    # info['value'] = V(`.state[t]`)
    #               <<-->> v(x_t)
    #               \approx \mathbb{E}( G_t \mid x_t)
    #               \approx \mathbb{E}( r_{t+1} + \gamma r_{t+2} + ... \mid x_t)
    #               <<-->> npv(`.state[t+1:]`)
    # info['logits'] = \log \pi(... | .state[t] )
    #                <<-->> \log \pi( \cdot \mid x_t)

    # `.actor[t]` is actor's extra info in reaction to `.state[t]`, t=0..T
    bootstrap = batch.actor['value'][-1]
    #     `bootstrap` <<-->> `.value[-1]` = V(`.state[-1]`)

    # Future rewards following after `.state[t]` are recorded in `.state[t+1:]`
    # ret[t] = rew[t] + gamma * (1 - fin[t]) * (ret[t+1] or bootstrap)
    ret = pyt_returns(state_next.rew, state_next.fin,
                      gamma=gamma, bootstrap=bootstrap)

    # XXX post-mul by `1 - \gamma` fails to train, but seems appropriate
    # for the continuation/survival interpretation of the discount factor.
    #   <<-- but who says this is a good interpretation?
    # ret.mul_(1 - gamma)

    # Assume `.act` is unstructured: `act[t]` = a_{t+1} -->> T x B x 1
    act = state_next.act.unsqueeze(-1)  # actions taken during the rollout

    # \pi is the target policy, \mu is the behaviour policy
    log_pi, log_mu = info['logits'], batch.actor['logits']

    # the importance weights
    log_pi_a = log_pi.gather(-1, act).squeeze(-1)
    log_mu_a = log_mu.gather(-1, act).squeeze(-1)
    rho = log_mu_a.sub_(log_pi_a.detach())\
                  .neg_().exp_().clamp_(max=c_rho)

    # the critic's score (negative mse)
    #  \frac1{2 T} \sum_t (G_t - v(s_t))^2
    value = info['value']
    critic_mse = 0.5 * F.mse_loss(value, ret, reduction='mean')
    # v(x_t) \approx \mathbb{E}( G_t \mid x_t )
    #        \approx G_t (one-point estimate)
    #        <<-->> ret[t]

    # the policy surrogate score
    #    \frac1T \sum_t \rho_t (G_t - v_t) \log \pi(a_t \mid s_t)
    a2c_score = log_pi_a.mul(ret.sub(value.detach()).mul_(rho)).mean()

    # the policy entropy score (neg entropy)
    #   - H(\pi(•\mid s)) = - (-1) \sum_a \pi(a\mid s) \log \pi(a\mid s)
    f_min = torch.finfo(log_pi.dtype).min
    negentropy = log_pi.exp().mul(log_pi.clamp(min=f_min)).sum(dim=-1).mean()

    # maximize the entropy and the reinforce score, minimize the critic loss
    objective = C_entropy * negentropy + C_value * critic_mse - a2c_score
    return objective.mean(), dict(entropy=-float(negentropy),
                                  policy_score=float(a2c_score),
                                  value_loss=float(critic_mse))

For example we can also use GAE
```python
from rlplay.algo.returns import pyt_gae

# the positional arguments are $r_{t+1}$, $d_{t+1}$, and $v(s_t)$,
#  respectively, for $t=0..T-1$. The bootstrap is $v(S_T)$.
pyt_gae(batch.state.rew[1:], batch.state.fin[1:], batch.actor['value'][:-1],
        gamma=0.99, bootstrap=batch.actor['value'][-1])
```

#### D-DQN loss

Double DQN loss for contiguous trajectory fragements. 

* `.state[t+1].rew` -- $r_{t+1}$
* `.state[t+1].fin` -- $d_{t+1}$
* `.state[t+1].act` -- $a_t$ which caused $
    s_t \longrightarrow (s_{t+1}, r_{t+1}, d_{t+1})
$
* `_, _, batch.actor[t] = actor(.state[t])` -- $
    Q(z_t, \cdot; h_t, \theta_{\text{old}})
$ -- used for rollout collection
* `_, _, info_module[t] = module(.state[t])` -- $
    Q(z_t, \cdot; h_t, \theta)
$ -- the current q-function, producing $(h_t)_{t=0}^{T+1}$
* `_, _, info_target[t] = target(.state[t])` -- $
    Q(z_t, \cdot; h_t, \theta_-)
$ -- the q-target with $(h^-_t)_{t=0}^{T+1}$

The td-error is
$$
\delta_t
    = r_{t+1} + \gamma 1_{\neg d_{t+1}} v^*(z_{t+1})
        - Q(z_t, a_t; h_t, \theta)
    \,, $$

where the approximate state value estimate is $
    v^*(z_{t+1})
$ is one of
* $
    \max_a Q(z_{t+1}, a; h_{t+1}, \theta)
$ for the `Q-learning`
* $
    Q(z_{t+1}, \hat{a}; h_{t+1}, \theta_-)
$ for $
    \hat{a} = \arg \max_a Q(z_{t+1}, a; h_{t+1}, \theta)
$ in the `double DQN`
* $
    \max_a Q(z_{t+1}, a; h_{t+1}, \theta_-)
$ in the `DQN`

Notice that currently we have two trajectories of $h_t \to h_{t+1}$:
* computed using $\theta_-$ and $\theta$
  * representational drift

One of the works realted to DQN learning of recurrent agents is [Kapturowski et al. (2018)](https://openreview.net/forum?id=r1lyTjAqYX), who propose to use the burn=in period in
the contiguous trajectory fragment in order to compenaste for the representation drift.
However, there is noe clear cut evidence suggestingthat the hidden recurrent sequences
$h_t$ and $h^-_t$ yield significantly different results.

<br>

In [None]:
import copy

# @torch.enable_grad()
def rddqn(batch, module,
          *, target, gamma=0.95, double=True, weights=None, loss=F.mse_loss):
    r"""Compute the Double-DQN loss over a contiguous fragment of a trajectory.

    Details
    -------
    In Q-learning the action value function minimizes the TD-error
    $$
        r_{t+1} + \gamma v^*(z_{t+1}) 1_{\neg d_{t+1}}
            - Q(z_t, a_t; \theta)
        \,, $$
    w.r.t. Q-network parameters $\theta$ where $z_t$ is the actionable , e.g.
    the current observation $x_t$ and the recurrent state $h_t$, the last
    action $a_{t-1}$, reward $r_t$, and termination flag $d_t$.

    If classic Q-learning there is no target network and the next state optimal value function is bootstrapped
    using the current Q-network (`module`):
    $$
        v^*(s_{t+1})
            \approx \max_a Q(s_{t+1}, a; \theta)
        \,. $$
    The DQN method, proposed by
        [Minh et al. (2013)](https://arxiv.org/abs/1312.5602),
    uses a secondary Q-network to estimate the value of the next state:
    $$
        v^*(s_{t+1})
            \approx \max_a Q(s_{t+1}, a; \theta^-)
        \,, $$
    where $\theta^-$ are frozen parameters of the Q-network (`target`). The
    Double DQN algorithm of
        [van Hasselt et al. (2015)](https://arxiv.org/abs/1509.06461)
    unravels the $\max$ operator as $
        \max_k u_k
            \equiv u_{\arg \max_k u_k}
    $ and replaces the outer $u$ with the Q-values of the target net, while
    the inner $u$ (inside the $\arg\max$) is computed with the Q-values of the
    current Q-network. Specifically, the Double DQN value estimate is
    $$
        v^*(s_{t+1})
            \approx Q(s_{t+1}, \hat{a}_{t+1}; \theta^-)
        \,, $$
    for $
        \hat{a}_{t+1}
            = \arg \max_a Q(s_{t+1}, a; \theta)
    $ being the action taken by the current Q-network $\theta$ at $s_{t+1}$.
    """

    if target is None:
        # use $Q(\cdot; \theta)$ instead of $Q(\cdot; \theta^-)$
        target, double = module, False

    fragment = batch.state

    # get $Q(z_t, \cdot; h_t, \theta)$ for all t=0..T, and $(h_t)_t$
    _, _, info_module = module(
        fragment.obs, fragment.act,
        fragment.rew, fragment.fin,
        hx=batch.hx, stepno=fragment.stepno)

    # get $Q(z_t, \cdot; h^-_t \theta^-)$ for all t=0..T, and $(h^-_t)_t$
    # XXX make sure no grads are computed for the target Q-value
    with torch.no_grad():
        info_target = info_module
        if target is not module:
            _, _, info_target = target(
                fragment.obs, fragment.act,
                fragment.rew, fragment.fin,
                hx=batch.hx, stepno=fragment.stepno)

    # $\hat{A}_t$: get the module's response to current and next state by shifting
    info_module_curr, info_module_next = timeshift(info_module)

    # get the next state `state[t+1]`, and the target's response to it
    state_next, info_target_next = suply(
        xgetitem, (fragment, info_target), index=slice(1, None))

    # get $Q(z_t, a_t; h_t, \theta)$ for all t=0..T-1
    q_replay = info_module_curr['q'].gather(-1, state_next.act.unsqueeze(-1))

    # get the target value
    with torch.no_grad():
        if double and module is not target:
            # get $\hat{a} = \arg \max_a Q(z_{t+1}, a; h_{t+1}, \theta)$
            # XXX can we reuse $\hat{a}$ from the REACT step?
            hat_a = info_module_next['q'].max(dim=-1).indices

            # get $\hat{v}(z_{t+1}) = Q(z_{t+1}, \hat{a}; h^-_{t+1}, \theta^-)$
            q_value = info_target_next['q'].gather(-1, hat_a.unsqueeze(-1))

        else:
            # get $\hat{v}(z_{t+1}) = \max_a Q(z_{t+1}, a; h^-_{t+!}, \theta^-)$
            q_value = info_target_next['q'].max(dim=-1, keepdim=True).values

        # get $r_{t+1} + \gamma 1_{d_{t+1}} \hat{v}(z_{t+1})$ using inplace ops
        q_value.masked_fill_(state_next.fin.unsqueeze(-1), 0.)
        q_value.mul_(gamma).add_(state_next.rew.unsqueeze(-1))

        # compute the Temp. Diff. error (δ) for experience prioritization
        td_error = q_replay - q_value
    # end with

    # the (weighted) td-error loss
    if weights is None:
        value = loss(q_replay, q_value, reduction='mean')
        return value, {'td_error': td_error}

    values = loss(q_replay, q_value, reduction='none')
    return weights.mul(values).mean(), {'td_error': td_error}

<br>

### the Actor

The policy of the actor

A procedure and a layer, which converts the input integer data into its
little endian binary representation as float $\{0, 1\}^m$ vectors.

In [None]:
def onehotbits(input, n_bits=63, dtype=torch.float):
    """Encode integers to fixed-width binary floating point vectors"""
    assert not input.dtype.is_floating_point
    assert 0 < n_bits < 64  # torch.int64 is signed, so 64-1 bits max

    # n_bits = {torch.int64: 63, torch.int32: 31, torch.int16: 15, torch.int8 : 7}

    # get mask of set bits
    pow2 = torch.tensor([1 << j for j in range(n_bits)]).to(input.device)
    x = input.unsqueeze(-1).bitwise_and(pow2).to(bool)

    # upcast bool to float to get one-hot
    return x.to(dtype)

class OneHotBits(torch.nn.Module):
    def __init__(self, n_bits=63, dtype=torch.float):
        assert 1 <= n_bits < 63
        super().__init__()
        self.n_bits, self.dtype = n_bits, dtype

    def forward(self, input):
        return onehotbits(input, n_bits=self.n_bits, dtype=self.dtype)

A more sophisticated recurrent learner:

In [None]:
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence


class CartPoleActor(BaseActorModule):
    def __init__(self, epsilon=0.1, lstm=False, use_cudnn=True):
        super().__init__()
        self.epsilon, self.lstm, self.use_cudnn = epsilon, lstm, use_cudnn

        self.features = torch.nn.ModuleDict(dict(
            obs=torch.nn.Sequential(
                torch.nn.Linear(4, 64),
                torch.nn.ReLU(),
            ),
            act=torch.nn.Embedding(2, 4),
            rew=torch.nn.Sequential(
                torch.nn.Linear(1, 4),
                torch.nn.ReLU(),
            ),
            stepno=torch.nn.Sequential(
                OneHotBits(16),
                torch.nn.Linear(16, 16),
                torch.nn.ReLU(),
            ),
        ))

        n_features = 64 + 4 + 4 + 16
        if not self.lstm:
            self.core = torch.nn.Sequential(
                torch.nn.Linear(n_features, 64),
                torch.nn.ReLU(),
            )

        else:
            self.core = torch.nn.LSTM(n_features, 64, 1)

        self.policy = torch.nn.Sequential(
            torch.nn.Linear(64, 2),
            torch.nn.LogSoftmax(dim=-1),
        )

        self.baseline = torch.nn.Sequential(
            torch.nn.Linear(64, 1),
        )

    def forward(self, obs, act, rew, fin,
                *, hx=None, stepno=None, virtual=False):
        # Everything is  [T x B x ...]
        inputs = torch.cat([
            self.features['obs'](obs),
            self.features['act'](act),
            self.features['rew'](rew.unsqueeze(-1)),
            self.features['stepno'](stepno),
        ], dim=-1)
        
        if not self.lstm:
            output, hx = self.core(inputs), ()

        elif self.use_cudnn:
            # sequence padding (MUST have sampling with `sticky=True`)
            n_steps, n_env, *_ = fin.shape
            if n_steps > 1:
                # we assume sticky=True
                lengths = 1 + (~fin[1:]).sum(0).cpu()
                inputs = pack_padded_sequence(inputs, lengths, enforce_sorted=False)

            output, hx = self.core(inputs, hx)
            if n_steps > 1:
                output, lens = pad_packed_sequence(
                    output, batch_first=False, total_length=n_steps)

        else:
            # inputs is T x B x F, hx is either None, or a proper recurrent state
            outputs = []
            # manually step through the RNN core
            for input, mask in zip(inputs.unsqueeze(1), fin.unsqueeze(-1)):
                # zero if f indicates reset: multiplying by zero stops grad
                if hx is not None:
                    # stop hx grads if `reset` (mul-by-zero)
                    hx = suply(torch.Tensor.mul, hx, other=~mask)

                output, hx = self.core(input, hx)
                outputs.append(output)

            output = torch.cat(outputs, dim=0)

        # value must not have any trailing dims, i.e. T x B
        logits = self.policy(output)
        value = self.baseline(output).squeeze(-1)

        # XXX eps-greedy?
        if self.training:
            # blend the policy with a uniform distribution
            prob = logits.detach().exp().mul_(1 - self.epsilon)
            prob.add_(self.epsilon / logits.shape[-1])

            actions = multinomial(prob)

        else:
            actions = logits.argmax(dim=-1)

        return actions, hx, dict(
            value=value,
            logits=logits,
            # entropy
            # logit_at_the_chosen_action
        )

Initialize the learner

In [None]:
# learner, sticky = nonRecurrentPolicyWrapper(policy()), False
learner = CartPoleActor(lstm=False, use_cudnn=False)
sticky = learner.use_cudnn and learner.lstm

learner.train()
device_ = torch.device('cpu')  # torch.device('cuda:0')
learner.to(device=device_)

# prepare the optimizer for the learner
optim = torch.optim.Adam(learner.parameters(), lr=1e-3, weight_decay=1e-3)

Load a better trained agent

Initialize the sampler

In [None]:
# T, B = 120, 20
# T, B = 120, 4
T, B = 21, 4

Pick one collector
* NetHack environment `nle`, does not like `fork` method, so we should use `spawn`, which is not notebook friendly :(
  * essentially it is better to prototype in notebook with `same.rollout`, then write a submodule non-interactive script with `multi.rollout`

In [None]:
# generator of rollout batches
batchit = multi.rollout(
    factory,
    learner,
    n_steps=T,
    n_actors=8,
    n_per_actor=B,
    n_buffers=16,
    n_per_batch=2,
    sticky=sticky,  # so that we can leverage cudnn's fast RNN implementations
    pinned=False,
    clone=False,
    close=False,
    device=device_,
    start_method='fork',  # fork in notebook for macos, spawn in linux
)

Implement your favourite training method

In [None]:
import tqdm
from torch.nn.utils import clip_grad_norm_

# pytoch loves to hog all threads on some linux systems 
torch.set_num_threads(1)

gamma = 0.99
losses, rewards = [], []

# generator of evaluation rewards
# test_it = test(factory_eval, learner, n_envs=4, n_steps=500, device=device_)
test_it = evaluate(factory_eval, learner, n_envs=4, n_steps=500,
                   clone=False, device=device_, start_method='fork')

# the training loop
exclude = {'returns'}
ewm, alpha = None, 0.5
for epoch in tqdm.tqdm(range(400)):
    for j, batch in zip(range(100), batchit):
        loss, info = a2c(batch, learner, gamma=gamma, c_rho=1.5)

        optim.zero_grad()
        loss.backward()

        grad_norm = clip_grad_norm_(learner.parameters(), max_norm=1e2)
        optim.step()

        losses.append({k: float(v) for k, v in info.items() if k not in exclude})
        losses[-1]['grad'] = float(grad_norm)

    # fetch the evaluation results (lag by one innre loop!)
    rewards.append(next(test_it))

    # eqm track the minimal evaluation reward
    low = rewards[-1].min()
    if ewm is None:
        ewm = low

    else:
        ewm += alpha * (low - ewm)

    if ewm > 498:
        break

# close the generators
batchit.close()
test_it.close()

In [None]:
batchit.close()

<br>

In [None]:
def collate(records):
    """collate identically keyed dicts"""
    out, n_records = {}, 0
    for record in records:
        for k, v in record.items():
            out.setdefault(k, []).append(v)
    
    return out

In [None]:
data = {k: numpy.array(v) for k, v in collate(losses).items()}

In [None]:
if 'value_loss' in data:
    plt.semilogy(data['value_loss'])

In [None]:
plt.plot(data['entropy'])

In [None]:
plt.plot(data['policy_score'])

In [None]:
plt.semilogy(data['grad'])

In [None]:
rewards = numpy.stack(rewards, axis=0)

In [None]:
rewards

In [None]:
m, s = numpy.median(rewards, axis=-1), rewards.std(axis=-1)

In [None]:
fi, ax = plt.subplots(1, 1, figsize=(4, 2), dpi=300)

ax.plot(numpy.mean(rewards, axis=-1))
ax.plot(numpy.median(rewards, axis=-1))
ax.plot(numpy.min(rewards, axis=-1))
ax.plot(numpy.std(rewards, axis=-1))
# ax.plot(m+s * 1.96)
# ax.plot(m-s * 1.96)

plt.show()

In [None]:
with factory_eval() as env:
    learner.eval()
    eval_rewards, info = core_evaluate([
        env
    ], learner, render=True, n_steps=1e4, device=device_)

print(sum(eval_rewards))

In [None]:
from rlplay.algo.returns import npy_returns, npy_deltas

td_target = eval_rewards + gamma * info['value'][1:]
td_error = td_target - info['value'][:-1]
# td_error = npy_deltas(
#     eval_rewards, numpy.zeros_like(eval_rewards, dtype=bool), info['value'][:-1],
#     gamma=gamma, bootstrap=info['value'][-1])

fig, ax = plt.subplots(1, 1, figsize=(4, 2), dpi=300)
ax.semilogy(abs(td_error) / abs(td_target))
# ax.set_ylim(-0.0001, 0.002);
ax.set_title('relative td(1)-error');

In [None]:
from rlplay.algo.returns import npy_returns, npy_deltas

# plt.plot(
#     npy_returns(eval_rewards, numpy.zeros_like(eval_rewards, dtype=bool),
#                 gamma=gamma, bootstrap=info['value'][-1]))
fig, ax = plt.subplots(1, 1, figsize=(4, 2), dpi=300)
ax.plot(info['value'])
ax.axhline(1 / (1 - gamma), c='k', alpha=0.5);

In [None]:
import math
from scipy.special import softmax, expit, entr

*head, n_actions = info['logits'].shape
proba = softmax(info['logits'], axis=-1)

fig, ax = plt.subplots(1, 1, figsize=(4, 2), dpi=300)
ax.plot(entr(proba).sum(-1)[:, 0])
ax.axhline(math.log(n_actions), c='k', alpha=0.5);

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4, 2), dpi=300)
ax.hist(info['logits'][..., 1] - info['logits'][..., 0], bins=51);  # log-ratio

<br>

In [None]:
# stepno = batch.state.stepno
stepno = torch.arange(16384)

In [None]:
with torch.no_grad():
    out = learner.features['stepno'](stepno)

In [None]:
fig, axes = plt.subplots(3, 3, figsize=(8, 8), dpi=200,
                         sharex=True, sharey=True)

for j, ax in zip(range(out.shape[1]), axes.flat):
    ax.plot(out[:, j], lw=1)

fig.tight_layout()

In [None]:
with torch.no_grad():
    plt.imshow(abs(learner.core[0].weight[:, 64+4+4:]).T)

In [None]:
with torch.no_grad():
    plt.imshow(abs(learner.features.stepno[1].weight))

In [None]:
abs(learner.features.stepno[1].weight)

<br>

In [None]:
assert False

<br>

In [None]:
import matplotlib.pyplot as plt

p_l, v_l, ent = zip(*losses)

plt.plot(p_l)
plt.plot(ent)

In [None]:
plt.plot(v_l)

Run in the environment

In [None]:
plt.plot([
    sum(evaluate(factory, learner, render=False))
    for _ in range(200)
])

<br>

In [None]:
assert False

In [None]:
class Bar:
    def __init__(self, parent):
        self.parent = parent
        self._range = range(self.parent.n)
        self._it = iter(self._range)

    def __iter__(self):
        return self

    def __next__(self):
        return next(self._it)

class Foo:
    def __init__(self, n=10):
        self.n = n

    def __iter__(self):
        return Bar(self)


In [None]:
list(Foo())

In [None]:
class Bar:
    def __init__(self, parent):
        self.parent = parent

    def __iter__(self):
        yield from range(self.parent.n)

class Foo:
    def __init__(self, n=10):
        self.n = n

    def __iter__(self):
        return iter(Bar(self))


<br>