# `rlplay`-ing around with the Double-Dee QN

In [None]:
import torch
import numpy

import matplotlib.pyplot as plt
%matplotlib inline

import gym

See example.ipynb for the overview of `rlplay`

<br>

## Tabular CartPole with Q-learning

### The environment

A version of the Taxi Environment that disassembles the observation into pieces

In [None]:
from gym.spaces import Dict, Discrete


class StructuredTaxiEnv(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = Dict(dict(
            row=Discrete(5),
            col=Discrete(5),
            location=Discrete(5),
            destination=Discrete(4)
        ))

    def observation(self, observation):
        # observation = r x c x p x d \in 5 x 5 x 5 x 4
        n, d = divmod(observation, 4)
        n, p = divmod(n, 5)
        r, c = divmod(n, 5)
        return dict(r=r, c=c, p=p, d=d)

A wrapper that makes Taxi's `.rander` compatible with `core_evaluate`

In [None]:
class RenderPatch(gym.Wrapper):
    def render(self, mode='human'):
        result = self.env.render(mode)
        if result is None:
            return True

        return result

The environment factory

In [None]:
from gym.envs.toy_text import TaxiEnv
from functools import partial


def factory(struct=False):
    env = gym.make("Taxi-v3")
    if struct:
        env = StructuredTaxiEnv(env)
    return RenderPatch(env)


factory_eval = partial(factory)

<br>

### the Actor

A procedure and a layer, which converts the input integer data into its
little endian binary representation as float $\{0, 1\}^m$ vectors.

In [None]:
def onehotbits(input, n_bits=63, dtype=torch.float):
    """Encode integers to fixed-width binary floating point vectors"""
    assert not input.dtype.is_floating_point
    assert 0 < n_bits < 64  # torch.int64 is signed, so 64-1 bits max

    # n_bits = {torch.int64: 63, torch.int32: 31, torch.int16: 15, torch.int8 : 7}

    # get mask of set bits
    pow2 = torch.tensor([1 << j for j in range(n_bits)]).to(input.device)
    x = input.unsqueeze(-1).bitwise_and(pow2).to(bool)

    # upcast bool to float to get one-hot
    return x.to(dtype)


class OneHotBits(torch.nn.Module):
    def __init__(self, n_bits=63, dtype=torch.float):
        assert 1 <= n_bits < 64
        super().__init__()
        self.n_bits, self.dtype = n_bits, dtype

    def forward(self, input):
        return onehotbits(input, n_bits=self.n_bits, dtype=self.dtype)

A special module dictionary, which aplies itself to the input dict of tensors

In [None]:
from typing import Optional, Mapping
from torch.nn import Module, ModuleDict as BaseModuleDict


class ModuleDict(BaseModuleDict):
    """The ModuleDict, that applies itself to hte indup dicts."""
    def __init__(
        self,
        modules: Optional[Mapping[str, Module]] = None,
        dim: Optional[int]=-1
    ) -> None:
        super().__init__(modules)
        self.dim = dim

    def forward(self, input):
        # enforce concatenation in the order of the declaration in  __init__
        return torch.cat([
            m(input[k]) for k, m in self.items()
        ], dim=self.dim)

A simple tabular q-learner actor

In [None]:
from rlplay.engine import BaseActorModule


class CartPoleActor(BaseActorModule):
    def __init__(self, duelling=False, epsilon=0.1):
        super().__init__()
        self.duelling = duelling

        # the table is actually an embedding of the state index
        n_dim = 6 + (1 if duelling else 0)
        self.table = torch.nn.Embedding(500, n_dim)

        # for updating the exploration epsilon in the clones
        self.register_buffer('epsilon', torch.tensor(epsilon))

    def forward(self, obs, act, rew, fin, *, hx=None, stepno=None, virtual=False):
        qv, hx = self.table(obs), ()
        if self.duelling:
            val, adv= torch.split_with_sizes(qv, [1, 6], dim=-1)
            qv = adv + (val - adv.mean(dim=-1, keepdim=True))

            val, actions = val.squeeze(-1), qv.max(dim=-1).indices

        else:
            val, actions = qv.max(dim=-1)

        if self.training:
            *head, n_actions = qv.shape
            actions = actions.where(
                torch.rand(head).gt(self.epsilon),
                torch.randint(n_actions, size=head))

        return actions, hx, dict(q=qv, value=val)

<br>

#### D-DQN loss

Service functions for the algorithms

In [None]:
from rlplay.engine.utils.plyr import apply, suply, xgetitem


def timeshift(state, *, shift=1):
    """Get current and shfited slices of nested objects."""
    # use xgetitem to lett None through
    # XXX `curr[t]` = (x_t, a_{t-1}, r_t, d_t), t=0..T-H
    curr = suply(xgetitem, state, index=slice(None, -shift))

    # XXX `next[t]` = (x_{t+H}, a_{t+H-1}, r_{t+H}, d_{t+H}), t=0..T-H
    next = suply(xgetitem, state, index=slice(shift, None))

    return curr, next

Double DQN loss for contiguous trajectory fragements. 

* `.state[t+1].rew` -- $r_{t+1}$
* `.state[t+1].fin` -- $d_{t+1}$
* `.state[t+1].act` -- $a_t$ which caused $
    (s_t, z_t, a_t) \longrightarrow (s_{t+1}, z_{t+1}, r_{t+1}, d_{t+1})
$
* `_, _, fragment.actor[t] = actor(.state[t])` -- $
    q(z_t, h_t, \cdot; \theta_{\text{old}})
$ -- used for rollout collection
* `_, _, info_module[t] = module(.state[t])` -- $
    q(z_t, h_t, \cdot; \theta)
$ -- the current q-function, producing $(h_t)_{t=0}^{T+1}$
* `_, _, info_target[t] = target(.state[t])` -- $
    q(z_t, h_t, \cdot; \theta_-)
$ -- the q-target with $(h^-_t)_{t=0}^{T+1}$

The current q-network minimizes the $\mathrm{TD}(0)$-error is

$$
\delta_t(\theta)
    = \bigl(
        r_{t+1}
        + \gamma 1_{\{\neg d_{t+1}\}} v^*(z_{t+1})
    \bigr) - q(z_t, h_t, a_t; \theta)
    \,, $$

where the approximate state value estimate $
    v^*(z_{t+1})
$ is one of
* `Q-learning`: $
    \max_a q(z_{t+1}, h_{t+1}, a; \theta)
$
* `DQN`: $
    \max_a q(z_{t+1}, h_{t+1}, a; \theta_-)
$
* `double DQN`: $
    q(z_{t+1}, h_{t+1}, \hat{a}_{t+1}; \theta_-)
$ for $
    \hat{a}_{t+1} = \arg \max_a Qq(z_{t+1}, h_{t+1}, a; \theta)
$

One of the works realted to DQN learning of recurrent agents is [Kapturowski et al. (2018)](https://openreview.net/forum?id=r1lyTjAqYX), who propose to use a burn-in period in
the contiguous trajectory fragment in order to compenaste for the representation drift,
due to differnet RNN parameters $\theta_-$ and $\theta$.

There is no clear-cut evidence suggesting that the hidden recurrent sequences $h_t$
and $h^-_t$ yield significantly different results.

In [None]:
import torch.nn.functional as F


# @torch.enable_grad()
def ddq_learn(fragment, module, *, gamma=0.95, target=None, double=False):
    r"""Compute the Double-DQN loss over a _contiguous_ fragment of a trajectory.

    Details
    -------
    In Q-learning the action value function minimizes the TD-error

    $$
        r_{t+1}
            + \gamma 1_{\neg d_{t+1}} v^*(z_{t+1})
            - q(z_t, a_t; \theta)
        \,, $$

    w.r.t. Q-network parameters $\theta$ where $z_t$ is the actionable state,
    $r_{t+1}$ is the reward for $s_t \to s_{t+1}$ transition. The value of
    $z_t$ include the current observation $x_t$ and the recurrent state $h_t$,
    the last action $a_{t-1}$, the last reward $r_t$, and termination flag
    $d_t$.

    In the classic Q-learning there is no target network and the next state
    optimal state value function is bootstrapped using the current Q-network
    (`module`):

    $$
        v^*(z_{t+1})
            \approx \max_a q(z_{t+1}, a; \theta)
        \,. $$

    The DQN method, proposed by

        [Minh et al. (2013)](https://arxiv.org/abs/1312.5602),

    uses a secondary Q-network (`target`) to estimate the value of the next
    state:

    $$
        v^*(z_{t+1})
            \approx \max_a q(z_{t+1}, a; \theta^-)
        \,, $$

    where $\theta^-$ are the frozen parameters of the target Q-network. The
    Double DQN algorithm of

        [van Hasselt et al. (2015)](https://arxiv.org/abs/1509.06461)

    unravels the $\max$ operator as
    $
        \max_k u_k \equiv u_{\arg \max_k u_k}
    $
    and replaces the outer $u$ with the Q-values of the target Q-network, while
    computing the inner $u$ (inside the $\arg\max$) with the current Q-network.
    Specifically, the Double DQN value estimate is

    $$
        v^*(z_{t+1})
            \approx q(z_{t+1}, \hat{a}_{t+1}; \theta^-)
            \,,
            \hat{a}_{t+1}
                = \arg \max_a q(z_{t+1}, a; \theta)
        \,, $$

    for $
        \hat{a}_{t+1}
            = \arg \max_a q(s_{t+1}, a; \theta)
    $ being the action taken by the current Q-network $\theta$ at $z_{t+1}$.

    Recurrent DQN
    -------------
    The key problem with the recurrent state $h_t$ in $z_t$ is its representaion
    drift: the endogenous states used for collecting trajectory data during the
    rollout are produced by an actor with stale perameters $\theta_{\text{old}}$,
    and thus might have high discrepancy with the recurrent state produced by
    the current Q-network $\theta$ or the target $\theta-_$. To mitigate this
        
        [Kapturowski et al. (2018)](https://openreview.net/forum?id=r1lyTjAqYX)
    
    proposed to spend a slice `burnin` of the recorded trajectory on
    aligning the recurrent representation. Specifically, starting with $h_0$
    (contained in `fragment.hx`) they propose to launch two sequences $h_t$
    and $h^-_t$ from the same $h^-_0 = h_0$ using $q(\cdot; \theta)$ and
    $q(\cdot; \theta^-)$, respectively.
    """

    trajectory, hx = fragment.state, fragment.hx
    obs, act, rew, fin = trajectory.obs, trajectory.act, trajectory.rew, trajectory.fin

    # get $Q(z_t, h_t, \cdot; \theta)$ for all t=0..T
    _, _, info_module = module(
        obs, act, rew, fin, hx=hx, stepno=trajectory.stepno)

    # get the next state `state[t+1]` $z_{t+1}$ to access $a_t$
    state_next = suply(xgetitem, trajectory, index=slice(1, None))

    # $\hat{A}_t$, the module's response to current and next state,
    #  contains the q-values. `curr` is $q(z_t, h_{t+1}, \cdot; \theta)$
    #  and `next` is $q(z_{t+1}, h_{t+1}, \cdot; \theta)$ is `next`.
    info_module_curr, info_module_next = timeshift(info_module)

    # get $q(z_t, h_t, a_t; \theta)$ for all t=0..T-1
    q_replay = info_module_curr['q'].gather(-1, state_next.act.unsqueeze(-1))

    # get $\hat{v}_{t+1}(z_{t+1}) = ...$
    with torch.no_grad():
        if target is None:
            # get $... = \max_a Q(z_{t+1}, h_{t+1}, a; \theta)$
            q_value = info_module_next['q'].max(dim=-1, keepdim=True).values

        else:
            _, _, info_target = target(
                obs, act, rew, fin, hx=hx, stepno=trajectory.stepno)

            info_target_next = suply(xgetitem, info_target, index=slice(1, None))
            if not double:
                # get $... = \max_a Q(z_{t+1}, h^-_{t+1}, a; \theta^-)$
                q_value = info_target_next['q'].max(dim=-1, keepdim=True).values

            else:
                # get $\hat{a}_{t+1} = \arg \max_a Q(z_{t+1}, h_{t+1}, a; \theta)$
                hat_act = info_module_next['q'].max(dim=-1).indices.unsqueeze(-1)

                # get $... = Q(z_{t+1}, h^-_{t+1}, \hat{a}_{t+1}; \theta^-)$
                q_value = info_target_next['q'].gather(-1, hat_act)

        # get $r_{t+1} + \gamma 1_{d_{t+1}} \hat{v}_{t+1}(z_{t+1})$ using inplace ops
        q_value.masked_fill_(state_next.fin.unsqueeze(-1), 0.)
        q_value.mul_(gamma).add_(state_next.rew.unsqueeze(-1))

    # td-error ell-2 loss
    return F.mse_loss(q_replay, q_value, reduction='sum')

<br>

prepare the optimizer for the learner

In [None]:
gamma = 0.6
use_target, use_double, use_duelling = False, False, False

# `target` does not work for some reason
# `duelling` fails for both target and double, and sorta for ordinary q

Initialize the learner

In [None]:
learner = CartPoleActor(duelling=use_duelling, epsilon=1.)

learner.train()
device_ = torch.device('cpu')  # torch.device('cuda:0')
learner.to(device=device_)

optim = torch.optim.SGD(learner.parameters(), lr=1e-1)

Initialize the sampler

In [None]:
T, B = 25, 8

In [None]:
from rlplay.engine.rollout import multi

batchit = multi.rollout(
    factory,
    learner,
    n_steps=T,
    n_actors=8,
    n_per_actor=B,
    n_buffers=16,
    n_per_batch=2,
    sticky=False,
    pinned=False,
    clone=True,
    close=False,
    device=device_,
    start_method='fork',  # fork in notebook for macos, spawn in linux
)

A generator of evaluation rewards

In [None]:
from rlplay.engine.rollout.evaluate import evaluate

test_it = evaluate(factory_eval, learner, n_envs=4, n_steps=500,
                   clone=False, device=device_, start_method='fork')

Implement your favourite training method

In [None]:
import tqdm
import copy
from math import log, exp
from torch.nn.utils import clip_grad_norm_

torch.set_num_threads(1)

# the training loop
losses, rewards = [], []
decay = -log(2) / 50  # exploration epsilon halflife
for epoch in tqdm.tqdm(range(400)):
    # freeze the target for q
    target = copy.deepcopy(learner) if use_target else None

    for j, batch in zip(range(100), batchit):
        loss = ddq_learn(batch, learner, target=target,
                         gamma=gamma, double=use_double)


        optim.zero_grad()
        loss.backward()
        grad = clip_grad_norm_(learner.parameters(), max_norm=1e3)
        optim.step()

        losses.append(dict(
            loss=float(loss), grad=float(grad),
        ))

    learner.epsilon.mul_(exp(decay)).clip_(0.1, 1.0)

    # fetch the evaluation results lagged by one inner loop!
    rewards.append(next(test_it))

In [None]:
# close the generators
batchit.close()
test_it.close()

<br>

In [None]:
def collate(records):
    """collate identically keyed dicts"""
    out, n_records = {}, 0
    for record in records:
        for k, v in record.items():
            out.setdefault(k, []).append(v)
    
    return out


data = {k: numpy.array(v) for k, v in collate(losses).items()}

In [None]:
plt.semilogy(data['loss'])

In [None]:
plt.semilogy(data['grad'])

In [None]:
rewards = numpy.stack(rewards, axis=0)

In [None]:
rewards

In [None]:
m, s = numpy.median(rewards, axis=-1), rewards.std(axis=-1)

In [None]:
fi, ax = plt.subplots(1, 1, figsize=(4, 2), dpi=300)

ax.plot(numpy.mean(rewards, axis=-1))
ax.plot(numpy.median(rewards, axis=-1))
ax.plot(numpy.min(rewards, axis=-1))
ax.plot(numpy.std(rewards, axis=-1))
# ax.plot(m+s * 1.96)
# ax.plot(m-s * 1.96)

plt.show()

<br>

The ultimate evaluation run

In [None]:
from rlplay.engine import core

with factory_eval() as env:
    learner.eval()
    eval_rewards, info = core.evaluate([
        env
    ], learner, render=True, n_steps=1e2, device=device_)

print(sum(eval_rewards))

<br>

Let's analyze the performance

In [None]:
plt.plot((info['q'] - numpy.expand_dims(info['value'], -1))[:, 0])

In [None]:
from rlplay.algo.returns import npy_returns, npy_deltas

td_target = eval_rewards + gamma * info['value'][1:]
td_error = td_target - info['value'][:-1]
# td_error = npy_deltas(
#     eval_rewards, numpy.zeros_like(eval_rewards, dtype=bool), info['value'][:-1],
#     gamma=gamma, bootstrap=info['value'][-1])

fig, ax = plt.subplots(1, 1, figsize=(4, 2), dpi=300)
ax.semilogy(abs(td_error) / abs(td_target))
ax.set_title('relative td(1)-error');

In [None]:
from rlplay.algo.returns import npy_returns, npy_deltas

# plt.plot(
#     npy_returns(eval_rewards, numpy.zeros_like(eval_rewards, dtype=bool),
#                 gamma=gamma, bootstrap=info['value'][-1]))
fig, ax = plt.subplots(1, 1, figsize=(4, 2), dpi=300)
ax.plot(info['value']);

<br>

In [None]:
assert False

<br>

In [None]:
# stepno = batch.state.stepno
stepno = torch.arange(256)

In [None]:
with torch.no_grad():
    out = learner.features['stepno'](stepno)

    out = F.linear(F.relu(out), learner.core[1].weight[:, -8:],
                       bias=learner.core[1].bias)
#     out = F.linear(F.relu(out), learner.core.weight_ih_l0[:, -8:],
#                        bias=learner.core.bias_ih_l0)
#     out = F.relu(out)

In [None]:
fig, axes = plt.subplots(8, 8, figsize=(8, 8), dpi=200,
                         sharex=True, sharey=True)

for j, ax in zip(range(out.shape[1]), axes.flat):
    ax.plot(out[:, j], lw=1)

fig.tight_layout(pad=0, h_pad=0, w_pad=0)

In [None]:
with torch.no_grad():
    plt.imshow(abs(learner.core[1].weight[:, -8:]).T)

In [None]:
lin = learner.features.stepno[1]

In [None]:
with torch.no_grad():
    plt.imshow(abs(lin.weight))

<br>