# `rlplay`-ing around

In [None]:
import torch
import numpy

<br>

## Rollout collection

Rollout collection is designed to be as much `plug-n-play` as possible, i.e. it supports
arbitrarily structured nested containers of arrays or tensors for environment observations
and actions.

**Assumptions**
* the environment communicates either in python scalars or in numpy arrays
* the nested containers are built from `dicts`, `lists`, `tuple` or `namedtuples`

In theory, there is no need for special data preporcessing, except for casting data to
correct dtypes (like obs to `float32` in `CartPole`).

The actor should also support certain API.

### Creating the actors

#### Semantics

The actor performs the following update
$$
    (x_t, a_{t-1}, r_t, d_t, h_t)
        \overset{\mathrm{Actor}}{\longrightarrow}
        (a_t, h_{t+1})
    \,, $$

and the environment -- 
$$
    (s_t, a_t)
        \overset{\mathrm{Env}}{\longrightarrow}
        (s_{t+1}, x_{t+1}, r_{t+1}, d_{t+1})
    \,. $$

Rollout collection relies on the following API of the actor:
* `.reset(j, hx)` reset the recurrent state of the j-th environment in the batch (if applicable)
  * `hx` contains tensors with shape `n_lstm_layers x batch x hidden`, or is an empty tuple

* `.step(obs, act, rew, fin, hx)` get the next action $a_t$, the recurrent state $h_{t+1}$, and
the extra info in response to $x_t$, $a_{t-1}$, $r_t$, $d_t$, and $h_t$, respectively.
  * extra info dict should include `value`
  * MUST allocate new `hx` if the recurrent state is updated
  * MUST NOT change the inputs in-place

* `.value(obs, act, rew, fin, hx)` compute the value-function estimate $
      v(s_t) \approx G_t = \mathbb{E} \sum_{j\geq t} \gamma^{j-t} r_{j+1}
  $.

In [None]:
from rlplay.engine.base import BaseActorModule

All these methods call `.forward(obs, act, rew, fin, hx)`, which should return three things:
1. `actions` prescribed actions in the environment, with data of shape `n_steps x batch x ...`
2. `hx` data with shape `n_steps x batch x ...`
3. `info` dict with extra `n_steps x batch x ...` data
  * `value` -- the value function estimates
  * `logits` -- the policy logits (if applicable)

In [None]:
from rlplay.utils.common import multinomial

class nonRecurrentPolicyWrapper(BaseActorModule):
    """Example wrapper for a non-recurrent policy.
    
    Details
    -------
    This example assumes flat `Discrete(n)` action space, and
    simple non-structured observation space, e.g. a python scalar
    or a `numpy.array`.
    """

    def __init__(self, policy, *, epsilon=0.1):
        super().__init__()
        self.policy, self.epsilon = policy, epsilon

    def forward(self, obs, act=None, rew=None, fin=None, *, hx=None):
        # Everything is  [T x B x ...]
        logits, hx = self.policy(obs, act, rew), ()

        # value must not have any trailing dims, i.e. T x B
        value = logits.new_zeros(fin.shape)

        # XXX eps-greedy?
        if self.training:
            unif = torch.tensor(1. / logits.shape[-1])

            prob = logits.detach().exp()
            prob.mul_(1 - self.epsilon)
            prob.add_(unif, alpha=self.epsilon)

            actions = multinomial(prob)

        else:
            actions = logits.argmax(dim=-1)

        return actions, hx, dict(value=value, logits=logits)

<br>

### Rollout collection (same-process)

Collect rollouts within the current process

In [None]:
from rlplay.engine.rollout import same

The parameters have the following meaning
```python
n_envs = 16     # the number of envs in the batch
n_steps = 51    # the length of each rollout fragment
sticky = False  # whether to stop interacting if an env resets mid-fragment
device = None   # specifies the device to put the actor's inputs onto
```


`rollout()` returns an iterator, which does the following, roughly.

Prepare the run-time context for the specified `actor` and the environments
```python
# spawn multiple envs
envs = [env_factory() for _ in n_envs]

# initialize a buffer for one rollout fragment (optionally pinned)
buffer = prepare(envs[0], actor, n_steps, len(envs),
                 pinned=pinned, device=device)

# the running context tor the actor and the envs
ctx, fragment = startup(envs, actor, buffer, pinned=pinned)
```

Now within the infinite loop it does the following
```python
# collect the fragment
collect(envs, actor, fragment, ctx, sticky=sticky, device=device)

# fragment.pyt -- torch tensors, fragment.npy -- numpy arrays (aliased)
# copy fragment.pyt onto `device`, and yield it to the user
```

The user has to manually limit the number of iterations using, for example,

```python
it = same.rollout(...)

for b, batch in zip(range(100), it):
    # train on batch
    pass

it.close()
```

<br>

### Rollout collection (single-process)

Single-actor rollout sampler running in a parallel process (double-buffered).

In [None]:
from rlplay.engine.rollout import single

Under the hood the functions creates **two** rollout fragment buffers, maintains
a reference to the specified `actor`, makes a shared copy of it (on the host), and
then spawns one worker process.

The worker, in turn, makes its own local copy of the actor on the specified device,
initializes the environments and the running context. During collection it altrenates
between the buffers, into which it records the rollout fragments it collects. Except
for double buffering, the logic is identical to `rollout`.

The local copies of the actor are **automatically updated** from the maintained reference.

```python
it = single.rollout(
    factory,              # the environment factory
    actor,                # the actor reference, used to update the local actors

    n_steps,              # the duration of a rollout fragment
    n_envs,               # the number of independent environments in the batch

    sticky=False,         # do we freeze terminated environments until the end of the rollout?
                          #  required if we wish to leverage cudnn's fast RNN implementations,
                          #  instead of manually stepping through the RNN core.

    close=True,           # should we `.close()` the environments when cleaning up?
                          #  some envs are very particular about this, e.g. nle

    start_method='fork',  # `fork` in notebooks, `spawn` in linux/macos and if we interchange
                          #  cuda tensors between processes (we DO NOT do that: we exchange indices
                          #  to host-shapred tensors)

    device=None,          # the device on which to collect rollouts (the local actor is moved
                          #  onto this device)
)

# ...

it.close()
```

<br>

### Rollout collection (multi-process)

A more load-balanced multi-actor milti-process sampler

In [None]:
from rlplay.engine.rollout import multi

This version of the rollout collector allocates several buffers and spawns
many parallel workers. Each worker creates it own local copy of the actor,
instantiates `n_envs` local environments and allocates a running context for
all of them. The rollout collection in each worker is **hardcoded to run on
the host device**.

```python
it = multi.rollout(
    factory,              # the environment factory
    actor,                # the actor reference, used to update the local actors

    n_steps,              # the duration of each rollout fragment

    n_actors,             # the number of parallel actors
    n_per_actor,          # the number of independent environments run in each actor
    n_buffers,            # the size of the pool of buffers, into which rollout
                          #  fragments are collected. Should not be less than `n_actors`.
    n_per_batch,          # the number of fragments collated into a batch

    sticky=False,         # do we freeze terminated environments until the end of the rollout?
                          #  required if we wish to leverage cudnn's fast RNN implementations,
                          #  instead of manually stepping through the RNN core.

    pinned=False,

    close=True,           # should we `.close()` the environments when cleaning up?
                          #  some envs are very particular about this, e.g. nle

    device=None,          # the device onto which to move the rollout batches

    start_method='fork',  # `fork` in notebooks, `spawn` in linux/macos and if we interchange
                          #  cuda tensors between processes (we DO NOT do that: we exchange indices
                          #  to host-shared tensors)
)

# ...

it.close()
```

<br>

### Evaluation (same-process)

In order to evaluate an actor in a batch of environments, one can use `evaluate`.

In [None]:
from rlplay.engine.base import evaluate

The function *does not* collect the rollout data, except for the rewards.
Below is the intended use case.
* **NB** this is run in the same process, hence blocks until completion, which
might take considerable time (esp. if `n_steps` is unbounded)

In [None]:
def test(
    factory, actor, n_envs=4,
    *, n_steps=None, close=True, render=False, device=None
):
    # spawn a batch of environments
    envs = [factory() for _ in range(n_envs)]

    try:
        while True:
            rewards = evaluate(envs, learner, n_steps=n_steps,
                               render=render, device=device)

            # get the accumulated rewards (gamma=1)
            yield sum(rewards)

    finally:
        if close:
            for e in envs:
                e.close()
    

<br>

## CartPole with REINFORCE or A2C

The reinforce PG algo

In [None]:
import torch.nn.functional as F
from rlplay.engine.returns import pyt_returns

def reinforce(batch, module, *, gamma=0.99, C_entropy=1e-2):
    # actor responses to `x_t`, `a_{t-1}`, `r_t`, `d_t` with `a_t`
    #  (obs[t], act[t-1], rew[t-1], fin[t-1]) -->> act[t]
    #  (o[1:-1], a[:-1], r[:-1], d[:-1]), a[1:]
    _, _, info = module(
        batch.state.obs[1:-1],
        batch.state.act[:-1],
        batch.state.rew[:-1],
        batch.state.fin[:-1], hx=batch.hx)

    # REINFORCE
    # compute returns: G_t = r_{t+1} + \gamma G_{t+1}, and
    #  v(s_t) \approx \mathbb{E}_{\pi_{\geq t}} G_t(a_t, s_{t+1}, a_{t+1}, ... \mid s_t).
    # XXX GAE?
    G_t = pyt_returns(batch.state.rew, batch.state.fin, gamma=gamma,
                      bootstrap=batch.bootstrap[0])[1:]

    # \pi is the target policy, mu is the behaviour policy
    logits, actions = info['logits'], batch.state.act[1:].unsqueeze(-1)
    log_pi = logits.gather(-1, actions).squeeze(-1)

    log_mu = batch.actor['logits'].gather(-1, actions).squeeze(-1)

    # the importance weights
    rho = log_mu.sub_(log_pi.detach()).neg_().exp_().clamp_(max=1)

    # reinforce grads G_t \nabla \log \pi(a_t\mid s_t)
    policy_score = log_pi.mul(G_t - G_t.mean(dim=0)).mul(rho).mean()

    # maximize policy entropy:
    #   H(\pi(•\mid s)) = - \sum_a \pi(a\mid s) \log \pi(a\mid s)
    f_min = torch.finfo(logits.dtype).min
    entropy = logits.exp().mul(logits.clamp(min=f_min))\
                    .sum(dim=-1).neg().mean()

    # weighted sum of the policy score and entropy
    # \ell := - \frac1T \sum_t G_t \log \pi(a_t \mid s_t)
    #         + C \mathbb{H} \pi(\cdot \mid s_t)
    objective = policy_score + C_entropy * entropy

    # optimize: use inplace neg_ for maximization
    return objective.neg().mean(), dict(returns=G_t, entropy=float(entropy),
                                        policy_score=float(policy_score))

Actor-critic algo

In [None]:
def a2c(batch, module, *, gamma=0.99, C_entropy=1e-2, C_value=0.25):
    # actor responses to `x_t`, `a_{t-1}`, `r_t`, `d_t` with `a_t`
    #  (obs[t], act[t-1], rew[t-1], fin[t-1]) -->> act[t]
    #  (o[1:-1], a[:-1], r[:-1], d[:-1]), a[1:]
    _, _, info = module(
        batch.state.obs[1:-1],
        batch.state.act[:-1],
        batch.state.rew[:-1],
        batch.state.fin[:-1], hx=batch.hx)

    # Advantage Actor-Critic
    G_t = pyt_returns(batch.state.rew, batch.state.fin,
                      gamma=gamma, bootstrap=batch.bootstrap[0])[1:]

    # \pi is the target policy, mu is the behaviour policy
    logits, actions = info['logits'], batch.state.act[1:].unsqueeze(-1)
    log_pi = logits.gather(-1, actions).squeeze(-1)

    log_mu = batch.actor['logits'].gather(-1, actions).squeeze(-1)

    # the importance weights
    rho = log_mu.sub_(log_pi.detach()).neg_().exp_().clamp_(max=1)

    # compute the critics loss
    #  \frac1{NT} \sum_{jt} (G_t(\tau_j) - v(s_t(\tau_j)))^2
    adv = G_t - info['value']
    value_score = F.mse_loss(info['value'], G_t, reduction='mean').neg()

    # compute the policy loss
    #  - \frac1N\sum_j (G_j - v(s_j)) \log \pi(a_j \mid s_j)
    policy_score = log_pi.mul(adv.detach()).mul(rho).mean()

    # maximize policy entropy:
    #   H(\pi(•\mid s)) = - \sum_a \pi(a\mid s) \log \pi(a\mid s)
    f_min = torch.finfo(logits.dtype).min
    entropy = logits.exp().mul(logits.clamp(min=f_min))\
                    .sum(dim=-1).neg().mean()

    # weighted sum of the policy and value score and the entropy
    objective = policy_score + C_entropy * entropy + C_value * value_score

    # optimize: use inplace neg_ for maximization
    return objective.neg().mean(), dict(
        returns=G_t, value_score=float(value_score),
        entropy=float(entropy), policy_score=float(policy_score))

The policy of the actor

The environment factory

In [None]:
import time
import gym
import rlplay.utils.integration.gym
from rlplay.zoo.env import NarrowPath


class FP32Observation(gym.ObservationWrapper):
    def observation(self, observation):
        return observation.astype(numpy.float32)

class OneHotObservation(gym.ObservationWrapper):
    def observation(self, observation):
        return numpy.eye(1, self.env.observation_space.n,
                         k=observation, dtype=numpy.float32)[0]

def factory():
    return FP32Observation(gym.make("CartPole-v0").unwrapped)
#     return gym.make("Taxi-v3").unwrapped
    # return OneHotObservation(NarrowPath())

Initialize the learner

In [None]:
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence


class CartPoleActor(BaseActorModule):
    def __init__(self, epsilon=0.1, lstm=False):
        super().__init__()
        self.epsilon, self.lstm = epsilon, lstm

        self.features = torch.nn.ModuleDict(dict(
            obs=torch.nn.Sequential(
                torch.nn.Linear(4, 64),
                torch.nn.ReLU(),
            ),
            act=torch.nn.Embedding(2, 4),
            rew=torch.nn.Sequential(
                torch.nn.Linear(1, 4),
                torch.nn.ReLU(),
            ),
        ))

        n_features = 64 + 4 + 4
        if not self.lstm:
            self.core = torch.nn.Sequential(
                torch.nn.Linear(n_features, 64),
                torch.nn.ReLU(),
            )

        else:
            self.core = torch.nn.GRU(n_features, 64, 1)

        self.policy = torch.nn.Sequential(
            torch.nn.Linear(64, 2),
            torch.nn.LogSoftmax(dim=-1),
        )

        self.baseline = torch.nn.Sequential(
            torch.nn.Linear(64, 1),
        )

    def forward(self, obs, act, rew, fin, *, hx=None):
        # Everything is  [T x B x ...]
        input = torch.cat([
            self.features['obs'](obs),
            self.features['act'](act),
            self.features['rew'](rew.unsqueeze(-1)),
        ], dim=-1)
        
        if not self.lstm:
            output, hx = self.core(input), ()

        else:
            # sequence padding (MUST have sampling with `sticky=True`)
            n_steps, n_env, *_ = fin.shape
            if n_steps > 1:
                # we assume sticky=True
                lengths = 1 + (~fin[1:]).sum(0)
                input = pack_padded_sequence(input, lengths, enforce_sorted=False)

            output, hx = self.core(input, hx)
            if n_steps > 1:
                output, lens = pad_packed_sequence(
                    output, batch_first=False, total_length=n_steps)

        # value must not have any trailing dims, i.e. T x B
        logits = self.policy(output)
        value = self.baseline(output).squeeze(-1)

        # XXX eps-greedy?
        if self.training:
            unif = torch.tensor(1. / logits.shape[-1])

            prob = logits.detach().exp()
            prob.mul_(1 - self.epsilon)
            prob.add_(unif, alpha=self.epsilon)

            actions = multinomial(prob)

        else:
            actions = logits.argmax(dim=-1)

        return actions, hx, dict(value=value, logits=logits)

In [None]:
# pyt_gae(batch.state.rew, batch.state.fin, batch.actor['value'], gamma=0.99, bootstrap=batch.bootstrap[0])

In [None]:
# learner, sticky = nonRecurrentPolicyWrapper(policy()), False
learner = CartPoleActor(lstm=False)
sticky = learner.lstm

learner.train()
device_ = torch.device('cpu')  # torch.device('cuda:0')
learner.to(device=device_)

# prepare the optimizer for the learner
optim = torch.optim.Adam(learner.parameters(), lr=1e-3)

Handy procedure to evaluate the actor

In [None]:
import gym
import rlplay.utils.integration.gym  # hotfix for gym's poor viz (spawns gl threads!)

from rlplay.engine.base import evaluate

Load a better trained agent

Initialize the sampler

In [None]:
import matplotlib.pyplot as plt

# T, B = 120, 20
# T, B = 120, 4
T, B = 51, 4

Pick one collector

In [None]:
# generator of rollout batches
batchit = multi.rollout(
    factory,
    learner,
    n_steps=T,
    n_actors=8,
    n_per_actor=B,
    n_buffers=16,
    n_per_batch=2,
    sticky=sticky,  # so that we can leverage cudnn's fast RNN implementations
    pinned=False,
    close=False,
    device=device_,
    start_method='fork',  # fork in notebook for macos, spawn in linux
)

Implement your favourite training method

In [None]:
import tqdm

gamma = 0.99
losses, rewards = [], []

# generator of evaluation rewards
test_it = test(factory, learner, n_envs=4, n_steps=500)

# the training loop
for epoch in tqdm.tqdm(range(100)):
    for j, batch in zip(range(100), batchit):

        optim.zero_grad()
        loss, info = a2c(batch, learner, gamma=gamma)
        loss.backward()
        optim.step()

        losses.append(tuple(map(
            float, (info['policy_score'], info['entropy'])
        )))

    rewards.append(next(test_it))

    if rewards[-1].min() > 900:
        break

# close the generators
batchit.close()
test_it.close()

<br>

In [None]:
plt.plot(losses)

In [None]:
rewards = numpy.stack(rewards, axis=0)

In [None]:
rewards

In [None]:
m, s = numpy.median(rewards, axis=-1), rewards.std(axis=-1)

In [None]:
plt.plot(numpy.mean(rewards, axis=-1))
plt.plot(numpy.median(rewards, axis=-1))
plt.plot(numpy.min(rewards, axis=-1))
plt.plot(numpy.std(rewards, axis=-1))
# plt.plot(m+s * 1.96)
# plt.plot(m-s * 1.96)

In [None]:
with factory() as env:
    learner.eval()
    print(sum(evaluate([
        env
    ], learner, render=True)))

<br>

In [None]:
assert False

<br>

<br>

<br>

In [None]:
import matplotlib.pyplot as plt

p_l, v_l, ent = zip(*losses)

plt.plot(p_l)
plt.plot(ent)

In [None]:
plt.plot(v_l)

Run in the environment

In [None]:
plt.plot([
    sum(evaluate(factory, learner, render=False))
    for _ in range(200)
])

<br>

In [None]:
assert False

In [None]:
class Bar:
    def __init__(self, parent):
        self.parent = parent
        self._range = range(self.parent.n)
        self._it = iter(self._range)

    def __iter__(self):
        return self

    def __next__(self):
        return next(self._it)

class Foo:
    def __init__(self, n=10):
        self.n = n

    def __iter__(self):
        return Bar(self)


In [None]:
list(Foo())

In [None]:
class Bar:
    def __init__(self, parent):
        self.parent = parent

    def __iter__(self):
        yield from range(self.parent.n)

class Foo:
    def __init__(self, n=10):
        self.n = n

    def __iter__(self):
        return iter(Bar(self))


<br>