# Patchy overview of `rlplay` with REINFORCE

In [None]:
import torch
import numpy

import matplotlib.pyplot as plt
%matplotlib inline

<br>

## Rollout collection

Rollout collection is designed to be as much `plug-n-play` as possible, i.e. it
supports **arbitrarily structured nested containers** of arrays or tensors for
environment observations and actions. The actor, however, should **expose**
certain API (described below).

In [None]:
from rlplay.engine import core

# help(core.collect)

It's role is to serve as a *middle-man* between the **actor-environment** pair
and the **training loop**: to track the trajectory of the actor in the environment,
and properly record it into the data buffer.

For example, it is not responsible for seeding or randomization of environments
(i'm looking at you, `AtariEnv`), and datatype casting (except for rewards,
which are cast to `fp32` automatically). In theory, there is **no need** for
special data preporcessing, except for, perhaps, casting data to proper dtypes,
like from `numpy.float64` observations to `float32` in `CartPole`.

#### Semantics

The collector just carefully records the trajectory by alternating between
the **REACT** and **STEP+EMIT** phases in the following fashion:

$$
    \cdots
        \longrightarrow t
        \overset{\mathrm{REACT}}{\longrightarrow} t + \tfrac12
        \overset{\mathrm{STEP+EMIT}}{\longrightarrow} t + 1
        \longrightarrow \cdots
    \,, $$

where the half-times $t + \tfrac12$ are commonly reffered to as the `afterstates`:
the actor has chosen an action in response to the current observation, yet has
not interacted with the environment.

So the `time` advances in halves, and the proper names for the half times
in the diagram above are the `state`, the `afterstate` and the `next state`,
respectively.

The collected `fragment` data has the following structure:
* `.state` $z_t$ **the current "extended" observation**
  * `.stepno` $n_t$ the step counter
  * `.obs` $x_t$ **the current observation** emitted by transitioning to $s_t$
  * `.act` $a_{t-1}$ **the last action** which caused $s_{t-1} \longrightarrow s_t$ in the env
  * `.rew` $r_t$ **the previous reward** received by getting to $s_t$
  * `.fin` $d_t$ **the termination flag** indicating if $s_t$ is terminal in the env

* `.actor` $A_t$ auxiliary data from the actor due to **REACT**

* `.env` $E_{t+1}$ auxiliary data from the environment due to **STEP+EMIT**

* `.hx` $h_0$ the starting recurrent state of the actor

Here $s_t$ denotes **the unobserved true full state** of the environment.

The actor $\theta$ interacts with the environment and generates the following
<span style="color:orange">**tracked**</span> data during the rollout,
unobserved/non-tracked data <span style="color:red">**in red**</span>
and $t = 0..T-1$:

*  ${\color{orange}{h_0}}$, the starting recurrent state, is recorded in $\,.\!\mathtt{hx}$

* **REACT**: the actor performs the following update ($t \to t + \frac12$)

$$ 
    \bigl(
        \underbrace{
            .\!\mathtt{state}[\mathtt{t}]
        }_{{\color{orange}{z_t}}},\,
        {\color{red}{h_t}}
    \bigr)
        \overset{\text{Actor}_{\theta_{\text{old}}}}{\longrightarrow}
        \bigl(
            \underbrace{
                .\!\mathtt{state}.\!\mathtt{act}[\mathtt{t+1}]
            }_{a_t \leadsto {\color{orange}{z_{t+1}}}},\,
            \underbrace{
                .\!\mathtt{actor}[\mathtt{t}]
            }_{{\color{orange}{A_t}}},\,
            {\color{red}{h_{t+1}}}
        \bigr)
    \,, $$

* **STEP+EMIT**: the environment updates it's unobserved state and emits
the observed data ($t + \frac12 \to t+1_-$)

$$ 
    \bigl(
        {\color{red}{s_t}},\,
        \underbrace{
            .\!\mathtt{state}.\!\mathtt{act}[\mathtt{t+1}]
        }_{a_t \leadsto {\color{orange}{z_{t+1}}}}
    \bigr)
        \overset{\text{Env}}{\longrightarrow}
        \bigl(
            {\color{red}{s_{t+1}}},\,
            \underbrace{
                .\!\mathtt{state}.\!\mathtt{obs}[\mathtt{t+1}]
            }_{x_{t+1} \leadsto {\color{orange}{z_{t+1}}}},\,
            \underbrace{
                .\!\mathtt{state}.\!\mathtt{rew}[\mathtt{t+1}]
            }_{r_{t+1} \leadsto {\color{orange}{z_{t+1}}}},\,
            \underbrace{
                .\!\mathtt{state}.\!\mathtt{fin}[\mathtt{t+1}]
            }_{d_{t+1} \leadsto {\color{orange}{z_{t+1}}}},\,
            \underbrace{
                .\!\mathtt{env}[\mathtt{t}]
            }_{{\color{orange}{E_{t+1}}}}
        \bigr)
    \,, $$

* collect loop ($t + 1_- \to t+1$)

$$ 
    \bigl(
        {\color{orange}{n_t}},\,
        {\color{orange}{d_{t+1}}}
    \bigr)
        \longrightarrow
            \underbrace{
                .\!\mathtt{state}.\!\mathtt{stepno}[\mathtt{t+1}]
            }_{n_{t+1} \leadsto {\color{orange}{z_{t+1}}}}
    \,. $$

Here $r_t$ is a scalar reward, $d_t = \top$ if $s_t$ is terminal, or $\bot$
otherwise, $n_{t+1} = 0$ if $d_t = \top$, else $1 + n_t$, and $a \leadsto b$
means $a$ being recored into $b$.

In general, we may treat $z_t$, the extended observation, as an ordinary
observation, by **suitably modifying** the environment: we can make it
recall the most recent action $a_{t-1}$ and compute the termination indicator
$d_t$ of the current state, and let it keep track of the interaction counter
$n_t$, and, finally, we can configure it to supply the most recent reward
$r_t$ as part of the emitted observation.

Hence we essentially consider the following POMDP setup:
\begin{align}
    a_t, h_{t+1}, A_t
        &\longleftarrow \operatorname{Actor}(z_t, h_t; \theta)
        \,, \\
    z_{t+1}, r_{t+1}, E_{t+1}, s_{t+1}
        &\longleftarrow \operatorname{Env}(s_t, a_t)
        \,, \\
\end{align}

Specifically, let $
(z_t)_{t=0}^T
    = (n_t, x_t, a_{t-1}, r_t, d_t)_{t=0}^T
$ be the trajectory fragment in `.state`, and $h_0$, `.hx`, be the starting
(not necessarily the initial) recurrent state of the actor at the begining
of the rollout.

#### Requirements

* all nested containers **must be** built from pure python `dicts`, `lists`, `tuples` or `namedtuples`

* the environment communicates either in **numpy arrays** or in python **scalars**, but not in data types that are incompatible with pytorch (such as `str` or `bytes`)

```python
# example
obs = {
    'camera': {
        'rear': numpy.zeros(3, 320, 240),
        'front': numpy.zeros(3, 320, 240),
    },
    'proximity': (+0.1, +0.2, -0.1, +0.0,),
    'other': {
        'fuel_tank': 78.5,
        'passenger': False,
    },
}
```

* the actor communicates in torch tensors **only**

* the environment produces **float scalar** rewards (other data may be communicated through auxiliary environment info-dicts)

### Container support with `.plyr`

One of the core tools used in `rlplay` is a high performing procedure that traverses
containers of `list`, `dict` and `tuple` and calls the specified function with the 
non-container objects found the containers as arguments (like `map`, but not an iterator
for arbitrarily and applicable to structured objects).

See `README.md` and `rlplay.engine.utils.plyr.apply` for docs.

The `apply` procedire has slightly faster specialized version `suply` and `tuply`,
which do not waste time on validating the structure of the contianers. They differ
in the manner in which they call the specified function: the first passes positional
arguments, while the second passes all arguments in one tuple (think of `map` and
`starmap` from `functools`)

In [None]:
# appliers of functions to nested objects
from rlplay.engine.utils.plyr import apply, suply, tuply

# `setitem` function with argument order, specialized for `apply`
from rlplay.engine.utils.plyr import xgetitem, xsetitem

# help(apply)

How to use `suply` to reset the recurrent state `hx` returned by `torch.nn.LSTM`:

```python
# the mask of inputs just after env resets
fin = torch.randint(2, size=(10, 4), dtype=bool)

# the tensors in `hx` must have the same 2nd dim as `fin`
hx = torch.randn(2, 1, 4, 32, requires_grad=False).unbind()
h0 = torch.zeros(2, 1, 4, 32, requires_grad=True).unbind()
# XXX h0 and hx are tuples of tensors (but we're just as good with dicts)

# get the masks at step 2, and make it broadcastable with 3d hx
m = ~fin[2].unsqueeze(-1)  # reset tensors at fin==False

# multiply by zero the current `hx` (diff-able reset and grad stop)
suply(
    m.mul,  # `.mul` method of the mask upcasts from bool to float if necessary
    hx,     # arg `other` of `.mul`
)

# replace the reset batch elments by a diff-able init value
suply(
    torch.add,         # .add(input, other, *, alpha=1.)
    suply(m.mul, hx),  # arg `input` of `.add`
    suply(r.mul, h0),  # arg `other` of `.add`
    # alpha=1.         # pass other `alpha` if we want
)

# XXX `torch.where` does not have an `easily` callable interface
suply(
    lambda a, b: torch.where(m, a, b),  # or `a.where(m, b)`
    hx, h0,
)
```

For example, this is used to manually run the recurrent network loop:
```python
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence


if use_cudnn and sticky:
    # sequence padding (MUST have sampling with `sticky=True`)
    n_steps, n_env, *_ = fin.shape
    if n_steps > 1:
        # we assume sticky=True
        lengths = 1 + (~fin[1:]).sum(0).cpu()  # first observation's fin should be ignored
        inputs = pack_padded_sequence(input, lengths, enforce_sorted=False)

    output, hx = self.core(inputs, hx)
    if n_steps > 1:
        output, lens = pad_packed_sequence(
            output, batch_first=False, total_length=n_steps)

else:
    # input is T x B x F, hx is either None, or a proper recurrent state
    outputs = []
    for x, m in zip(input.unsqueeze(1), ~fin.unsqueeze(-1)):
        # `m` indicates if no reset took place, otherwise
        #  multiply by zero to stop the grads
        if hx is not None:
            hx = suply(m.mul, hx)

        output, hx = self.core(x, hx)
        outputs.append(output)

    output = torch.cat(outputs, dim=0)
```

<br>

### Creating the actors

Rollout collection relies on the following API of the actor:
* `.reset(j, hx)` reset the recurrent state of the j-th environment in the batch (if applicable)
  * `hx` contains tensors with shape `(n_lstm_layers * n_dir) x batch x hidden`, or is an empty tuple
  * the returned `hx` is the updated recurrent state


* `.step(stepno, obs, act, rew, fin, /, *, hx, virtual)` get the next action $a_t$, the recurrent state $h_{t+1}$, and
the **extra info** in response to $n_t$, $x_t$, $a_{t-1}$, $r_t$, $d_t$, and $h_t$ respectively.
  * extra info `dict` **might** include `value` key with a `T x B` tensor of state value estimates $
      v_t(z_t) \approx G_t = \mathbb{E} \sum_{j\geq t} \gamma^{j-t} r_{j+1}
  $.
  * MUST allocate new `hx` if the recurrent state is updated
  * MUST NOT change the inputs in-place


In [None]:
from rlplay.engine import BaseActorModule

help(BaseActorModule.reset)

In [None]:
help(BaseActorModule.step)

`BaseActorModule` is essentially a thin sub-class of `torch.nn.Module`, that implements
the API through `.forward(obs, act, rew, fin, *, hx, stepno)`, which should return three things:

1. `actions` prescribed actions in the environment, with data of shape `n_steps x batch x ...`
  * can be a nested container of dicts, lists, and tuples


2. `hx` data with shape `n_steps x batch x ...`
  * can be a nested container of dicts, lists, and tuples
  * **if an actor is not recurrent**, then must return an empty container, e.g. a tuple `()`


3. `info` object, which might be a tensor or a nested object containing data in tensors
`n_steps x batch x ...`. For example, one may communicate the following data:
  * `value` -- the state value estimates $v(z_t)$
  * `logits` -- the policy logits $\log \pi(\cdot \mid z_t)$
  * `q` -- $Q(z_t, \cdot)$ values

Here is an example actor, that wraps a simple MLP policy.

In [None]:
from rlplay.utils.common import multinomial

class PolicyWrapper(BaseActorModule):
    """A non-recurrent policy for a flat `Discrete(n)` action space."""

    def __init__(self, policy):
        super().__init__()
        self.policy = policy

        # for updating the exploration epsilon in the clones
        # self.register_buffer('epsilon', torch.tensor(epsilon))

    def forward(self, obs, act=None, rew=None, fin=None,
                *, hx=None, stepno=None, virtual=False):
        # Everything is  [T x B x ...]
        logits = self.policy(locals())

        actions = multinomial(logits.detach().exp())

        return actions, (), dict(logits=logits)

<br>

### Manual rollout collection

We shall need the following procedures from the core of the engine:

In [None]:
from rlplay.engine.core import prepare, startup, collect

Manual collection requires an `actor` and a batch of environment instances `envs`.

Prepare the run-time context for the specified `actor` and the environments
```python
# settings
sticky = False  # whether to stop interacting if an env resets mid-fragment
device = None   # specifies the device to put the actor's inputs and data onto
pinned = False  # whether to keep the running context in non-resizable pinned
                #  (non-paged) memory for faster host-device transfers

# initialize a buffer for one rollout fragment
buffer = prepare(envs[0], actor, n_steps, len(envs),
                 pinned=False, device=device)

# the running context tor the actor and the envs (optionally pinned)
ctx, fragment = startup(envs, actor, buffer, pinned=pinned)

while not done:
    # collect the fragment
    collect(envs, actor, fragment, ctx, sticky=sticky, device=device)

    # fragment.pyt -- torch tensors, fragment.npy -- numpy arrays (aliased on-host)
    do_stuff(actor, fragment.pyt)
```

<br>

### Rollout collection (same-process)

Collect rollouts within the current process

In [None]:
from rlplay.engine.rollout import same

The parameters have the following meaning
```python
it = same.rollout(
    envs,           # the batch of environment instances
    actor,          # the actor which interacts with the batch
    n_steps=51,     # the length of the rollout fragment
    sticky=False,   # whether to stop interacting if an env resets mid-fragment
    device=None,    # specifies the device to put the actor's inputs onto
)
```

`rollout()` returns an iterator, which has, roughly, the same logic,
as the manual collection above.

Inside the infinite loop it copies `fragment.pyt` onto `device`, before
yielding it to the user. It also does not spawn its own batch of environments,
unlike parallel variants.

The user has to manually limit the number of iterations using, for example,

```python
it = same.rollout(...)

for b, batch in zip(range(100), it):
    # train on batch
    pass

it.close()
```

<br>

### Rollout collection (single-process)

Single-actor rollout sampler running in a parallel process (double-buffered).

In [None]:
from rlplay.engine.rollout import single

Under the hood the functions creates **two** rollout fragment buffers, maintains
a reference to the specified `actor`, makes a shared copy of it (on the host), and
then spawns one worker process.

The worker, in turn, makes its own local copy of the actor on the specified device,
initializes the environments and the running context. During collection it altrenates
between the buffers, into which it records the rollout fragments it collects. Except
for double buffering, the logic is identical to `rollout`.

The local copies of the actor are **automatically updated** from the maintained reference.

```python
it = single.rollout(
    factory,              # the environment factory
    actor,                # the actor reference, used to update the local actors

    n_steps,              # the duration of a rollout fragment
    n_envs,               # the number of independent environments in the batch

    sticky=False,         # do we freeze terminated environments until the end of the rollout?
                          #  required if we wish to leverage cudnn's fast RNN implementations,
                          #  instead of manually stepping through the RNN core.

    clone=True,           # should the worker use a local clone of the reference actor

    close=True,           # should we `.close()` the environments when cleaning up?
                          #  some envs are very particular about this, e.g. nle

    start_method='fork',  # `fork` in notebooks, `spawn` in linux/macos and if we interchange
                          #  cuda tensors between processes (we DO NOT do that: we exchange indices
                          #  to host-shapred tensors)

    device=None,          # the device on which to collect rollouts (the local actor is moved
                          #  onto this device)
)

# ...

it.close()
```

<br>

### Rollout collection (multi-process)

A more load-balanced multi-actor milti-process sampler

In [None]:
from rlplay.engine.rollout import multi

This version of the rollout collector allocates several buffers and spawns
many parallel workers. Each worker creates it own local copy of the actor,
instantiates `n_envs` local environments and allocates a running context for
all of them. The rollout collection in each worker is **hardcoded to run on
the host device**.

```python
it = multi.rollout(
    factory,              # the environment factory
    actor,                # the actor reference, used to update the local actors

    n_steps,              # the duration of each rollout fragment

    n_actors,             # the number of parallel actors
    n_per_actor,          # the number of independent environments run in each actor
    n_buffers,            # the size of the pool of buffers, into which rollout
                          #  fragments are collected. Should not be less than `n_actors`.
    n_per_batch,          # the number of fragments collated into a batch

    sticky=False,         # do we freeze terminated environments until the end of the rollout?
                          #  required if we wish to leverage cudnn's fast RNN implementations,
                          #  instead of manually stepping through the RNN core.

    pinned=False,

    clone=True,           # should the parallel actors use a local clone of the reference actor

    close=True,           # should we `.close()` the environments when cleaning up?
                          #  some envs are very particular about this, e.g. nle

    device=None,          # the device onto which to move the rollout batches

    start_method='fork',  # `fork` in notebooks, `spawn` in linux/macos and if we interchange
                          #  cuda tensors between processes (we DO NOT do that: we exchange indices
                          #  to host-shared tensors)
)

# ...

it.close()
```

<br>

### Evaluation (same-process)

In order to evaluate an actor in a batch of environments, one can use `evaluate`.

In [None]:
from rlplay.engine import core

# help(core.evaluate)

The function *does not* collect the rollout data, except for the rewards.
Below is the intended use case.
* **NB** this is run in the same process, hence blocks until completion, which
might take considerable time (esp. if `n_steps` is unbounded)

In [None]:
# same process
def same_evaluate(
    factory, actor, n_envs=4,
    *, n_steps=None, close=True, render=False, device=None
):
    # spawn a batch of environments
    envs = [factory() for _ in range(n_envs)]

    try:
        while True:
            rewards, _ = core.evaluate(
                envs, actor, n_steps=n_steps,
                render=render, device=device)

            # get the accumulated rewards (gamma=1)
            yield sum(rewards)

    finally:
        if close:
            for e in envs:
                e.close()

<br>

### Evaluation (parallel process)

Like rollout collection, evaluation can (and probably shoulb) be performed in
a parallel process, so that it does not burden the main thread with computations
not related to training.

In [None]:
from rlplay.engine.rollout.evaluate import evaluate

<br>

## CartPole with REINFORCE

### the CartPole Environment

In [None]:
import gym

# hotfix for gym's unresponsive viz (spawns gl threads!)
import rlplay.utils.integration.gym

The environment factory

In [None]:
class FP32Observation(gym.ObservationWrapper):
    def observation(self, observation):
        return observation.astype(numpy.float32)
#         obs[0] = 0.  # mask the position info
#         return obs  # observation.astype(numpy.float32)

def factory():
    return FP32Observation(gym.make("CartPole-v0").unwrapped)

<br>

### the algorithms

Service functions for the algorithms

In [None]:
from rlplay.engine.utils.plyr import apply, suply, xgetitem


def timeshift(state, *, shift=1):
    """Get current and shfited slices of nested objects."""
    # use xgetitem to lett None through
    # XXX `curr[t]` = (x_t, a_{t-1}, r_t, d_t), t=0..T-H
    curr = suply(xgetitem, state, index=slice(None, -shift))

    # XXX `next[t]` = (x_{t+H}, a_{t+H-1}, r_{t+H}, d_{t+H}), t=0..T-H
    next = suply(xgetitem, state, index=slice(shift, None))

    return curr, next

The reinforce PG algo

In [None]:
from rlplay.algo.returns import pyt_returns

# @torch.enable_grad()
def reinforce(fragment, module, *, gamma=0.99, C_entropy=1e-2,
              c_rho=float('inf')):
    r"""The REINFORCE algorithm (importance-weighted off-policy).

    The basic policy-gradient alogrithm with a baseline $b_t$:
    $$
        \nabla_\theta J(s_t)
            = \mathbb{E}_{a \sim \beta(a\mid s_t)}
                \frac{\pi(a\mid s_t)}{\beta(a\mid s_t)}
                    \bigl( r_{t+1} + \gamma G_{t+1} - b_t \bigr)
                \nabla_\theta \log \pi(a\mid s_t)
        \,. $$
    """
    
    state, state_next = timeshift(fragment.state)

    # REACT: (state[t], h_t) \to (\hat{a}_t, h_{t+1}, \hat{A}_t)
    _, _, info = module(
        state.obs, state.act, state.rew, state.fin,
        hx=fragment.hx, stepno=state.stepno)

    # \pi is the target policy, \mu is the behaviour policy
    log_pi, log_mu = info['logits'], fragment.actor['logits']

    # Assume .act is unstructured: `act[t]` = a_{t+1} -->> T x B x 1
    act = state_next.act.unsqueeze(-1)

    # the importance weights
    log_pi_a = log_pi.gather(-1, act).squeeze(-1)
    log_mu_a = log_mu.gather(-1, act).squeeze(-1)

    bootstrap = torch.tensor(0.)
    # XXX bootstrap with the perpetual lasr reward?
    # bootstrap = state_next.rew[-1].div(1 - gamma)

    # The present value of the future rewards following `state[t]`:
    #    G_t = r_{t+1} + \gamma G_{t+1}
    ret = pyt_returns(state_next.rew, state_next.fin,
                      gamma=gamma, bootstrap=bootstrap)
    # ret.sub_(ret.mean(dim=0))
    # ret.div_(ret.std(dim=0))

    # the policy surrogate score (max)
    #    \frac1T \sum_t \rho_t (G_t - b_t) \log \pi(a_t \mid s_t)
    rho = log_mu_a.sub_(log_pi_a.detach()).neg_().exp_().clamp_(max=c_rho)
    reinfscore = log_pi_a.mul(ret.mul_(rho)).mean()

    # the policy neg-entropy score (min)
    #   - H(\pi(•\mid s)) = - (-1) \sum_a \pi(a\mid s) \log \pi(a\mid s)
    f_min = torch.finfo(log_pi.dtype).min
    negentropy = log_pi.exp().mul(log_pi.clamp(min=f_min)).sum(dim=-1).mean()

    # maximize the entropy and the reinforce score
    # \ell := - \frac1T \sum_t G_t \log \pi(a_t \mid s_t)
    #         - C \mathbb{H} \pi(\cdot \mid s_t)
    loss = C_entropy * negentropy - reinfscore
    return loss.mean(), dict(
        entropy=-float(negentropy),
        policy_score=float(reinfscore),
    )

<br>

### the Actor

A procedure and a layer, which converts the input integer data into its
little endian binary representation as float $\{0, 1\}^m$ vectors.

In [None]:
def onehotbits(input, n_bits=63, dtype=torch.float):
    """Encode integers to fixed-width binary floating point vectors"""
    assert not input.dtype.is_floating_point
    assert 0 < n_bits < 64  # torch.int64 is signed, so 64-1 bits max

    # n_bits = {torch.int64: 63, torch.int32: 31, torch.int16: 15, torch.int8 : 7}

    # get mask of set bits
    pow2 = torch.tensor([1 << j for j in range(n_bits)]).to(input.device)
    x = input.unsqueeze(-1).bitwise_and(pow2).to(bool)

    # upcast bool to float to get one-hot
    return x.to(dtype)


class OneHotBits(torch.nn.Module):
    def __init__(self, n_bits=63, dtype=torch.float):
        assert 1 <= n_bits < 64
        super().__init__()
        self.n_bits, self.dtype = n_bits, dtype

    def forward(self, input):
        return onehotbits(input, n_bits=self.n_bits, dtype=self.dtype)

A special module dictionary, which aplies itself to the input dict of tensors

In [None]:
from typing import Optional, Mapping
from torch.nn import Module, ModuleDict as BaseModuleDict


class ModuleDict(BaseModuleDict):
    """The ModuleDict, that applies itself to hte indup dicts."""
    def __init__(
        self,
        modules: Optional[Mapping[str, Module]] = None,
        dim: Optional[int]=-1
    ) -> None:
        super().__init__(modules)
        self.dim = dim

    def forward(self, input):
        # enforce concatenation in the order of the declaration in  __init__
        return torch.cat([
            m(input[k]) for k, m in self.items()
        ], dim=self.dim)

A policy which uses many inputs.

In [None]:
from torch.nn import Sequential
from torch.nn import Embedding, Linear, Identity
from torch.nn import ReLU, LogSoftmax

def policy():
    return Sequential(
        ModuleDict(dict(
#             stepno=Sequential(
#                 OneHotBits(), Linear(63, 4, bias=False)
#             ),
            obs=Identity(),
            act=Embedding(2, 4),
        )),
        Linear(0 + 4 + 4, 32),
        ReLU(),
        Linear(32, 2),
        LogSoftmax(dim=-1),
    )

The discount factor

In [None]:
gamma = 0.99
C_entropy = 0.01
c_rho = float('inf')

Initialize the learner and the factories

In [None]:
from functools import partial

factory_eval = partial(factory)

learner, sticky = PolicyWrapper(policy()), False

learner.train()
device_ = torch.device('cpu')  # torch.device('cuda:0')
learner.to(device=device_)

# prepare the optimizer for the learner
optim = torch.optim.Adam(learner.parameters(), lr=1e-2)

Initialize the sampler

In [None]:
T, B = 25, 8

Pick one collector
* NetHack environment `nle`, does not like `fork` method, so we should use `spawn`, which is not notebook friendly :(
  * essentially it is better to prototype in notebook with `same.rollout`, then write a submodule non-interactive script with `multi.rollout`

In [None]:
# generator of rollout batches
batchit = multi.rollout(
    factory,
    learner,
    n_steps=T,
    n_actors=16,
    n_per_actor=B,
    n_buffers=24,
    n_per_batch=2,
    sticky=sticky,  # so that we can leverage cudnn's fast RNN implementations
    pinned=False,
    clone=True,
    close=False,
    device=device_,
    start_method='fork',  # fork in notebook for macos, spawn in linux
)

Generator of evaluation rewards

In [None]:
# test_it = test(factory_eval, learner, n_envs=4, n_steps=500, device=device_)
test_it = evaluate(factory_eval, learner, n_envs=4, n_steps=500,
                   clone=False, device=device_, start_method='fork')

Implement your favourite training method

In [None]:
import tqdm
# from math import log, exp
from torch.nn.utils import clip_grad_norm_

# pytoch loves to hog all threads on some linux systems 
torch.set_num_threads(1)

# the training loop
losses, rewards = [], []
# decay = -log(2) / 25  # exploration epsilon halflife
for epoch in tqdm.tqdm(range(100)):
    for j, batch in zip(range(100), batchit):
        loss, info = reinforce(batch, learner, gamma=gamma,
                               C_entropy=C_entropy, c_rho=c_rho)

        optim.zero_grad()
        loss.backward()
        grad_norm = clip_grad_norm_(learner.parameters(), max_norm=numpy.inf)
        optim.step()
    
        losses.append(dict(
            loss=float(loss),
            grad=float(grad_norm),
            **info
        ))

    # fetch the evaluation results (lag by one inner loop!)
    rewards.append(next(test_it))

    # learner.epsilon.mul_(exp(decay)).clip_(0.1, 1.0)

In [None]:
# close the generators
batchit.close()
test_it.close()

import pdb; pdb.pm()

<br>

In [None]:
def collate(records):
    """collate identically keyed dicts"""
    out, n_records = {}, 0
    for record in records:
        for k, v in record.items():
            out.setdefault(k, []).append(v)
    
    return out


data = {k: numpy.array(v) for k, v in collate(losses).items()}

In [None]:
if 'loss' in data:
    plt.plot(data['loss'])

In [None]:
if 'entropy' in data:
    plt.plot(data['entropy'])

In [None]:
if 'policy_score' in data:
    plt.plot(data['policy_score'])

In [None]:
plt.semilogy(data['grad'])

In [None]:
rewards = numpy.stack(rewards, axis=0)

In [None]:
rewards

In [None]:
m, s = numpy.median(rewards, axis=-1), rewards.std(axis=-1)

In [None]:
fi, ax = plt.subplots(1, 1, figsize=(4, 2), dpi=300)

ax.plot(numpy.mean(rewards, axis=-1))
ax.plot(numpy.median(rewards, axis=-1))
ax.plot(numpy.min(rewards, axis=-1))
ax.plot(numpy.std(rewards, axis=-1))
# ax.plot(m+s * 1.96)
# ax.plot(m-s * 1.96)

plt.show()

<br>

The ultimate evaluation run

In [None]:
with factory_eval() as env:
    learner.eval()
    eval_rewards, info = core.evaluate([
        env
    ], learner, render=True, n_steps=1e4, device=device_)

print(sum(eval_rewards))

<br>

Let's analyze the performance

In [None]:
import math
from scipy.special import softmax, expit, entr

*head, n_actions = info['logits'].shape
proba = softmax(info['logits'], axis=-1)

fig, ax = plt.subplots(1, 1, figsize=(4, 2), dpi=300)
ax.plot(entr(proba).sum(-1)[:, 0])
ax.axhline(math.log(n_actions), c='k', alpha=0.5, lw=1);

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4, 2), dpi=300)
ax.hist(info['logits'][..., 1] - info['logits'][..., 0], bins=51);  # log-ratio

<br>

In [None]:
assert False

<br>

In [None]:
# stepno = batch.state.stepno
stepno = torch.arange(8192)

In [None]:
with torch.no_grad():
    out = learner.policy[0]['stepno'](stepno)

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(8, 8), dpi=200,
                         sharex=True, sharey=True)

for j, ax in zip(range(out.shape[1]), axes.flat):
    ax.plot(out[:, j], lw=1)

fig.tight_layout(pad=0, h_pad=0, w_pad=0)

In [None]:
with torch.no_grad():
    plt.imshow(abs(learner.policy[4].weight) @ abs(learner.policy[1].weight))

In [None]:
with torch.no_grad():
    plt.imshow(abs(learner.policy[0]['stepno'][-1].weight)[:, :16].T)

<br>