# Let's try an non-hierarchical RL agent

In [None]:
import time
import gym
import nle

import matplotlib.pyplot as plt

In [None]:
del gym.Wrapper.__getattr__

Import other useful modules

In [None]:
import plyr
import torch
import numpy as np

from torch import nn
from torch.nn import functional as F

Select the device

In [None]:
# device_ = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device_ = torch.device('cpu')

Adam with high `weight_decay` may push many parameters' values
into denormalized fp mode, which is ultra slow on CPU (but not
as bad on GPU), see the answer and a reply form *njuffa* in to
this [stackoverflow](https://stackoverflow.com/questions/36781881).

In [None]:
torch.set_flush_denormal(True)

Load modified minihack tasks, which punish death by `-1` reward
* see the base class [MiniHack](https://github.com/facebookresearch/minihack/blob/65fc16f0f321b00552ca37db8e5f850cbd369ae5/minihack/base.py#L131-L132) subclassed by `MiniHackNavigation` and `MiniHackRoom`

In [None]:
import nle_toolbox.utils.env.minihack

* [MiniHack-HideNSeek-Big-v0](https://minihack.readthedocs.io/en/latest/envs/navigation/hidenseek.html)
> ... the agent is spawned in a big room full of trees and clouds \[, which\] block the line of sight of the player\[, \] and a random monster (chosen to be more powerful than the agent). The agent, monsters and spells can pass through clouds unobstructed \[but\] cannot pass through trees. The goals is to make use of the environment features, avoid being seen by the monster and quickly run towards the goal.

* [MiniHack-WoD-Hard-v0](https://minihack.readthedocs.io/en/latest/envs/skills/wod.html)
> ... require mastering the usage of the wand of death (WoD). Zapping a WoD it in any direction fires a death ray which instantly kills almost any monster it hits. ... the WoD needs to be found first, only then the agent should enter the corridor with a monster (who is awake and hostile this time), kill it, and go to the staircase.

* [MiniHack-LavaCross-v0](https://minihack.readthedocs.io/en/latest/envs/skills/lava_cross.html)
> The agent can accomplish this by either levitating over it (via a potion of levitation or levitation boots) or freezing it (by zapping the wand of cold or playing the frost horn).

* [MiniHack-Memento-F4-v0](https://minihack.readthedocs.io/en/latest/envs/navigation/memento.html)
> The agent is presented with a prompt (in the form of a sleeping monster of a specific type), and then navigates along a corridor. At the end of the corridor the agent reaches a fork, and must choose a direction. One direction leads to a grid bug, which if killed terminates the episode with +1 reward. All other directions lead to failure through a invisible trap that terminates the episode when activated. The correct path is determined by the cue seen at the beginning of the episode.

We hide the NLE under several layers of wrappers. From the core to the shell:
1. `ReplayToFile` handles seeding and logs the taken actions and seed into a file for later inspection and replay.

2. `NLEObservationPatches` patches tty-screens, botched by the cr-lf misconfiguration of the NLE's tty term emulator and NetHacks displays (lf only).

3. `NLEFeatureExtractor` adds extra features generated on-the-fly from the current NLE's observation.
  * an ego-centric view of the specified radius into the dungeon map (`vicinity`)
  * percentage strength (`NLE_BL_STR125`), converted to a `100`-base score, used by the game to compute extra strength bonuses.
    * The strength stat in AD&D 2ed, upon which the mechanics of NetHack is based, comes in two ints: strength and percentage. The latter is applicable to **warrior classes** with **natural str** 18 and denotes `exceptional strength`, which confers extra chance-to-hit, damage, and chance to force locks or doors.


4. ~`RecentHistory` keeps a brief log of actions taken in the environment (partially duplicates the functionality of the `Replay` wrapper).~

5. `Chassis` handles skippable gui events that do not require a decision, such as collecting menu pages unless an interaction is required, fetching consecutive topline or log messages.

6. `ObservationDictFilter` allow only the specified keys of the observation dict to get through

7. `ActionMasker` computes the mask of action that are **forbidden** in the current game state (_gui_ or _play_)
  * the mask is communicated through the observation dicts under the key `action_mask`

8. `RecentMessageLog` keeps a log of a specified number of recent messages fetched by the upstream `Chassis` wrapper.
  * the log is reported in `message_log` field of the observation dict

In [None]:
from nle_toolbox.bot.chassis import get_wrapper

from nle_toolbox.utils.replay import ReplayToFile, Replay

from nle_toolbox.utils.env.wrappers import NLEObservationPatches
from nle_toolbox.utils.env.wrappers import NLEFeatureExtractor
from nle_toolbox.utils.env.wrappers import RecentHistory

from nle_toolbox.bot.chassis import Chassis, ActionMasker, RecentMessageLog

from nle_toolbox.utils.env.wrappers import ObservationDictFilter

The factory for collecting random exploration rollouts

In [None]:
import minihack
from nle import nethack
# from nle_toolbox.utils import seeding

DEFAULT_OBSERVATION_KEYS = tuple(
    frozenset(nethack.OBSERVATION_DESC.keys())
    - frozenset(('program_state', 'internal'))
)

def factory(
    seed=None,
    *,
    config,
    folder=None,
    sticky=False,
):
    # creat the env instance and get the action mapping
    env = gym.make(
        config['id'],
        # remove service fields 'internal' and 'program_state'
        observation_keys=DEFAULT_OBSERVATION_KEYS,
    )

    # XXX ascii (char) to action
    ctoa = {chr(a): j for j, a in enumerate(env.unwrapped.actions)}
    atoc = tuple(map(chr, env.unwrapped.actions))

    # provide seeding capabilities and full action tracking
    if folder is None:
        env = Replay(env, sticky=sticky)

    else:
        env = ReplayToFile(env, sticky=sticky, folder=folder, save_on='done')
    env.seed(seed)

    # patch bugged tty output
    env = NLEObservationPatches(env)

    # # log recent actions
    # env = RecentHistory(
    #     env,
    #     n_recent=128,
    #     map=lambda a: atoc[a],  # XXX atoc IS NOT a dict!
    # )

    # skippable gui abstraction layer. Bypassed if the action
    #  space does not bind a SPACE action.
    # XXX we can also use \015 (ENTER) aside from \040 (SPACE).
    space = ctoa.get(' ', ctoa.get('\015'))
    env = Chassis(env, space=space, split=False)

    # a feature extractor to potentially reduce the runtime complexity
    # * ego-centric view, and properly handled "exceptional strength" stat (str125)
    env = NLEFeatureExtractor(env, k=config['vicinity'])

    # filter unused observation keys
    # XXX this wrapper should be applied before any container-type
    #  modifications of the NLE's observation space.
    env = ObservationDictFilter(
        env,
        ## the map, bottom line stats and inventory
        'glyphs',
        # 'chars', 'colors', 'specials',
        'blstats',
        'inv_glyphs',
        # 'inv_strs',
        'inv_letters',
        # 'inv_oclasses',

        ## used for in-notebook rendering
        'tty_chars', 'tty_colors', 'tty_cursor',

        ## used by the GUI abstraction layer (Chassis)
        # 'message', 'misc',

        ## other fields and fields related to internal state
        # 'screen_descriptions',  # 'internal', 'program_state',

        ## extra features produced by the upstream wrappers
        'vicinity',
    )

    # compute and action mask based on the current NLE mode: gui or play
    env = ActionMasker(env)

    # track the recent messages from NetHack (necessary because the game
    #  may send multi-part messages)
    # XXX the longest modal message log does not exceeed 21 lines
    env = RecentMessageLog(env, n_recent=8)  # extra 8 x 256 bytes

    return env

A renderer for this **factory**

     y  k  u  
      \ | /   
    h - . - l 
      / | \   
     b  j  n  

In [None]:
import pprint as pp
from time import sleep

from nle_toolbox.utils.env.render import render as tty_render
from IPython.display import clear_output

def ipynb_render(obs, clear=True, fps=None):
    if fps is not None:
        if clear:
            clear_output(wait=True)

        print(tty_render(**obs))
        if fps > 0:
            sleep(fps)

    return True

<br>

Pick the environment and get its action space parameters

In [None]:
# navigation, 8 compass directions
# env_id = 'MiniHack-Room-Ultimate-15x15-v1'
# env_id = 'MiniHack-CorridorBattle-Dark-v1'
# env_id = 'MiniHack-HideNSeek-Big-v1'

# navigation with 12 actions (8 compass, wait, enter, esc, space)
# env_id = 'MiniHack-Room-Ultimate-15x15-MoreActions-v0'
env_id = 'MiniHack-CorridorBattle-Dark-MoreActions-v0'
# env_id = 'MiniHack-HideNSeek-Big-MoreActions-v0'

# memory and navigation
# env_id = 'MiniHack-Memento-F4-v1'
# env_id = 'MiniHack-Memento-Short-F2-v0'

# skill acquistion (advanced) 85 actions
# env_id = 'MiniHack-WoD-Hard-v0'
# env_id = 'MiniHack-LavaCross-v0'

# Full game (ultimate) 121 actions
# env_id = 'NetHackChallenge-v0'

The essential parameters

In [None]:
import pprint as pp
with gym.make(env_id) as env:
    pp.pprint(env.unwrapped.actions)
    n_actions = env.action_space.n
    # n_actions = 12  # 85, 121  # XXX make this an automatic setting dependent on the env

embedding_dim = 16
intermediate_size = 256
n_vicinity = 3

No intrinsic motivation

In [None]:
recipe_null = {'cls': None}

recipe_null

Intrinsic motivation via Random Network distillation

In [None]:
from nle_toolbox.zoo.models.motivation import RNDModule, RNDNetwork

recipe_rnd = RNDModule.default_recipe(
    RNDNetwork.default_recipe(
        embedding_dim=embedding_dim,
        intermediate_size=intermediate_size,
        sizes=(
            intermediate_size,
        ),
        k=n_vicinity,
        bls=('hp', 'hunger', 'condition',),
        act={
            'num_embeddings': n_actions,
            'embedding_dim': embedding_dim,
        },
    ),
    root=False,
)

recipe_rnd

Intrinsic motivation based on impact driven exploration.

In [None]:
from nle_toolbox.zoo.models.motivation import RIDEModule, RIDEEmbedding

recipe_ride = RIDEModule.default_recipe(
    n_actions=n_actions,
    embed=RIDEEmbedding.default_recipe(
        embedding_dim=embedding_dim,
        intermediate_size=intermediate_size,
        k=n_vicinity,
        bls=('hp', 'hunger', 'condition',),
        h0=True,
    ),
    sizes=(256,),
    bilinear=False,
    flip=False,
)

recipe_ride

The recipe of the basic agent

In [None]:
from nle_toolbox.zoo.models.basic import NLENeuralAgent

recipe_agent_basic = NLENeuralAgent.default_recipe(
    n_actions=n_actions,
    embedding_dim=embedding_dim,
    intermediate_size=intermediate_size,
    act_embedding_dim=None,  # default to `embedding_dim`
    fin_embedding_dim=0,  # disable `fin` flag embedding
#     core='lstm',
    core='linear',
    # it appears that more layers makes the agent learn better!
    num_layers=1,
    k=n_vicinity,
    bls=('hp', 'hunger', 'condition',),
    learn_tau=False,
)

recipe_agent_basic

A recipe for Highway Transformer

In [None]:
from nle_toolbox.zoo.models.transformer import NLEHITNeuralAgent

# highway xformer
recipe_agent_hit = NLEHITNeuralAgent.default_recipe(
    n_actions=n_actions,
    embedding_dim=embedding_dim,
    n_cls=8,
    n_io=8,  # \geq 4 is great, but \leq 2 doesn't work
    intermediate_size=64,
    num_attention_heads=4,
    head_size=16,
    num_layers=1,  # 2 is gud, but can we use one layer?
    k=n_vicinity,
    bls=('hp', 'hunger', 'condition',),
    learn_tau=False,
)

recipe_agent_hit

Assemble the full recipe: agent and the motivator

In [None]:
recipe = {
    'agent': recipe_agent_basic,
#     'agent': recipe_agent_hit,
#     'motivator': recipe_rnd,
    'motivator': recipe_ride,
#     'motivator': recipe_null,
}

NLEAgent = NLENeuralAgent
# NLEAgent = NLEHITNeuralAgent

tags = (
    'basic',  # 'hit',
#     'rnd',
    'ride',
)

intrinsic/extrinsic settings for the advantages, policy grads, critics, etc.
and parameters of the fragment collectors, optimizers and training.

In [None]:
# TODO explain why a certain parameter is used
#  and where it came from with the exact citation
d_hyper_act = dict(
    # weight of the ext/int rewards in the gae mix for the polgrad
    f_alpha={'ext': 1., 'int': 0.5,},

    # extrinsic/intrinsic reward PV discount
    f_gamma={'ext': 0.99, 'int': 0.9,},
    # XXX should we encourage more persistent exploration than exploitation?

    # the GAE discount: interpolate between TD(0) and TD(\infty)
    # XXX good for RND, but what about other methods?
    f_lambda={'ext': 0.96, 'int': 0.96,},
    # XXX there are suggestions that a lambda close to one makes
    #  the projected td(lambda) into a contraction, so we use a high
    #  value of 0.96

    # the coefficients in the loss for the both `ext` and
    #  `int` critic terms, entropy, and policy grads
    # XXX `-ve` maximizes, `+ve` minimizes
    f_coefs={
        'polgrad': -1.,
        'entropy': -0.01,
        'critic': {'ext': 0.5, 'int': 0.5,},
    },

    # the number of off-policy updates for GAE-A2C + impala or PPO+GAE
    n_off_policy=3,
    n_batch_size=None,  # the batch size for each pass (None -- full batch)
    s_off_alg='ppo',

    # the share of the runtime state's hx's gard
    #  to be passed to `h0` between bptt fragments
    f_h0_lerp=0.0,

    f_lr=1e-3,

    # the length of the truncated-bptt rollout for the actor/agent and motivator
    n_length=100,

    # gradient ell-2 norm clipping
    f_grad_norm=5.0,  # float('inf'),
)

The hyper parameters of the motivation modules

In [None]:
d_hyper_mot = dict(
    f_lr=1e-3,

    # generic motivator hyper parameters
    n_length=20,
    
    # the value to clip the motivator's rewards from above
    #  (rews are currently non-negative)
    f_rew_max=None,

    # special hyper parameters
    n_batches=1,  # the number of backprop passes thru a fragment
    n_batch_size=None,  # the batch size for each pass (None -- full batch)
    b_detach=False,

    f_coefs={
        'fwd': 10,
        'inv': 0.1,
    },

    # gradient ell-2 norm clipping
    # XXX clipping motivator's grad seems to adversely impact the performance
    f_grad_norm=float('inf'),
)

The fragmented a2c parameters

In [None]:
import wandb

wandb.init(
    mode='disabled',
    project='nle-toolbox',
    job_type='sandbox',
    config=dict(
        vicinity=n_vicinity,

        # the number of envs to run simultaneously
        n_batch=16,

        # the total number of steps allotted to training (summed across all envs)
#         n_total=2_592_000 * 6,
        n_total=16000,

        act=d_hyper_act,
        mot=d_hyper_mot,

        # also track the recipes
        recipe=recipe,
        checkpoint=None,

        # the env id
        id=env_id,

        # duplicate the fragment lengths for ease of use in the dashboard
        _n_act_length=d_hyper_act['n_length'],
        _n_mot_length=d_hyper_mot['n_length'],
    ),
    tags=tags,
);

* `linear` with small TxB batch (64x5) is more noisy, and produces less `stable` policy, but converges fastest
* `linear` with larget batch (16x100) converges a bit slower, but produces more stable policy
* `lstm` with 16x100 converges slowest

<br>

### Redesigning the building blocks of the NLE featrue extractor

A module, which delays its structured input `T x ...` by the specified number of steps.

In [None]:
from typing import Any

class DelayBy(nn.Module):
    def __init__(self, n: int = 1):
        assert n > 0
        super().__init__()
        self.n = n

    def forward(self, input: Any, hx: Any = None) -> tuple[Any, Any]:
        # initialize the state: setup the pre-history
        if hx is None:
            hx = plyr.apply(lambda t: torch.zeros_like(t[-1:]), input)
            if self.n > 1:
                hx = plyr.tuply(torch.cat, *(hx,) * self.n, dim=0)

        # splice the current input and the pre-history, then split
        spliced = plyr.apply(torch.cat, hx, input, _star=False, dim=0)
        return (
            plyr.apply(lambda t, n=self.n: t[:-n], spliced),
            plyr.apply(lambda t, n=self.n: t[-n:], spliced),
        )

    def extra_repr(self):
        return f'n={self.n}'

A clunky tool to split the parameters into biases and weights.

Unlike `.bias` parameters, which **should not** be decayed, not every `.weight` parameter
**should be** decayed. For example, learnt positional encodings can be seen as bias terms
to the "linear" operation effected by an `nn.Embedding`, at the same time ordinary token
embeddings should also not be decayed either, since they effectively serve as the input
data representation.

`nn.LayerNorm` layers perform within-layer normalization and then re-scale and translate
it, meaning that their weight should be regularized to unit scales ant not zeros.

Hence:
* all biases, and weights of `nn.LayerNorm` and `nn.Embedding` should not be regularized
by weight decay


In [None]:
from nle_toolbox.zoo.models.transformer import HiT

# XXX clean this up
def split_parameters(module):
    decay, no_decay, seen = [], [], set()
    for prefix, mod in module.named_modules():
        for name, par in mod.named_parameters(prefix='', recurse=False):
            assert par not in seen
            seen.add(par)

            # always exclude bias terms
            if name.startswith('bias'):
                no_decay.append(par)

            # but always decay dense linear weights
            elif name.startswith('weight') and isinstance(
                mod, (nn.Linear, nn.LSTM, nn.Bilinear)
            ):
                decay.append(par)

            # yet never decay normalization translations and embedding representations
            elif name.startswith('weight') and isinstance(
                mod, (nn.LayerNorm, nn.Embedding)
            ):
                no_decay.append(par)

            elif name in ('posemb', 'cls', 'iox') and isinstance(
                mod, (HiT,)
            ):
                no_decay.append(par)

            else:
                no_decay.append(par)
                # raise TypeError(f'Unrecognized parameter `{name}` in `{prefix}` {mod}.')

    return decay, no_decay

Build an agent from the recipe and reset the bias terms in its recurrent core

In [None]:
from torch.nn import init
from nle_toolbox.utils.nn import rnn_reset_bias

agent = NLEAgent(**wandb.config.recipe['agent'])
agent.to(device_)

agent.apply(rnn_reset_bias);

Init agent's optimizer

In [None]:
decay, no_decay = split_parameters(agent)

# AdamW doesn't do what you expect it to do, Ivan! (although it
#  correctly decouples the objective's grad and the ell-2 weight reg).
#  See https://arxiv.org/abs/1711.05101.pdf
agent.optim = torch.optim.AdamW([
    dict(params=decay),
    dict(params=no_decay, weight_decay=0.),
], lr=wandb.config['act']['f_lr'], eps=1e-5, weight_decay=0.001)

Load an eariler checkpoint.

In [None]:
if wandb.config.checkpoint is not None:
    ckpt = torch.load(wandb.config.checkpoint)
    print(agent.load_state_dict(ckpt['agent']))

Get the specified motivator: RIDE or RND, -- and prepare their optimizers.

In [None]:
from nle_toolbox.utils.nn import ModuleDict

recipe_m = wandb.config.recipe['motivator'].copy()
cls_motivator = recipe_m.pop('cls')

mot = None
if cls_motivator is not None:
    mot = {
        str(c): c for c in (RIDEModule, RNDModule)
    }[cls_motivator](**recipe_m)
    mot.to(device_)

    decay, no_decay = split_parameters(mot)
    mot.optim = torch.optim.AdamW([
        dict(params=decay),
        dict(params=no_decay, weight_decay=0.),
    ], lr=wandb.config['mot']['f_lr'], eps=1e-5, weight_decay=0.01)

In [None]:
agent

In [None]:
mot

<br>

### Let's train an A2C agent in a capsule

An object to extract full episodes from their trajectory fragments.

In [None]:
from nle_toolbox.utils.rl.tools import EpisodeExtractor

A capulse for a learner agent and a launcher for it.

In [None]:
from nle_toolbox.utils.rl.capsule import launch, capsule

Compute the policy gradient surrogate, the entropy and other loss components

In [None]:
from nle_toolbox.utils.rl.engine import pyt_polgrad
from nle_toolbox.utils.rl.engine import pyt_entropy
from nle_toolbox.utils.rl.engine import pyt_critic

We shall use GAE in policy gradients and returns for the critic.

In [None]:
from nle_toolbox.utils.rl.returns import pyt_ret_gae

Get the weighted sum of the leaves in one nested container with
weights from another second container.

In [None]:
import operator as op

def reduce(values, weight=None):
    flat = []
    if weight is not None:
        values = plyr.apply(op.mul, values, weight)

    plyr.apply(flat.append, values)
    return sum(flat)

A function to compute the targets (GAE, returns) for policy grads and critic loss.

In [None]:
@torch.no_grad()
def pg_targets(rew, val, /, gam, lam, *, fin):
    r"""Compute the targets (GAE, returns) for policy grads and critic loss.

    Details
    -------
    The arguments `rew`, `fin`, and `val` are $r_t$, $d_t$ and $v(s_t)$,
    respectively! The td-error terms in GAE depend on $r_{t+1}$, $d_{t+1}$
    and on both $v(s_{t+1})$ and $v(s_t)$, hence on `rew[1:]`, `fin[1:]`,
    `val[1:]` and `val[:-1]`. `val[-1]` is the value-to-go estimate for
    the last state in the related trajectory fragment.
    """
    gae, ret = {}, {}
    for k in rew:
        ret[k], gae[k] = pyt_ret_gae(
            rew[k][1:], fin[1:], val[k],
            gam=gam[k], lam=lam[k],
        )

    return gae, ret

A function to produce masks for loss dropout, to simulate random batches.

In [None]:
from nle_toolbox.utils.rl.engine import dropout_mask

The Advantage (GAE) Actor-Critic loss on a fragment.
* this one is on-policy, thus it expects the data in `vp` to be diff-able.

In [None]:
def a2c_gae(input, vp, *, gam, lam, alpha, mask=None):
    """do GAE-A2C learning w. intrinsic motivation."""
    # (gae) compute GAE and returns for all rewards
    # XXX r_{t+1}, v_t, v{t+1} -->> A_t, R_t; `rew[t]`, `fin[t]`, and
    # `val[t]` must be $r_t$, $d_t$ and $v(s_t)$, respectively, where
    # `v_t` is computed based on the historical data $\cdot_s$ with `s < t`!
    gae, ret = pg_targets(input.rew, vp.val, gam, lam, fin=input.fin)

    # (gae + motivation) get the weighted sum of advantages
    adv = reduce(gae, alpha).detach()

    return {
        # (sys) policy grad surrogate `sg(G_t) \log \pi_t(a_t)`
        'polgrad': plyr.apply(pyt_polgrad, vp.pol, input.act, adv=adv, mask=mask),
        # (sys) entropy of the policy
        'entropy': plyr.apply(pyt_entropy, vp.pol, mask=mask),
        # (sys) intrinsic/extrinsic critic loss
        'critic': plyr.apply(pyt_critic, vp.val, ret, mask=mask),
    }

The PPO + GAE.

Recall tha we have
$$
\mathbb{E}_{a \sim \pi}
    f_a \nabla\log\pi_a
    = \mathbb{E}_{a \sim \mu}
        \frac{\pi_a}{\mu_a}
        f_a \nabla\log\pi_a
    = \nabla \mathbb{E}_{a \sim \mu}
        \frac{\pi_a}{\mu_a}
        \operatorname{sg}(f_a)
    = \nabla \mathbb{E}_{a \sim \pi}
        \operatorname{sg}(f_a)
    \,. $$

In [None]:
def ppo_gae(input, vp, myu, *, gam, lam, alpha, eps=0.2, mask=None):
    # (gae) compute GAE and returns for all rewards
    gae, ret = pg_targets(input.rew, vp.val, gam, lam, fin=input.fin)

    # (gae + motivation) get the weighted sum of advantages
    adv = reduce(gae, alpha).detach()

    # (ppo) get the action likelihood ratio \rho. It must be
    #  diff-able w.r.t. the current policy `vp.pol`.
    rho = pyt_logpact(vp.pol, input.act) - pyt_logpact(myu.pol, input.act)

    # (ppo) importance weighted PPO loss `\ell(\frac{\pi_t(a_t)}{\mu_t(a_t)}, A_t)`
    # XXX PPO disables grad feedback outside the $1 \pm \varepsilon$ trust
    #  region of the likelihood, depending on the sign of the advantage
    #  estimate, either undoes overly probable disadvantageous actions or
    #  reinforces confidence in profitable actions within the region.
    lik = rho.exp()
    polgrad = torch.minimum(
        adv * lik,
        adv * lik.clamp(1. - eps, 1. + eps)  # XXX can use .expm1
    )

    # apply the batching mask
    if mask is not None:
        scale = float(mask.sum()) / mask.numel()
        polgrad = polgrad.mul(mask).div(scale)

    return {
        'polgrad': polgrad.sum(),
        # (sys) entropy of the policy
        'entropy': plyr.apply(pyt_entropy, vp.pol, mask=mask),
        # (sys) intrinsic/extrinsic critic loss
        'critic': plyr.apply(pyt_critic, vp.val, ret, mask=mask),
    }

A helper for computing V-trace targets, that takes care of chipping
off the initial `rew` and `fin`.

In [None]:
from nle_toolbox.utils.rl.returns import pyt_td_target, pyt_vtrace
from nle_toolbox.utils.rl.returns import trailing_broadcast

def pyt_vtrace_helper(rew, val, gam, *, fin, rho, r_bar, c_bar):
    return pyt_vtrace(rew[1:], fin[1:], val, rho, gam=gam, r_bar=r_bar, c_bar=c_bar)

def pyt_td_target_helper(rew, vtrace, val, gam, *, fin, rho):
    # [O(T B F)] get the importance-weighted td(0) error advantages
    # see sec. "v-trace actor-critic algo" (p. 4) in Espeholt et al. (2018)
    target = pyt_td_target(rew[1:], fin[1:], vtrace, gam=gam)
    return target.sub(val[:-1]).mul(trailing_broadcast(rho, rew))

A procedure to get the likelihood of actions uder a given policy sequence

In [None]:
from nle_toolbox.utils.rl.engine import pyt_logpact

def impala(input, vp, myu, *, gam, alpha, r_bar, c_bar, mask=None):
    # (impala) For the advantages targets, IWs and critic targets we
    #  assume `vp` is diff-able, while `myu` isn't. Both are unstructured.
    val_, pol_ = plyr.apply(torch.Tensor.detach, vp)

    # (impala) get the importance weights
    # XXX `act[t]` is $a_{t-1}$, pol[t] is $\pi_t$ and $a_t \sim \pi_t$
    rho = pyt_logpact(pol_, input.act) - pyt_logpact(myu.pol, input.act)

    # (impala) get the V-trace targets $v_t$, t=0..N.
    # XXX Bad reading of the original paper (Espeholt et al.; 2018)) could
    #  wrongfully suggest that IMPALA uses values form the behavioural policy
    #  `myu`, rather than `vp` from the current policy.
    #  The V-trace targets, defined by eq. (1) in sec. 4.1 (p. 3, ibid), are
    #  in fact the result of applying the V-trace operator R (app. A.1, ibid)
    #  to SOME state-value function V. Furthermore, the paper never specifies
    #  whether the value estimates are bundled with the policy. Given this
    #  and the simultaneous on/off-policy use of V-trace, it should be understood
    #  that the section `V-trace A-C` (sec 4.2, ibid) uses the value estimates
    #  related to the current policy. Third-party implementations, e.g. RIDE's
    #  codebase, corroborate this.
    # XXX `rew[t], fin[t]` ($r_t, f_t$) PRECURSE `val[t]` ($v_t$), hence
    # we drop the first values from `rew` and `fin`.
    vtrace = plyr.apply(pyt_vtrace_helper, input.rew, val_, gam,
                        fin=input.fin, rho=rho, r_bar=r_bar, c_bar=c_bar)

    # (impala) get the importance-weighted td(0)-error advantages.
    # XXX unclear if we backprop thru `vp.val` here if the `val`
    # and `pol` networks have shared parameters.
    adv = plyr.apply(
        pyt_td_target_helper, input.rew, vtrace, val_, gam,
        fin=input.fin, rho=rho.exp().clamp_(max=r_bar))

    # (motivation) get the weighted sum of advantages
    adv = reduce(adv, alpha).detach()
    vtarget = plyr.apply(lambda t: t[:-1], vtrace)

    return {
        # (sys) policy grad surrogate `sg(A_t) \log \pi_t(a_t)`
        'polgrad': pyt_polgrad(vp.pol, input.act, adv=adv, mask=mask),
        # (sys) entropy of the policy
        'entropy': plyr.apply(pyt_entropy, vp.pol, mask=mask),
        # (sys) intrinsic/extrinsic critic loss
        'critic': plyr.apply(pyt_critic, vp.val, vtarget, mask=mask),
    }, rho.exp()

Rewarding impact-driven exploration
$$
\begin{equation}
    x_t, x_{t+1}
        \overset{\mathrm{RIDE}}{\longrightarrow}
            {\color{orange}{r^I_{t+1}}}
            = \bigl\|\phi(x_{t+1}) - \phi(x_t)\bigr\|
        \,.
\end{equation}
$$

$$
L_\mathrm{inv}
    = - \frac1{T-1} \sum_{t=1}^{T-1} \log p(a_{t-1} \mid x_t, x_{t-1})
  \,, $$
with
$$
p(a_{t-1} \mid x_t, x_{t-1})
    \propto f(\phi(x_t), \phi(x_{t-1}))
    \,. $$

$$
L_\mathrm{fwd}
    = - \frac1{T-1} \sum_{t=1}^{T-1}
        \log p(x_t \mid x_{t-1}, a_{t-1})
    \,, $$
with
$$
p(x_t \mid x_{t-1}, a_{t-1})
    \propto g(\phi(x_{t-1}), a_{t-1})
    \,. $$
* $x_t$ -- the current observation (we allow partial observability)
* $a_t$ -- the action to be taken in the `env`

In [None]:
def ride(module, emb0, emb1, act, *, b_detach=False, mask=None):
    # (ride) predict \xi_{t-1}, \xi_t -->> \pi(a_{t-1}) for t=1..T with
    #  the cross-entropy loss.
    # XXX we backprop into $\phi$ thru BOTH $\xi_{t-1}$ and $\xi_t$,
    #  `emb0[t]` and `emb1[t]`, respectively.
    out_inv = module.inv((emb0, emb1))
    loss_inv = F.cross_entropy(
        out_inv.flatten(0, 1), act.flatten(0, 1), reduction='none'
    ).reshape_as(act)

    # (ride) regress \xi_{t-1}, a_{t-1} -->> \xi_t for t=1..T w. ell-2 loss
    # XXX we could treat $\xi_t$ as a fixed target, and NOT backprop thru it!
    out_fwd = module.fwd((emb0, module.act(act)))
    loss_fwd = F.mse_loss(
        out_fwd, (emb1.detach() if b_detach else emb1), reduction='none',
    ).sum(-1)  # reduce over the embedding dim (output vector)

    # (sys) skip the first embedding (t=0) since it has been processed in
    #  the previous fragment. Besides, the agent has already acted upon the
    #  data in `input[T]` (verify).
    loss_inv = loss_inv[1:]
    loss_fwd = loss_fwd[1:]
    if mask is not None:
        scale = float(mask.sum()) / mask.numel()
        loss_inv = loss_inv.mul(mask).div(scale)
        loss_fwd = loss_fwd.mul(mask).div(scale)

    return {'fwd': loss_fwd.sum(), 'inv': loss_inv.sum()}

Progress bar update and termination condition checker.

In [None]:
import tqdm

def progress(bar, n):
    bar.update(n - bar.n)
    return n < bar.total

Decorators for debugging.

In [None]:
from functools import wraps

def capture(fn, to):
    """Capture the log information output to the specified dict.
    """
    if to is None:
        return fn

    @wraps(fn)
    def _wrapper(*args, **kwargs):
        nfo, _ = result = fn(*args, **kwargs)
        to.update(nfo)
        return result

    return _wrapper

def sniff(fn, to):
    """Record inputs and outputs to the specified list.
    """
    if to is None:
        return fn

    @wraps(fn)
    def _wrapper(*args, **kwargs):
        to.append((args, kwargs))
        return fn(*args, **kwargs)

    return _wrapper

A service function to get diagnostic stats and exploration metrics from an episode.

In [None]:
import math
from nle.nethack import NLE_BL_SCORE
from nle_toolbox.utils.env.defs import MAX_ENTITY, GLYPH_CMAP_OFF, symbol
from nle.env.base import NLE

from collections import namedtuple
Episode = namedtuple('Episode', 'input,output,info')

def ep_stats(ep, *, S_stone=symbol.S_stone + GLYPH_CMAP_OFF):
    assert isinstance(ep, Episode)

    met = {}
    # determine the offset for unfinished episodes
    offset = int(ep.input.fin[-1])
    if len(ep.input.fin) <= offset:
        return met

    # convert to numpy `npy = plyr.apply(np.asarray, ep)`
    # episode duration and total score
    n_length = len(ep.input.fin) - offset

    # score the episode (return = \sum_j r_{j+1} = sum r[j], j=1..N-1)
    f_return = plyr.apply(lambda r: r[1:].sum(), ep.input.rew)
    f_score = ep.input.obs['blstats'][-1-offset, NLE_BL_SCORE]

    # count the number of unique glyphs seen during the episode
    vic = ep.input.obs['vicinity'][:(-1 if offset > 0 else None)]
    cnt = torch.bincount(vic.flatten(), minlength=MAX_ENTITY + 1)
    n_unique = cnt.gt(0).sum()

    # The empirical entropy is a fine proxy for the diversity, since
    # it measures the amount of information content the seen glyphs.
    proba = cnt.div(cnt.sum())
    f_entropy = F.kl_div(proba.new_zeros(()), proba, reduction='sum').neg()

    # coverage and action effectiveness
    gly = ep.input.obs['glyphs'][:(-1 if offset > 0 else None)]
    # XXX we exclude the terminal obs, because it is actually the init
    #  obs from the next episode
    non_stone = (gly != S_stone).float().mean((-2, -1))
    f_cov = non_stone.max() / non_stone.min()
    f_eff = sum((g0 != g1).sum() / g1.numel() for g0, g1 in zip(gly, gly[1:]))
    
    # inspect the last info dict of the episode
    b_death = np.nan
    if ep.info:
        nfo = plyr.apply(lambda x: x[-1].item(), ep.info)

        # indicate if the episode ended in agent's death
        b_death = float(nfo['end_status'] == NLE.StepStatus.DEATH)

    return {
        'duration': n_length,
        'return': plyr.apply(float, f_return),
        'score': int(f_score),
        'n_unique': int(n_unique),
        'diversity': float(f_entropy) / math.log(2),
        'coverage': float(f_cov),
        'effectiveness': float(f_eff),
        'b_death': b_death,
    }

Aggregate the diagnostic stats across several episodes

In [None]:
def ep_aggregate(episodes):
    # filter out length one episodes
    metrics = list(filter(bool, map(ep_stats, episodes)))
    if not metrics:
        return {}

    # share of episodes that ended in agent's death
    f_deaths = np.mean([m['b_death'] for m in metrics])
    metrics = plyr.apply(np.median, *metrics, _star=False)

    metrics['b_death'] = f_deaths
    return metrics

<br>

Create the vectorized env and the agent

In [None]:
from copy import deepcopy
from nle_toolbox.utils.rl.engine import SerialVecEnv

config = deepcopy(dict(wandb.config))
env = SerialVecEnv(
    factory,
    n_envs=wandb.config.n_batch,
    kwargs=dict(config=config),
)

How to:
* organize the intrinsic motivator module?
  * ims are like regular actors, except their `actions` are rewards!
```python
    class SelfSupervisedRewards(nn.Module):
        def forward(self, obs, act=None, rew=None, fin=None, *, hx=None):
            return rew, (), hx
```
  * let's make intrinsict motivation modules be truly __exonegeous__ wrt the agents
* set up the optimizers for various submodules?

In [None]:
from torch.nn.utils import clip_grad_norm_

class BaseLearner(nn.Module):
    def __init__(self, agent, *, cfg: dict):
        super().__init__()
        self.agent = agent
        self.cfg = cfg
        self.epx = EpisodeExtractor()

    def forward(self, input, *, hx=None):
        return self.agent(**input._asdict(), hx=hx)

    def learn(self, input, vp=None, *, gx=None, hx=None, nfo=None):
        cfg = self.cfg
        if vp is None:
            _, vp, _ = self(input, hx=gx)

        # (sys) do GAE-A2C learning w. intrinsic motivation `vp` is (v_t, \pi_t)
        # ATTN the first pass is full batch?
        terms = a2c_gae(
            input,
            vp,
            gam=cfg['f_gamma'],
            lam=cfg['f_lambda'],
            alpha=cfg['f_alpha'],
            mask=None,
        )
        loss = reduce(terms, cfg['f_coefs'])

        # (sys) extra batch steps thru the current fragment with IMPALA
        myu = plyr.apply(torch.Tensor.detach, vp)
        for _ in range(cfg['n_off_policy']):
            # (sys) backprop through the agent
            self.agent.optim.zero_grad(True)
            loss.backward()
            grad = clip_grad_norm_(self.parameters(), cfg['f_grad_norm'])
            self.agent.optim.step()

            # (batching) bet the batch mask for the current fragment
            mask = dropout_mask(input, k=cfg['n_batch_size'])

            # (off-policy) additional pass thru the fragment
            _, vp, _ = self(input, hx=gx)
            if cfg['s_off_alg'] == 'impala':
                terms, _ = impala(input, vp, myu, gam=cfg['f_gamma'],
                                  alpha=cfg['f_alpha'], r_bar=1., c_bar=1.,
                                  mask=mask)

            elif cfg['s_off_alg'] == 'ppo':
                terms = ppo_gae(input, vp, myu, gam=cfg['f_gamma'],
                                lam=cfg['f_lambda'], alpha=cfg['f_alpha'],
                                eps=0.2, mask=mask)

            else:
                raise ValueError(f"Unknown algorithm `{cfg['s_off_alg']}`.")

            loss = reduce(terms, cfg['f_coefs'])

        # (sys) backprop through the agent
        self.agent.optim.zero_grad(True)
        loss.backward()
        grad = clip_grad_norm_(self.parameters(), cfg['f_grad_norm'])
        self.agent.optim.step()

        # (sys) extract episode strands (drop the last record, due to overlap)
        # XXX force a clone, since `input` might be overwritten, and epx slices!
        input = plyr.apply(lambda x: x[:-1].clone(), input)
        episodes = self.epx.extract(input.fin, Episode(
            input,
            (),
            plyr.apply(torch.as_tensor, *nfo, _star=False),
        ))

        # (sys) recompute the recurrent state `hx` AFTER the update over the
        # proper part of the fragment (t=0..N-1)
        if hx is not None:
            # `hx` is $h_N$ from $h_{t+1}, y_t = F(z_t, h_t; w)$, t=0..N-1.
            # One of update's side effects is that the recurrent runtime state
            #  in `hx` became stale, i.e. it no longer corresponds to the final
            #  state had the updated policy been run on the same fragment the
            #  second time. We would recompute `hx` over the entire historical
            #  trajectory, had we stored it whole. The next best option is to
            #  make a second pass over the just collected fragment $
            #      (z_t)_{t=0}^{N-1}
            #  $ and use $h'_N$ as the new `hx`, where $
            #      h'_{t+1}, y_t = F(z_t, h'_t; w')
            #  $ with $h'_0 = h_0$ given by `gx`.
            with torch.no_grad():
                _, _, hx = self(input, hx=gx)

            # DO NOT backprop through `hx` from the next fragment (truncated bptt!
            hx = plyr.apply(torch.Tensor.detach, hx)
            
            # pass some grad feedback from the next fragment to `h0`
            h0 = self.agent.initial_hx
            f_h0_lerp = cfg['f_h0_lerp']
            if h0 is not None and f_h0_lerp > 0:
                hx = plyr.apply(torch.lerp, hx, h0, weight=f_h0_lerp)

        # the batch-size normalization (TxB) is carried out on the logger's side
        terms = plyr.apply(float, terms)
        out = {'loss/loss': float(loss), 'loss/grad': float(grad)}
        out.update({'loss/' + k: v for k, v in terms.items()})

        # report diagnostic stats for completed episodes
        out.update({'metrics/' + k: v for k, v in ep_aggregate(episodes).items()})

        return out, hx

RND motivator

In [None]:
import math

class RNDMotivator(nn.Module):
    def __init__(self, rnd, *, cfg: dict):
        super().__init__()
        self.rnd = rnd
        self.cfg = cfg

    def forward(self, input, *, hx=None):
        """Reward observational novelty based on Random Network Distillation.
        """
        # get the non-diffable intrinsic rewards
        rew, out, hx = self.rnd(**input._asdict(), hx=hx)
        
        # clamp it, and return the rewards and outputs
        f_rew_max = self.cfg['f_rew_max']
        if f_rew_max is not None:
            rew.clamp_(max=f_rew_max)
        return rew, out, hx

    def learn(self, input, error=None, *, gx=None, hx=None, nfo=None):
        terms, grad = {}, float('nan')
        for _ in range(self.cfg['n_batches']):
            # We've got the following `input`: `.act[t]` is $r^I_{t-1}$
            #  `.obs.agent[t]` is $a_{t-1}$ taken in the env to get $x_t$,
            #  which is stroed in `.obs.obs[t]`.
            if error is None:
                _, error, _ = self(input, hx=gx)

            # (sys) skip the first item since it has been processed in a
            # previous fragment
            error = error[1:]

            # (batching) bet the batch mask for the current fragment
            mask = dropout_mask(input, k=self.cfg['n_batch_size'])
            if mask is not None:
                scale = float(mask.sum()) / mask.numel()
                error = error.mul(mask).div(scale)

            loss = error.sum()
            self.rnd.optim.zero_grad(True)
            loss.backward()
            grad = clip_grad_norm_(self.parameters(), self.cfg['f_grad_norm'])
            self.rnd.optim.step()

            # invalidate the errors
            error = None

        return {
            'loss/rnd': float(loss),
            'loss/mot.grad': float(grad),
        }, None

Intrinsic motivation using RIDE

In [None]:
class RIDEMotivator(nn.Module):
    def __init__(self, ride, *, cfg: dict):
        super().__init__()
        self.ride = ride
        self.cfg = cfg

    def forward(self, input, *, hx=None):
        """Impact-driven rewards.
        """
        # get the non-diffable intrinsic rewards
        rew, emb, hx = self.ride(**input._asdict(), hx=hx)

        # clamp it, and return the rewards and embeddings
        f_rew_max = self.cfg['f_rew_max']
        if f_rew_max is not None:
            rew.clamp_(max=f_rew_max)
        return rew, emb, hx

    def learn(self, input, emb=None, *, gx=None, hx=None, nfo=None):
        assert isinstance(input.obs, MotivatorObs)
        terms, grad = {}, float('nan')
        for _ in range(self.cfg['n_batches']):
            # input.obs.agent are actions ALREADY taken by the agent in the env!
            #  the last one is the one for which the capsule will compute the reward
            #  right after the call to `.learn` (all rewards received, except for
            #  the last one in `rew_`, the received rewards are in `input.act`)
            if emb is None:
                _, emb, _ = self(input, hx=gx)

            # (batching) bet the batch mask for the current fragment
            mask = dropout_mask(input, k=self.cfg['n_batch_size'])

            # get the actions of the agent and the embeddings
            # XXX `.agent[t]` is $a_{t-1}$, while `emb[t]` is
            #     $(z_{t-1}, z_t)$, $z_t = \phi(x_t)$, t=0..T
            terms = ride(self.ride, *emb, input.obs.agent,
                         b_detach=self.cfg['b_detach'], mask=mask)
            loss = reduce(terms, self.cfg['f_coefs'])

            self.ride.optim.zero_grad(True)
            loss.backward()
            grad = clip_grad_norm_(self.parameters(), self.cfg['f_grad_norm'])
            self.ride.optim.step()

            # invalidate the embeddings
            emb = None

        # (ride) recompute `hx` so that the initial embedding is correct,
        #  however do not waste compute on the whole fragment: since RIDE
        #  itself is non-recurrent and just applies embedder to the observations
        #  independently, we may just use the last observation in the proper
        #  fragment.
        # XXX clone for diffability, because `input` might be updated in-place
        with torch.no_grad():
            _, _, hx = self(plyr.apply(lambda x: x[-2:-1].clone(), input), hx=None)

            # DO NOT backprop through `hx` from the next fragment (truncated bptt!
            hx = plyr.apply(torch.Tensor.detach, hx)

        return {
            'loss/ride': plyr.apply(float, terms),
            'loss/mot.grad': float(grad),
        }, hx

problem:
* motivator need access to $a_t$ -- the action, which caused the transition $x_t \to x_{t+1}$ in the env. Therefore we cannot simply change the action space of the motivator to the rewards, unless we augmetn
its observation space by the action during init.

Create the learner and motivator instances.

In [None]:
learner = BaseLearner(agent, cfg=wandb.config.act)

motivator = None
if mot is not None:
    Motivator = {
        str(RNDModule): RNDMotivator,
        str(RIDEModule): RIDEMotivator,
    }[cls_motivator]
    motivator = Motivator(mot, cfg=wandb.config.mot)

At time $t \geq 0$ we do the following for each $
    j=0, 1, .., J-1
$ in sequence: at $t \to t+1$
$$
\omega^j_t, \bigl(
    u^{< j}_{t+1},
    u^{\geq j}_t
\bigr), \bigl(
    r^{< j}_{t+1},  % can see all rewards for historical observables
    r^{\geq j}_t
\bigr)
    \overset{L_j}{\longrightarrow}
        \omega^j_{t+1},
        u^j_{t+1},
        r^j_{t+1}
    \,, $$
where $u^{< j}_{t+1}$ are __observables__ issued by blocks prior to $j$
during their $t \to t+1$ transition, $u^{\geq j}_t$ --- the historical
time $t$ observables of all blocks after and including the $j$-th.

An __observable__ $u^j_{t+1}$ is anything that block $j$ may generate:
an action $a_{t+1}$ to be fed into a subsequent block, a partial observation
$x_{t+1}$, which can be used by the next block, or a reward $r_{t+1}$, which
the next block reveives for it historical __observable__.
<!--  -->
Despite being observables, we keep the rewards separate from $u^\cdot_t$, because
of their special role in the RL objective. The reward $r^{j-1}_{t+1}$ is issued
during $t \to t+1$ by the block $j-1$ for the historical observable $u^j_t$ produced
by block $j$ at $t-1 \to t$ transition.

The joint system is initialized at $t=-1$ with special values (default or random)
for $u^j_{-1}, r^j_{-1}$ and $\omega^j_{-1}$, that signify initialization.

For example, a classical binary env-agent interaction system has the following stepping:
$$
\begin{align}
    \underbrace{\omega_t}_{\omega^0_t}, \bigl(
        \emptyset,
        \bigl(
            \underbrace{\emptyset}_{u^0_t},
            \underbrace{a_t}_{u^1_t}
        \bigr)
    \bigr), \bigl(
        \emptyset,
        \bigl(
            \underbrace{\emptyset}_{r^0_t},
            \underbrace{\emptyset}_{r^1_t}
        \bigr)
    \bigr)
        &\overset{\text{ENV}}{\longrightarrow}
            \underbrace{\omega_{t+1}}_{\omega^0_{t+1}},
            \underbrace{x_{t+1}}_{u^0_{t+1}},
            \underbrace{r_{t+1}}_{r^0_{t+1}}
        \,, \\
    \underbrace{h_t}_{\omega^1_t}, \bigl(
        \underbrace{x_{t+1}}_{u^0_{t+1}},
        \underbrace{a_t}_{u^1_t}
    \bigr), \bigl(
        \underbrace{r_{t+1}}_{r^0_{t+1}},
        \underbrace{\emptyset}_{r^1_t}
    \bigr)
        &\overset{\text{ACT}}{\longrightarrow}
            \underbrace{h_{t+1}}_{\omega^1_{t+1}},
            \underbrace{a_{t+1}}_{u^0_{t+1}},
            \underbrace{\emptyset}_{r^1_{t+1}}
        \,,
\end{align}
$$
since **env** does not use its own observables it emitted in the past ($u^0_t$),
and reacts only to the observables produced by the actor, which are, in this
case, *actions* ($u^1_t = a_t$). Similarly, the **env** completely ignores all
rewards ($r^j_t$), unlike the actor **act**, which uses the env's reward to
reinfoce its historical actions ($r^0_{t+1}$ for the action $a_t$). As for
the observables, the actor makes use of the recent and heistorical observables
emitted by *env* ($u^0_s$ for $s \leq t+1$) and its own historical action
($u^1_t = a_t$) in order to produce the next reactive action ($u^1_{t+1} = a_{t+1}$).

learning process

In [None]:
from nle_toolbox.utils.rl.engine import Input, prepare
from nle_toolbox.utils.rl.capsule import capsule, buffered, launch
from nle_toolbox.zoo.models.motivation import MotivatorObs

log = {}
data_agnt = None
data_motv = None

caps = buffered(
    learner,
    capture(sniff(learner.learn, data_agnt), log),
    length=wandb.config.act['n_length'],
    device=device_
)

if motivator is not None:
    motv = buffered(
        motivator,
        capture(sniff(motivator.learn, data_motv), log),
        length=wandb.config.mot['n_length'],
        device=device_,
    )

n_steps = 0
with tqdm.tqdm(
    initial=n_steps,
    total=wandb.config.n_total,
    ncols=70,
    disable=False,
) as bar:
    # the local runtime state for the main exec flow
    # env is reset, and act is sampled from the env
    #   \emptyset -->> \omega_0, x_0, r^E_0, f_0
    npy, pyt = prepare(env, rew=0., fin=True)

    # specify the motivator and get its first reward action
    #   \xi_0, ((x_0, a_{-1}), g_{-1}, r^E_0, f_0) -->> \xi_1, g_0
    npy_mot_rew = np.zeros_like(npy.rew)
    if motivator is not None:
        obs = MotivatorObs(npy.act, npy.obs)
        act = np.full_like(npy_mot_rew, np.nan)
        npy_mot_rew = launch(motv, Input(obs, act, npy.rew, npy.fin))

    # recv the first action of the encapsulated agent
    #   h_0, (x_0, a_{-1}, (r^E_0, g_0), f_0) -->> h_1, a_0
    rew = {'ext': npy.rew, 'int': npy_mot_rew}
    act = launch(caps, Input(npy.obs, npy.act, rew, npy.fin))

    # step thru the environment
    #   \omega_0, a_0 -->> \omega_1, x_1, r^E_1, f_1
    obs, rew, fin, nfo = env.step(act)
    n_steps += len(env)

    # the main loop
    while progress(bar, n_steps):
        # compute intrinsic motivation (self-supervised rewards)
        #   \xi_{t-1}, a_{t-1}, x_{t-1}, x_t -->> \xi_t, r^I_t
        if motivator is not None:
            npy_mot_rew = motv.send((MotivatorObs(act, obs), rew, fin, nfo))

        # issue the rewards r^E_t and r^I_t for the `t-1 -->> t` transition
        rew = {'ext': rew, 'int': npy_mot_rew}

        # decide which capsule to route the data to and get the next action
        #   h_{t-1}, x_t -->> h_t, a_t
        act = caps.send((obs, rew, fin, nfo))

        # (log) the training progress
        if log:
            wandb.log({'n_steps': n_steps, **log}, commit=True)
            log.clear()

        # step thru the environment
        #   \omega_{t-1}, a_{t-1} -->> \omega_t, x_t, r^E_t, f_t
        obs, rew, fin, nfo = env.step(act)
        # plyr.apply(np.copyto, npy, Input(obs, act, rew, fin))
        n_steps += len(env)

<br>

In [None]:
import os
from nle_toolbox.utils.io import mkstemp

target = os.path.abspath('./checkpoints')
os.makedirs(target, exist_ok=True)

checkpoint = mkstemp('.pt', f'ckpt__{wandb.run.name}__{n_steps}__', dir=target)
torch.save({
    'cls': str(NLEAgent),
    'agent': agent.state_dict(),
    'agent.optim': agent.optim.state_dict(),
    'mot': mot.state_dict() if mot is not None else {},
    'mot.optim': mot.optim.state_dict() if mot is not None else {},
    'config': dict(wandb.config),
}, checkpoint)
print(checkpoint)

In [None]:
wandb.finish(quiet=True)

<br>

We finally got something:
* reducing the dim of the RND output vector from 64 to 16 was, perhaps,
the most significant step in the direction of the agent learning at least
something.
* the next thing was removing the build embedding and using `condition`,
`hunger`, and `vitals
* the effect of positive `f_h0_lerp` has not been investigated
  * reducing it from `.05` to `0.0` seems to **adversely** impact learning


Next idea to try: like token + position in BERT and GPT, let's exploit additive embeddings.
* in glyphs: entity + group embedding -- entities semantics is modulated by its group (relevant to MON, PET, RIDE, STATUE), also add special `ego`-embedding at the centre of vicinity.
Let $g_{uv} \in \mathbb{G}$ be glyph at position $u, v$ relative to the `hero` (bls-x, -y coords).
$$
f_{uv}
    = w_E\bigl[\operatorname{entity}(g_{uv})\bigr]
    + w_G\bigl[\operatorname{group}(g_{uv})\bigr]
    + 1_{(0,0)}{(u, v)} w_{\mathrm{ego}}
  \,, $$
where $w_\cdot$ are game map-related embeddings
* in blstats: modulate an `ego`-embedding by the `vitals`, `condition` and other stats' embeddings.
  * should the the `ego`-embedding be shared between map and state?

<br>

A ranked buffer for episode rollouts.

* what do we do with the missing `hx`? clone episodes in full?
  * `you wake up on a cold stone floor in the middle of a vast chamber without a shred of memoery of how you got here. What do you do?` maybe it is OK to take contiguous fragments of a long episode and start with a wiped out memory.
* do we clone just the actions, or also the value function?

In [None]:
from heapq import heapreplace, heappushpop, heappush, heappop

from dataclasses import dataclass, field
from typing import Any

class RankedBuffer:
    @dataclass(order=True, frozen=True, repr=False)
    class RankedItem:
        rank: float
        item: Any = field(compare=False)

    def __init__(self, capacity):
        self.buffer = []
        self.capacity = capacity

    def push(self, rk, it):
        item = self.RankedItem(rk, it)
        # push the current item
        if len(self.buffer) < self.capacity:
            return heappush(self.buffer, item)
        # ... pop the lowest-ranking one, if we exceed capacity
        return heappushpop(self.buffer, item)

    def extend(self, pairs):
        last = None
        for rk, it in pairs:
            last = self.push(rk, it)
        return last

    def __bool__(self):
        return bool(self.buffer)

    def __getitem__(self, index):
        return self.buffer[index]

    def __repr__(self):
        return type(self).__name__ + f"({len(self.buffer)}/{self.capacity})"

    def __iter__(self):
        return ((el.rank, el.item) for el in self.buffer)

    def sample(
        self,
        n_samples=8,
        n_steps=64,
        *,
        rng=np.random.default_rng(),
    ):
        # determinie the sufficiently long episodes
        eligible = []
        for j, ep in enumerate(self.buffer):
            input = ep.item.input
            dur_ = len(input.fin) - int(input.fin[-1])
            if dur_ < n_steps:
                continue

            eligible.append((j, dur_,))

        # sample starting strands from episodes
        chunks = []
        for i in rng.choice(len(eligible), size=n_samples):
            k, dur_ = eligible[i]
            j = rng.integers(dur_ - n_steps + 1)

            chunk = plyr.apply(lambda t: t[j:j + n_steps],
                               self.buffer[k].item)
            chunks.append(chunk)

        return plyr.apply(torch.stack, *chunks, _star=False, dim=1)

A common routine to plot the computed rollout metrics

In [None]:
def plot_ep_metrics(metrics):
    fig, axes = plt.subplots(2, 2, figsize=(7, 3), dpi=300)

    out_ = plyr.apply(list, *metrics, _star=False)
    for ax, (nom, val) in zip(axes.flat, out_.items()):
        ax.hist(val, label=nom, log=nom in ('score', 'return',), bins=20)
        ax.set_title(nom)

    fig.tight_layout()
    return fig, axes

Rank the peisode and put it into a buffer

In [None]:
def add_episodes(buf, iterable, *, C_cov=0.1):
    for ep in iterable:
        met = ep_stats(ep)
        if not met:
            continue

        rk = met['return'] + C_cov * met['coverage'] / met['duration']
        buf.push(rk, ep)

One step in the joint differentiable rollout collection and a procedure to collect a differentiable rollout

In [None]:
from nle_toolbox.utils.rl.engine import step

def collect(env, agent, npyt, hx, *, n_steps, visualize=None, fps=0.01, device=None):
    """Collect a fragment of the trajectory."""
    # (sys) get a view into numpy's observation arrays
    vw_vis = None
    if visualize is not None:
        vw_vis = plyr.apply(plyr.getitem, npyt.npy.obs, index=visualize)

    for j in range(n_steps):
        if vw_vis is not None:
            ipynb_render(vw_vis, clear=True, fps=fps)

        # (sys) get $(x_t, a_{t-1}, r_t, d_t), v_t, \pi_t$
        _, hx, _ = out = step(env, agent, npyt, hx, device=device)
        yield out

We may want to evaluate in a different env.

In [None]:
from inspect import signature, _empty
from nle_toolbox.utils.dicttools import override

def_config = signature(factory).parameters['config'].default
if def_config is _empty:
    def_config = config

eval_config = override(
    def_config,
    dict(
#         id='MiniHack-CorridorBattle-Dark-MoreRats-v0',
    ),
)

Prepare a strand extractor and a buffer for ranking episodes, and ready the vectorized env and the runtime context.

In [None]:
epx, buf = EpisodeExtractor(), RankedBuffer(128)
metrics = []

env = SerialVecEnv(factory, n_envs=16, kwargs=dict(config=eval_config))
npyt, hx = prepare(env, rew=0., fin=True), None

A visualized evaluation run.

In [None]:
n_total = 65536 * 4
n_steps, visualize = 0, None
with torch.no_grad(), tqdm.tqdm(
    initial=n_steps, total=n_total, ncols=80,
    disable=visualize is not None,
) as bar:
    nfo_ = {}
    while progress(bar, n_steps):
        # (sys) collect a fragment of the episode time `t` afterstates, t=0..N-1
        fragment, hxx, nfo = zip(*collect(
            env, agent, npyt, hx, n_steps=128,
            visualize=visualize, fps=0.05,
            device=device_
        ))
        # XXX `fragment` is ((x_t, a_{t-1}, r_t, d_t), v_t, \mu_t), t=0..N-1

        # (sys) retain running state `hx`, but detach its grads (truncated bptt)
        # ATTN do not update `npyt` and `hx`!
        if hxx[-1] is not None:
            hx = plyr.apply(torch.Tensor.detach, hxx[-1])
        
        # (sys) shift and collate the info dicts
        #    `d0, (d1, ..., dn)` -->> `(d0, ..., d{n-1}), dn`
        *nfo, nfo_ = nfo_ or nfo[0], *nfo
        nfo = plyr.apply(torch.as_tensor, *nfo, _star=False)

        # (sys) repack the fragment data
        # XXX note, `.act[t]` is $a_{t-1}$, but the other `*[t]` are $*_t$,
        #  e.g. `.rew[t]` is $r_t$, and `pol[t]` is `$\pi_t$.
        input, output = plyr.apply(torch.cat, *fragment, _star=False)

        # (sys) incerment the step count
        n_steps += input.fin.numel()

        # (sys) extract episode strands with log-probs of the taken actions
        episodes = epx.extract(input.fin, Episode(input, output, nfo))
        add_episodes(buf, episodes, C_cov=0.5)

        # (evl) compute the metrics of the completed episodes
        for ep in episodes:
            met = ep_stats(ep)
            if met:
                metrics.append(met)

    # (sys) extract the residual episode strands
    unfinished = epx.finish()
    add_episodes(buf, unfinished, C_cov=0.5)

plot_ep_metrics(metrics);

The stats of the episodes in the buffer.

In [None]:
from matplotlib import pyplot as plt

out, fps = [], None
for rk, ep in buf:
    npy = plyr.apply(np.asarray, plyr.apply(torch.Tensor.cpu, ep.input))
    off = int(npy.fin[-1])  # offset for unfinished strands
    for t in range(len(npy.fin) - off):
        obs = plyr.apply(plyr.getitem, npy.obs, index=t)
        ipynb_render(obs, fps=fps)

    met = ep_stats(ep)
    if met:
        out.append(met)

plot_ep_metrics(out);

In [None]:
assert False

<br>

Render a replay of the given episode from the agent's ego-centric view.

In [None]:
from nle_toolbox.utils.env.draw import draw

ENV_ACTIONS = env.envs[0].unwrapped.actions
for rk, ep in buf:
    n_length = len(ep.input.fin) - int(ep.input.fin[-1])
    npy = plyr.apply(lambda x: x.cpu().numpy(), ep)
    for t in range(n_length):
        fig = plt.figure(figsize=(6, 6), dpi=120)
        artists = draw(fig, npy, t, actions=ENV_ACTIONS)
        clear_output(True)
        plt.tight_layout()
        plt.show()
        plt.close()
        sleep(0.01)

<br>

In [None]:
{k: p.grad.flatten().norm() for k, p in mot.named_parameters() if p.grad is not None}

In [None]:
{k: p.grad.flatten().norm() for k, p in agent.named_parameters() if p.grad is not None}

In [None]:
assert False

<br>

<br>

Get ready to clone the successful episodes in the ranked buffer

In [None]:
n_samples, n_steps = 8, 64
r_bar, c_bar = 1.01, 1.1
rng = np.random.default_rng()

losses = []

Behaviour cloning
* let's have a look at [Self-Imitation Learning](https://proceedings.mlr.press/v80/oh18b.html)
> ... off-policy actor-critic algorithm that learns to reproduce the agent’s past good decisions.
... that exploiting past good experiences can indirectly drive deep exploration.

In [None]:
for k in tqdm.tqdm(range(50), ncols=70):
    # sample a bath of trajectory fragments
    input, myu, _ = out = buf.sample(n_samples, n_steps, rng=rng)

    # recompute the policy and value-to-go estimates for the episode
    # XXX amnesia training: forget the hx
    _, vp, _ = agent(input.obs, input.act, input.rew, hx=None, fin=input.fin)
    # XXX this is not EXACTLY identical to `fin=ep_.fin`, which is guaranteed
    #  to contain a reset `fin[0]` and possibly a `fin[-1]` (not in case when
    #  the episode is unfinished). We ignore pol[-1] $\pi_{T}$ and val[-1]
    #  $v(s_{T})$, both of which pertain to the next episode. `fin` affects
    #  only the recurrrent state anyway and 1) we set the initial to `None`,
    #  and 2) do not ever use the

    L_loglik = pyt_polgrad(vp.pol, input.act, adv=1.)
    ell = - L_loglik

    # get the v-trace target for the critic and the advantages to pol-grad
    # XXX here `.fin[-1]` properly blocks the last state-value backup
#     ret, _, rho = pyt_impala(
#         input.rew, input.fin, input.act,
#         val['ext'], pol, myuval['ext'], myupol,
#         gam=f_gamma['ext'], r_bar=r_bar, c_bar=c_bar
#     )
#     L_critic = pyt_critic(val['ext'], ret)

#     ell = (L_critic * (C_critic['ext'] / 2) - L_loglik)

    agent.optim.zero_grad()
    ell.backward()
    agent.optim.step()

#     losses.append((float(L_loglik), float(L_critic)))
    losses.append((float(L_loglik),))

In [None]:
L_loglik, = map(np.array, zip(*losses))

plt.semilogy(-L_loglik, c='C1')

In [None]:
assert False

<br>

In [None]:
from minihack.envs import register
from nle import nethack
from nle.env.tasks import NetHackScore

class WalkingNetHack(NetHackScore):
    def __init__(self, *args, observation_keys, **kwargs):
        kwargs["actions"] = kwargs.pop(
                "actions", tuple(nethack.CompassDirection)
        )
        super().__init__(*args, **kwargs, observation_keys=observation_keys)

register(
    id="NetHackChallengeMovingOnly-v0",
    entry_point=WalkingNetHack,
)

from functools import partial
fac = partial(factory, id='NetHackChallengeMovingOnly-v0')



<br>

In [None]:
with torch.no_grad():
    # `(W_ii|W_if|W_ig|W_io)`
    blsw = agent.core.weight_ih_l0[:, 128:]
    blsw = blsw.reshape(len(blsw), -1, 5)
    norms = blsw.norm(p=2, dim=-1)

norms = dict(zip([
    'hunger', 'status', 'hp', 'mp',
    'str', 'dex', 'con', 'int', 'wis', 'cha',
    'strprc', 'AC', 'encumberance',
], norms.T))

breaks = 0, 128, 256, 384, 512
c_pair = "C0", "C1"
fig, axes = plt.subplots(2,2, figsize=(7, 7,), dpi=300, sharey=True, sharex=True)
for nom, ax in zip(norms, axes.flat):
    ax.semilogy(norms[nom], label=nom)
    for j, (a, b) in enumerate(zip(breaks[1:], breaks)):
        ax.axvspan(a, b, color=c_pair[j&1], alpha=0.05, zorder=-10)

    ax.legend(fontsize='xx-small')

<br>

<br>