# Let's try an non-hierarchical RL agent

In [None]:
import time
import gym
import nle

import matplotlib.pyplot as plt

In [None]:
del gym.Wrapper.__getattr__

Import other useful modules

In [None]:
import plyr
import torch
import numpy as np

from torch import nn
from torch.nn import functional as F

<br>

Register envs, which punishment death with `-1` reward.
* see the base class [MiniHack](https://github.com/facebookresearch/minihack/blob/65fc16f0f321b00552ca37db8e5f850cbd369ae5/minihack/base.py#L131-L132) subclassed by `MiniHackNavigation` and `MiniHackRoom`

In [None]:
from minihack.envs import register

register(
    id="MiniHack-Room-Ultimate-15x15-v1",
    entry_point="minihack.envs.room:MiniHackRoom15x15Ultimate",
    kwargs=dict(
        reward_win=+1,  # default
        reward_lose=-1,  # used to be 0.
    ),
)

register(
    id="MiniHack-CorridorBattle-Dark-v1",
    entry_point="minihack.envs.fightcorridor:MiniHackFightCorridorDark",
    kwargs=dict(
        reward_win=+1,  # default
        reward_lose=-1,  # used to be 0.
    ),
)

register(
    id="MiniHack-HideNSeek-Big-v1",
    entry_point="minihack.envs.hidenseek:MiniHackHideAndSeekBig",
    kwargs=dict(
        reward_win=+1,  # default
        reward_lose=-1,  # used to be 0.
    ),
)

register(
    id="MiniHack-Memento-F4-v1",
    entry_point="minihack.envs.memento:MiniHackMementoF4",
    kwargs=dict(
        reward_win=+1,  # default
        reward_lose=-1,  # used to be 0. <<-- reward for death
    ),
)

Agent's and motivator's arch

In [None]:
from nle_toolbox.zoo.models.basic import NLENeuralAgent

recipe = {
    'agent': NLENeuralAgent.default_recipe(
        n_actions=8,  # 85, 121  # XXX make this an automatic setting dependent on the env
        embedding_dim=16,
        intermediate_size=256,
        act_embedding_dim=None,  # default to `embedding_dim`
        fin_embedding_dim=0,  # disable `fin` flag embedding
        core='lstm',

        # it appears that more layers makes the agent learn better!
        num_layers=2,
        # hardcoded! see `factory()`
        k=3,
        bls=(
            'hp',
            'hunger',
            'condition',
        ),
    ),
}

recipe['motivator'] = {
    'obs': recipe['agent']['obs'],
    'sizes': [16],
}

recipe_bas = recipe

A recipe for Highway Transformer

In [None]:
from nle_toolbox.zoo.models.transformer import NLEHITNeuralAgent

recipe = {
    'agent': NLEHITNeuralAgent.default_recipe(
        n_actions=8,  # 121  # XXX make this an automatic setting dependent on the env
        embedding_dim=16,
        n_cls=8,
        n_io=8,  # \geq 4 is great, but \leq 2 doesn't work
        intermediate_size=64,
        num_attention_heads=4,
        head_size=16,

        num_layers=1,  # 2 is gud, but can we use one layer?
        # hardcoded! see `factory()`
        k=3,
        bls=(
            'hp',
            'hunger',
            'condition',
        ),
    ),
}

recipe['motivator'] = {
    'obs': {
        'n_context': recipe['agent']['n_context'],
        'embedding_dim': recipe['agent']['embedding_dim'],
        'intermediate_size': 256,
        'dropout': 0.0,
    },
    'sizes': [16],
}

recipe_hit = recipe

Pick the agent to try

In [None]:
if True:
    # highway xformer
    recipe, NLEAgent = recipe_hit, NLEHITNeuralAgent
    tags = 'HiT',

else:
    recipe, NLEAgent = recipe_bas, NLENeuralAgent
    tags = ()

The fragmented a2c parameters

In [None]:
import wandb

wandb.init(
    project='nle-toolbox-capsule',
    job_type='nethack',
    # mode='disabled',
    config=dict(
        # int weight in the gae mix for the polgrads
        f_alpha=0.5,
        f_rnd_loss_dropout=0.0,
        f_rnd_lr=1e-3,

        # extrinsic/intrinsic reward PV discount
        f_gamma={'ext': 0.99, 'int': 0.9},

        # the GAE discount
        f_lambda={'ext': 0.96, 'int': 0.96},

        # the share of the runtime state's hx's gard to be passed to h0
        f_h0_lerp=0.00,

        # critic (both ext and int) and entropy weights in the loss
        C_pg=1.,
        C_critic={'ext': 0.5, 'int': 0.5},
        C_entropy=0.01,
        f_a2c_lr=1e-3,

        # the number of off-policy updates
        n_off_policy=0,

        # gradient ell-2 norm clipping
        f_grad_norm=float('inf'),

        # the truncated-bptt length rollout
        n_fragment_length=100,
        n_rnd_fragment_length=20,

        # the number of envs run simultaneously
        n_batch=16,

        # the total number of steps allotted to training (summed across all envs)
        n_total=2_592_000 * 1,

        # also track the recipes
        recipe=recipe,
        
        # the env id (which should've been here looong time ago)
        # navigation
#         id='MiniHack-Room-Ultimate-15x15-v1',
#         id='MiniHack-CorridorBattle-Dark-v1',
#         id='MiniHack-HideNSeek-Big-v1',
#         id='MiniHack-Memento-F4-v1',
        id='MiniHack-Memento-Short-F2-v0',

        # skill acquistion (advanced), 85 actions
#         id='MiniHack-WoD-Hard-v0',
#         id='MiniHack-LavaCross-v0',  # 85 actions

        # Full game (ultimate) 121 actions
#         id='NetHackChallenge-v0',

        b_load_ckpt=False,
    ),
    tags=[
        'hx-fix',
        *tags,
    ],
);

* `linear` with small TxB batch (64x5) is more noisy, and produces less `stable` policy, but converges fastest
* `linear` with larget batch (16x100) converges a bit slower, but produces more stable policy
* `lstm` with 16x100 converges slowest

* [MiniHack-HideNSeek-Big-v0](https://minihack.readthedocs.io/en/latest/envs/navigation/hidenseek.html)
> ... the agent is spawned in a big room full of trees and clouds \[, which\] block the line of sight of the player\[, \] and a random monster (chosen to be more powerful than the agent). The agent, monsters and spells can pass through clouds unobstructed \[but\] cannot pass through trees. The goals is to make use of the environment features, avoid being seen by the monster and quickly run towards the goal.

* [MiniHack-WoD-Hard-v0](https://minihack.readthedocs.io/en/latest/envs/skills/wod.html)
> ... require mastering the usage of the wand of death (WoD). Zapping a WoD it in any direction fires a death ray which instantly kills almost any monster it hits. ... the WoD needs to be found first, only then the agent should enter the corridor with a monster (who is awake and hostile this time), kill it, and go to the staircase.

* [MiniHack-LavaCross-v0](https://minihack.readthedocs.io/en/latest/envs/skills/lava_cross.html)
> The agent can accomplish this by either levitating over it (via a potion of levitation or levitation boots) or freezing it (by zapping the wand of cold or playing the frost horn).

* [MiniHack-Memento-F4-v0](https://minihack.readthedocs.io/en/latest/envs/navigation/memento.html)
> The agent is presented with a prompt (in the form of a sleeping monster of a specific type), and then navigates along a corridor. At the end of the corridor the agent reaches a fork, and must choose a direction. One direction leads to a grid bug, which if killed terminates the episode with +1 reward. All other directions lead to failure through a invisible trap that terminates the episode when activated. The correct path is determined by the cue seen at the beginning of the episode.

<br>

We hide the NLE under several layers of wrappers. From the core to the shell:
1. `ReplayToFile` handles seeding and logs the taken actions and seed into a file for later inspection and replay.

2. `NLEObservationPatches` patches tty-screens, botched by the cr-lf misconfiguration of the NLE's tty term emulator and NetHacks displays (lf only).

3. `NLEFeatureExtractor` adds extra features generated on-the-fly from the current NLE's observation.
  * an ego-centric view of the specified radius into the dungeon map (`vicinity`)
  * percentage strength (`NLE_BL_STR125`), converted to a `100`-base score, used by the game to compute extra strength bonuses.
    * The strength stat in AD&D 2ed, upon which the mechanics of NetHack is based, comes in two ints: strength and percentage. The latter is applicable to **warrior classes** with **natural str** 18 and denotes `exceptional strength`, which confers extra chance-to-hit, damage, and chance to force locks or doors.


4. `RecentHistory` keeps a brief log of actions taken in the environment (partially duplicates the functionality of the `Replay` wrapper).

5. `Chassis` handles skippable gui events that do not require a decision, such as collecting menu pages unless an interaction is required, fetching consecutive topline or log messages.

6. `ObservationDictFilter` allow only the specified keys of the observation dict to get through

7. `ActionMasker` computes the mask of action that are **forbidden** in the current game state (_gui_ or _play_)
  * the mask is communicated through the observation dicts under the key `action_mask`

In [None]:
from nle_toolbox.bot.chassis import get_wrapper

from nle_toolbox.utils.replay import ReplayToFile, Replay

from nle_toolbox.utils.env.wrappers import NLEObservationPatches
from nle_toolbox.utils.env.wrappers import NLEFeatureExtractor
from nle_toolbox.utils.env.wrappers import RecentHistory

from nle_toolbox.bot.chassis import Chassis, ActionMasker

from nle_toolbox.utils.env.wrappers import ObservationDictFilter

The factory for collecting random exploration rollouts

In [None]:
# from nle_toolbox.utils import seeding
import minihack

def factory(seed=None, folder=None, sticky=False, id=str(wandb.config.id)):
    env = gym.make(
        id,
        observation_keys=(
            'glyphs',
            'chars',
            'colors',
            'specials',
            'blstats',
            'message',
            'inv_glyphs',
            'inv_strs',
            'inv_letters',
            'inv_oclasses',
            'tty_chars',
            'tty_colors',
            'tty_cursor',
            'misc',
            'screen_descriptions',
        ),
    )

    from nle.nethack import ACTIONS
    ctoa = {chr(a): j for j, a in enumerate(env.unwrapped.actions)}
    atoc = tuple(map(chr, env.unwrapped.actions))

    # provide seeding capabilities and full action tracking
    if folder is None:
        env = Replay(env, sticky=sticky)

    else:
        env = ReplayToFile(env, sticky=sticky, folder=folder, save_on='done')
    env.seed(seed)

    # patch bugged tty output
    env = NLEObservationPatches(env)

    # log recent actions
    env = RecentHistory(
        env,
        n_recent=128,
        map=lambda a: atoc[a],  # XXX atoc IS NOT a dict!
    )

    # skippable gui abstraction layer. Bypassed if the action
    # space does not bind a SPACE action.
    env = Chassis(env, space=ctoa.get(' '), split=False)

    # a feature extractor to potentially reduce the runtime complexity
    # * ego-centric view, and properly hadnled exceptional strength stat
    env = NLEFeatureExtractor(env, k=3)

    # filter unused observation keys
    # XXX this wrapper should be applied before any container-type
    #  modifications of the NLE's observation space.
    env = ObservationDictFilter(
        env,
        # the map, bottom line stats and inventory
        'glyphs',
        # 'chars',
        # 'colors',
        # 'specials',
        'blstats',
        'inv_glyphs',
        # 'inv_strs',
        'inv_letters',
        # 'inv_oclasses',

        # used for in-notebook rendering
        'tty_chars',
        'tty_colors',
        'tty_cursor',

        # used by the GUI abstraction layer (Chassis)
        # 'message',
        # 'misc',

        # extra features produced by the upstream wrappers
        'vicinity',
    )

    # compute and action mask based on the current NLE mode: gui or play
    env = ActionMasker(env)
    return env

A renderer for this **factory**

In [None]:
import pprint as pp
from time import sleep

from nle_toolbox.utils.env.render import render as tty_render
from IPython.display import clear_output

def ipynb_render(obs, clear=True, fps=None):
    if fps is not None:
        if clear:
            clear_output(wait=True)

        print(tty_render(**obs))
        if fps > 0:
            sleep(fps)

    return True

<br>

### Redesigning the building blocks of the NLE featrue extractor

The design of a simple obervation encoder:
* embed glyphs' entities and their groups additively,
* employ the `ego` embedding: learnable offset to the centre of the vicinity (an ego-centric view into the map)
* join with embeddings of health, hunger and condition form the botl

In [None]:
from nle_toolbox.zoo.models.basic import SimpleEncoder

Intrinsic motivation via Random Network distillation

In [None]:
from typing import List

class RNDModule(nn.Module):
    def __init__(
        self,
        obs,
        *,
        sizes: List[int]
    ):
        super().__init__()

        layers = [SimpleEncoder(**obs)]

        sizes = obs['intermediate_size'], *sizes
        for n, m in zip(sizes, sizes[1:]):
            layers.append(nn.GELU())
            layers.append(nn.Linear(n, m, bias=True))
        self.encoder = nn.Sequential(*layers)

    def forward(self, obs, act=None, rew=None, fin=None, *, hx=None):
        return self.encoder(obs), None

A clunky tool to split the parameters into biases and weights.

Unlike `.bias` parameters, which **should not** be decayed, not every `.weight` parameter
**should be** decayed. For example, learnt positional encodings can be seen as bias terms
to the "linear" operation effected by an `nn.Embedding`, at the same time ordinary token
embeddings should also not be decayed either, since they effectively serve as the input
data representation.

`nn.LayerNorm` layers perform within-layer normalization and then re-scale and translate
it, meaning that their weight should be regularized to unit scales ant not zeros.

Hence:
* all biases, and weights of `nn.LayerNorm` and `nn.Embedding` should not be regularized
by weight decay


In [None]:
from nle_toolbox.zoo.models.transformer import HiT

# XXX clean this up
def split_parameters(module):
    decay, no_decay, seen = [], [], set()
    for prefix, mod in module.named_modules():
        for name, par in mod.named_parameters(prefix='', recurse=False):
            assert par not in seen
            seen.add(par)

            # always exclude bias terms
            if name.startswith('bias'):
                no_decay.append(par)

            # but always decay dense linear weights
            elif name.startswith('weight') and isinstance(
                mod, (nn.Linear, nn.LSTM)
            ):
                decay.append(par)

            # yet never decay normalization translations and embedding representations
            elif name.startswith('weight') and isinstance(
                mod, (nn.LayerNorm, nn.Embedding)
            ):
                no_decay.append(par)

            elif name in ('posemb', 'cls', 'iox') and isinstance(
                mod, (HiT,)
            ):
                no_decay.append(par)

            else:
                no_decay.append(par)
                # raise TypeError(f'Unrecognized parameter `{name}` in `{prefix}` {mod}.')

    return decay, no_decay

Adam with high `weight_decay` may push many parameters' values
into denormalized fp mode, which is ultra slow on CPU (but not
as bad on GPU), see the answer and a reply form *njuffa* in to
this [stackoverflow](https://stackoverflow.com/questions/36781881).

In [None]:
torch.set_flush_denormal(True)

Build an agent and the RND motivator from the recipe

In [None]:
from nle_toolbox.utils.nn import ModuleDict

agent = NLEAgent(**wandb.config.recipe['agent'])

rnd = ModuleDict(dict(
    target=RNDModule(**wandb.config.recipe['motivator']).requires_grad_(False).eval(),
    online=RNDModule(**wandb.config.recipe['motivator']).train(),
))

Reset the bias terms in the recurrent core

In [None]:
from torch.nn import init
from nle_toolbox.utils.nn import rnn_reset_bias

agent.apply(rnn_reset_bias);

Init agent's optimizer

In [None]:
decay, no_decay = split_parameters(agent)

# AdamW doesn't do what you expect it to do, Ivan! (although it
#  correctly decouples the objective's grad and the ell-2 weight reg).
#  See https://arxiv.org/abs/1711.05101.pdf
agent.optim = torch.optim.AdamW([
    dict(params=decay),
    dict(params=no_decay, weight_decay=0.),
], lr=wandb.config.f_a2c_lr, eps=1e-5, weight_decay=0.001)

Load an eariler checkpoint.

In [None]:
if wandb.config.b_load_ckpt:
    checkpoint = """/Users/ivannazarov/Github/repos_with_rl/nle_toolbox/"""\
                 """doc/checkpoints/ckpt__250571vf__31104000__2kfy_rb5.pt"""
    ckpt = torch.load(checkpoint)
    print(agent.load_state_dict(ckpt['agent']))

An optimizer for the RND

In [None]:
decay, no_decay = split_parameters(rnd.online)
rnd.optim = torch.optim.AdamW([
    dict(params=decay),
    dict(params=no_decay, weight_decay=0.),
], lr=wandb.config.f_rnd_lr, eps=1e-5, weight_decay=0.01)

In [None]:
agent

In [None]:
rnd

<br>

### Let's train an A2C agent in a capsule

An object to extract full episodes from their trajectory fragments.

In [None]:
from nle_toolbox.utils.rl.tools import EpisodeExtractor

A capulse for a learner agent and a launcher for it.

In [None]:
from nle_toolbox.utils.rl.capsule import launch, capsule

Compute the policy gradient surrogate, the entropy and other loss components

In [None]:
from nle_toolbox.utils.rl.engine import pyt_polgrad
from nle_toolbox.utils.rl.engine import pyt_entropy
from nle_toolbox.utils.rl.engine import pyt_critic

We shall use GAE in policy gradients and returns for the critic.

In [None]:
from nle_toolbox.utils.rl.returns import pyt_ret_gae

Get the weighted sum of the leaves in one nested container with
weights from another second container.

In [None]:
import operator as op

def reduce(values, weight=None):
    flat = []
    if weight is not None:
        values = plyr.apply(op.mul, values, weight)

    plyr.apply(flat.append, values)
    return sum(flat)

A function to compute the targets (GAE, returns) for policy grads and critic loss.

In [None]:
@torch.no_grad()
def pg_targets(rew, val, /, gam, lam, *, fin):
    r"""Compute the targets (GAE, returns) for policy grads and critic loss.

    Details
    -------
    The arguments `rew`, `fin`, and `val` are $r_t$, $d_t$ and $v(s_t)$,
    respectively! The td-error terms in GAE depend on $r_{t+1}$, $d_{t+1}$
    and on both $v(s_{t+1})$ and $v(s_t)$, hence on `rew[1:]`, `fin[1:]`,
    `val[1:]` and `val[:-1]`. `val[-1]` is the value-to-go estimate for
    the last state in the related trajectory fragment.
    """
    gae, ret = {}, {}
    for k in rew:
        ret[k], gae[k] = pyt_ret_gae(
            rew[k][1:], fin[1:], val[k],
            gam=gam[k], lam=lam[k],
        )

    return gae, ret

The Advantage (GAE) Actor-Critic loss on a fragment.
* this one is on-policy, thus it expects the data in `vp` to be diff-able.

In [None]:
def a2c_gae(input, vp, *, gam, lam, alpha):
    """do GAE-A2C learning w. intrinsic motivation."""
    # (gae) compute GAE and returns for all rewards
    # XXX r_{t+1}, v_t, v{t+1} -->> A_t, R_t; `rew[t]`, `fin[t]`, and
    # `val[t]` must be $r_t$, $d_t$ and $v(s_t)$, respectively, where
    # `v_t` is computed based on the historical data $\cdot_s$ with `s < t`!
    gae, ret = pg_targets(input.rew, vp.val, gam, lam, fin=input.fin)

    # (gae + motivation) get the weighted sum of advantages
    adv = reduce(gae, alpha).detach()

    return {
        # (sys) policy grad surrogate `sg(G_t) \log \pi_t(a_t)`
        'pg': plyr.apply(pyt_polgrad, vp.pol, input.act, adv=adv),
        # (sys) entropy of the policy
        'entropy': plyr.apply(pyt_entropy, vp.pol),
        # (sys) intrinsic/extrinsic critic loss            
        'critic': plyr.apply(pyt_critic, vp.val, ret),
    }

A helper for computing V-trace targets, that takes care of chipping
off the initial `rew` and `fin`.

In [None]:
from nle_toolbox.utils.rl.returns import pyt_td_target, pyt_vtrace
from nle_toolbox.utils.rl.returns import trailing_broadcast

def pyt_vtrace_helper(rew, val, gam, *, fin, rho, r_bar, c_bar):
    return pyt_vtrace(rew[1:], fin[1:], val, rho, gam=gam, r_bar=r_bar, c_bar=c_bar)

def pyt_td_target_helper(rew, vtrace, val, gam, *, fin, rho):
    # [O(T B F)] get the importance-weighted td(0) error advantages
    # see sec. "v-trace actor-critic algo" (p. 4) in Espeholt et al. (2018)
    target = pyt_td_target(rew[1:], fin[1:], vtrace, gam=gam)
    return target.sub(val[:-1]).mul(trailing_broadcast(rho, rew))

A procedure to get the likelihood of actions uder a given policy sequence

In [None]:
from nle_toolbox.utils.rl.engine import pyt_logpact

def impala(input, vp, myu, *, gam, alpha, r_bar, c_bar):
    # (impala) For the advantages targets, IWs and critic targets we
    #  assume `vp` is diff-able, while `myu` isn't. Both are unstructured.
    val_, pol_ = plyr.apply(torch.Tensor.detach, vp)

    # (impala) get the importance weights
    # XXX `act[t]` is $a_{t-1}$, pol[t] is $\pi_t$ and $a_t \sim \pi_t$
    rho = pyt_logpact(pol_, input.act) - pyt_logpact(myu.pol, input.act)

    # (impala) get the V-trace targets $v_t$, t=0..N. Note that IMPALA uses
    #  values form the behavioural policy `myu`, rather than the current
    #  `vp`. see sec. 4.1 (p. 3) in Espeholt et al. (2018)
    # XXX `rew[t], fin[t]` ($r_t, f_t$) PRECURSE `val[t]` ($v_t$), hence
    # we drop the first values from `rew` and `fin`.
    vtrace = plyr.apply(
        pyt_vtrace_helper, input.rew, myu.val, gam, fin=input.fin,
        rho=rho, r_bar=r_bar, c_bar=c_bar)

    # (impala) get the importance-weighted td(0)-error advantages.
    # XXX unclear if we backprop thru `vp.val` here if the `val`
    # and `pol` networks have shared parameters.
    adv = plyr.apply(
        pyt_td_target_helper, input.rew, vtrace, val_, gam,
        fin=input.fin, rho=rho.exp().clamp_(max=r_bar))

    # (motivation) get the weighted sum of advantages
    adv = reduce(adv, alpha).detach()
    vtarget = plyr.apply(lambda t: t[:-1], vtrace)
    return {
        # (sys) policy grad surrogate `sg(A_t) \log \pi_t(a_t)`
        'pg': pyt_polgrad(vp.pol, input.act, adv=adv),
        # (sys) entropy of the policy
        'entropy': plyr.apply(pyt_entropy, vp.pol),
        # (sys) intrinsic/extrinsic critic loss
        'critic': plyr.apply(pyt_critic, vp.val, vtarget),
    }, rho.exp()

Cache the hyper-parameters

In [None]:
n_total = wandb.config.n_total
n_fragment_length = wandb.config.n_fragment_length
n_batch = wandb.config.n_batch

f_lambda = wandb.config.f_lambda
f_gamma = wandb.config.f_gamma
f_alpha = {'ext': 1., 'int': wandb.config.f_alpha}

f_h0_lerp = wandb.config.f_h0_lerp
f_rnd_loss_dropout = wandb.config.f_rnd_loss_dropout

# `-ve` maximizes, `+ve` minimizes
f_C = {
    'pg': -wandb.config.C_pg,
    'entropy': -wandb.config.C_entropy,
    'critic': wandb.config.C_critic,
}

n_off_policy = wandb.config.n_off_policy

f_grad_norm = wandb.config.f_grad_norm

Progress bar update and termination condition checker.

In [None]:
import tqdm

def progress(bar, n):
    bar.update(n - bar.n)
    return n < bar.total

A service function to get diagnostic stats and exploration metrics from an episode.

In [None]:
import math
from nle.nethack import NLE_BL_SCORE
from nle_toolbox.utils.env.defs import MAX_ENTITY, GLYPH_CMAP_OFF, symbol
from nle.env.base import NLE

from collections import namedtuple
Episode = namedtuple('Episode', 'input,output,info')

def ep_stats(ep, *, S_stone=symbol.S_stone + GLYPH_CMAP_OFF):
    assert isinstance(ep, Episode)

    met = {}
    # determine the offset for unfinished episodes
    offset = int(ep.input.fin[-1])
    if len(ep.input.fin) <= offset:
        return met

    # convert to numpy `npy = plyr.apply(np.asarray, ep)`
    # episode duration and total score
    n_length = len(ep.input.fin) - offset

    # score the episode (return = \sum_j r_{j+1} = sum r[j], j=1..N-1)
    f_return = plyr.apply(lambda r: r[1:].sum(), ep.input.rew)
    f_score = ep.input.obs['blstats'][-1-offset, NLE_BL_SCORE]

    # count the number of unique glyphs seen during the episode
    vic = ep.input.obs['vicinity'][:(-1 if offset > 0 else None)]
    cnt = torch.bincount(vic.flatten(), minlength=MAX_ENTITY + 1)
    n_unique = cnt.gt(0).sum()

    # The empirical entropy is a fine proxy for the diversity, since
    # it measures the amount of information content the seen glyphs.
    proba = cnt.div(cnt.sum())
    f_entropy = F.kl_div(proba.new_zeros(()), proba, reduction='sum').neg()

    # coverage and action effectiveness
    gly = ep.input.obs['glyphs'][:(-1 if offset > 0 else None)]
    # XXX we exclude the terminal obs, because it is actually the init
    #  obs from the next episode
    non_stone = (gly != S_stone).float().mean((-2, -1))
    f_cov = non_stone.max() / non_stone.min()
    f_eff = sum((g0 != g1).sum() / g1.numel() for g0, g1 in zip(gly, gly[1:]))
    
    # inspect the last info dict of the episode
    b_death = np.nan
    if ep.info:
        nfo = plyr.apply(lambda x: x[-1].item(), ep.info)

        # indicate if the episode ended in agent's death
        b_death = float(nfo['end_status'] == NLE.StepStatus.DEATH)

    return {
        'duration': n_length,
        'return': plyr.apply(float, f_return),
        'score': int(f_score),
        'n_unique': int(n_unique),
        'diversity': float(f_entropy) / math.log(2),
        'coverage': float(f_cov),
        'effectiveness': float(f_eff),
        'b_death': b_death,
    }

Aggregate the diagnostic stats across several episodes

In [None]:
def ep_aggregate(episodes):
    # filter out length one episodes
    metrics = list(filter(bool, map(ep_stats, episodes)))
    if not metrics:
        return {}

    # share of episodes that ended in agent's death
    f_deaths = np.mean([m['b_death'] for m in metrics])
    metrics = plyr.apply(np.median, *metrics, _star=False)

    metrics['b_death'] = f_deaths
    return metrics

<br>

Create the vectorized env and the agent

In [None]:
from nle_toolbox.utils.rl.engine import SerialVecEnv

env = SerialVecEnv(factory, n_envs=wandb.config.n_batch)

How to:
* do logging and capsule learning?
* organize the intrinsic motivator module?
  * ims are like regular actors, except their `actions` are rewards!
```python
    class SelfSupervisedRewards(nn.Module):
        def forward(self, obs, act=None, rew=None, fin=None, *, hx=None):
            return rew, (), hx
```
  * let's make intrinsict motivation modules be truly __exonegeous__ wrt the agents
* set up the optimizers for various submodules?

In [None]:
import operator as op
from torch.nn.utils import clip_grad_norm_

class BaseLearner(nn.Module):
    def __init__(self, agent):
        super().__init__()
        self.agent = agent
        self.epx = EpisodeExtractor()

    def forward(self, input, *, hx=None):
        return self.agent(**input._asdict(), hx=hx)

    def learn(self, input, vp, *, gx=None, hx=None, nfo=None):
        # (sys) do GAE-A2C learning w. intrinsic motivation `vp` is (v_t, \pi_t)
        terms = a2c_gae(input, vp, gam=f_gamma, lam=f_lambda, alpha=f_alpha)

        # (sys) compute the loss
        loss = reduce(terms, f_C)

        # (sys) backprop through the agent and the online network of RND
        self.agent.optim.zero_grad(True)
        loss.backward()
        grad = clip_grad_norm_(self.parameters(), 5.)
        self.agent.optim.step()

        # (sys) extra batch steps thru the current fragment with IMPALA
        myu = plyr.apply(torch.Tensor.detach, vp)
        for _ in range(n_off_policy):
            _, vp, _ = self(input, hx=gx)
            terms_, rho_ = impala(input, vp, myu, gam=f_gamma,
                                  alpha=f_alpha, r_bar=1., c_bar=1.)

            self.agent.optim.zero_grad(True)
            reduce(terms_, f_C).backward()
            grad = clip_grad_norm_(self.parameters(), f_grad_norm)
            self.agent.optim.step()

        # (sys) extract episode strands (drop the last record, due to overlap)
        vw_main = plyr.apply(lambda x: x[:-1], input)
        episodes = self.epx.extract(vw_main.fin, Episode(
            vw_main,
            (),
            plyr.apply(torch.as_tensor, *nfo, _star=False),
        ))

        # (sys) recompute the recurrent state `hx` AFTER the update over the
        # proper part of the fragment (t=0..N-1)
        if hx is not None:
            h0 = self.agent.initial_hx
            # `hx` is $h_N$ from $h_{t+1}, y_t = F(z_t, h_t; w)$, t=0..N-1.
            # We assume that one of update's side effects is the recurrent
            # state in `hx` stale. We should recompute `hx` over the entire
            # historical trajectory, however, since we do not store it whole,
            # the next best thing is to make a second pass over the just
            # collected fragment $(z_t)_{t=0}^{N-1}$ and use $h'_N$ as the new
            # `hx`, where $h'_{t+1}, y_t = F(z_t, h'_t; w')$ with $h'_0 = h_0$
            # given by `gx`.
            with torch.no_grad():
                _, _, hx = self(vw_main, hx=gx)

            # DO NOT backprop through `hx` form the next fragment (truncated bptt)
            hx = plyr.apply(torch.Tensor.detach, hx)
            if h0 is not None and f_h0_lerp > 0:
                hx = plyr.apply(torch.lerp, hx, h0, weight=f_h0_lerp)
    
        terms = plyr.apply(float, terms)
        # XXX the batch-size normalization (TxB) can be carried out on
        #  the logger's side as has been done with other metrics.
        terms['entropy'] /= vw_main.fin.numel()
        out = {'loss/loss': float(loss), 'loss/grad': float(grad)}
        out.update({'loss/' + k: v for k, v in terms.items()})

        # report diagnostic stats for completed episodes
        out.update({'metrics/' + k: v for k, v in ep_aggregate(episodes).items()})

        return out, hx

RND motivator

In [None]:
class RNDMotivator(nn.Module):
    def __init__(self, rnd, *, f_dropout=0.0, max=None):
        super().__init__()
        self.rnd = rnd

        self.max = max
        self.f_dropout = f_dropout

    def forward(self, input, *, hx=None):
        # (rnd) compute the diff-able intrinsic reward
        #     r^I_t = \ell(f(x_t), \bar{f}(x_t)), t=0..N
        online, _ = self.rnd.online(**input._asdict(), hx=hx)
        with torch.no_grad():
            target, _ = self.rnd.target(**input._asdict(), hx=hx)

        error = F.mse_loss(online, target, reduction='none').sum(-1)

        # (rnd) intinsic rewards are non-diffable (optionally clamped)
        rewards = error.detach().clone()
        if self.max is not None:
            rewards.clamp_(max=self.max)

        # (api) return the rewards and the loss
        return rewards, error, None

    def learn(self, input, error, *, gx=None, hx=None, nfo=None):
        # (rnd) get the dropped out loss
        loss = F.dropout(error, p=self.f_dropout, training=True).sum()

        self.rnd.optim.zero_grad(True)
        loss.backward()
        self.rnd.optim.step()

        # wandb.log(, commit=False)
        return {
            'loss/rnd': float(loss),
        }, None

In [None]:
from functools import wraps

def capture(fn, to):
    @wraps(fn)
    def _wrapper(*args, **kwargs):
        nfo, _ = result = fn(*args, **kwargs)
        to.update(nfo)
        return result

    return _wrapper

In [None]:
learner = BaseLearner(agent)

In [None]:
motivator = RNDMotivator(rnd, f_dropout=0.2)

learning process

In [None]:
from nle_toolbox.utils.rl.engine import Input, prepare

log = {}

caps = capsule(
    learner,
    capture(learner.learn, log),
    length=wandb.config.n_fragment_length,
)

motv = capsule(
    motivator,
    capture(motivator.learn, log),
    length=wandb.config.n_rnd_fragment_length,
)

n_steps = 0
with tqdm.tqdm(initial=n_steps, total=n_total, ncols=70, disable=False) as bar:
    # the local runtime state for the main exec flow
    # env is reset, and act is sampled from the env
    npy, pyt = prepare(env, rew=0., fin=True)
    
    # specify the motivator and get its first reward
    npy_mot = launch(motv, Input(npy.obs, npy.rew, npy.rew, npy.fin))

    # recv the first action of the encapsulated agent
    act = launch(caps, Input(
        npy.obs,
        npy.act,
        {'ext': npy.rew, 'int': npy_mot},
        npy.fin,
    ))

    while progress(bar, n_steps):
        # step thru the env and write x_{t+1}, a_t, r_{t+1}, d_{t+1}
        obs, rew, fin, nfo = env.step(act)
        plyr.apply(np.copyto, npy, Input(obs, act, rew, fin))
        n_steps += len(env)

        # compute intrinsic motivation (self-supervised rewards)
        rew = {'ext': rew, 'int': motv.send((obs, rew, fin, nfo))}

        # decide which capsule to route the data to
        act = caps.send((obs, rew, fin, nfo))

        # (log) the training progress
        if log:
            wandb.log({'n_steps': n_steps, **log}, commit=True)
            log.clear()

In [None]:
import os
from nle_toolbox.utils.io import mkstemp

target = os.path.abspath('./checkpoints')
os.makedirs(target, exist_ok=True)

checkpoint = mkstemp('.pt', f'ckpt__{wandb.run.id}__{n_steps}__', dir=target)
torch.save({
    'cls': str(NLEAgent),
    'agent': agent.state_dict(),
    'agent.optim': agent.optim.state_dict(),
    'rnd': rnd.state_dict(),
    'rnd.optim': rnd.optim.state_dict(),
    'config': dict(wandb.config),
}, checkpoint)
print(checkpoint)

In [None]:
wandb.finish(quiet=True)

In [None]:
assert False

In [None]:
npy.fin

# we need to .send the data 

npy_cnt = np.zeros(fin.shape, int)

npy_cnt += 1
npy_cnt *= ~fin

npy_cnt

counter += fin

n_steps

{**obs, 'goal': torch.randn(*fin.shape, 32).numpy()}

fin

tuple(npy.obs)

pyt.act

In [None]:
{k: p.grad.flatten().norm() for k, p in rnd.online.named_parameters() if p.grad is not None}

In [None]:
{k: p.grad.flatten().norm() for k, p in agent.named_parameters() if p.grad is not None}

In [None]:
assert False

<br>

We finally got something:
* reducing the dim of the RND output vector from 64 to 16 was, perhaps,
the most significant step in the direction of the agent learning at least
something.
* the next thing was removing the build embedding and using `condition`,
`hunger`, and `vitals
* the effect of positive `f_h0_lerp` has not been investigated
  * reducing it from `.05` to `0.0` seems to **adversely** impact learning


Next idea to try: like token + position in BERT and GPT, let's exploit additive embeddings.
* in glyphs: entity + group embedding -- entities semantics is modulated by its group (relevant to MON, PET, RIDE, STATUE), also add special `ego`-embedding at the centre of vicinity.
Let $g_{uv} \in \mathbb{G}$ be glyph at position $u, v$ relative to the `hero` (bls-x, -y coords).
$$
f_{uv}
    = w_E\bigl[\operatorname{entity}(g_{uv})\bigr]
    + w_G\bigl[\operatorname{group}(g_{uv})\bigr]
    + 1_{(0,0)}{(u, v)} w_{\mathrm{ego}}
  \,, $$
where $w_\cdot$ are game map-related embeddings
* in blstats: modulate an `ego`-embedding by the `vitals`, `condition` and other stats' embeddings.
  * should the the `ego`-embedding be shared between map and state?

<br>

A ranked buffer for episode rollouts.

* what do we do with the missing `hx`? clone episodes in full?
  * `you wake up on a cold stone floor in the middle of a vast chamber without a shred of memoery of how you got here. What do you do?` maybe it is OK to take contiguous fragments of a long episode and start with a wiped out memory.
* do we clone just the actions, or also the value function?

In [None]:
from heapq import heapreplace, heappushpop, heappush, heappop

from dataclasses import dataclass, field
from typing import Any

class RankedBuffer:
    @dataclass(order=True, frozen=True, repr=False)
    class RankedItem:
        rank: float
        item: Any = field(compare=False)

    def __init__(self, capacity):
        self.buffer = []
        self.capacity = capacity

    def push(self, rk, it):
        item = self.RankedItem(rk, it)
        # push the current item
        if len(self.buffer) < self.capacity:
            return heappush(self.buffer, item)
        # ... pop the lowest-ranking one, if we exceed capacity
        return heappushpop(self.buffer, item)

    def extend(self, pairs):
        last = None
        for rk, it in pairs:
            last = self.push(rk, it)
        return last

    def __bool__(self):
        return bool(self.buffer)

    def __getitem__(self, index):
        return self.buffer[index]

    def __repr__(self):
        return type(self).__name__ + f"({len(self.buffer)}/{self.capacity})"

    def __iter__(self):
        return ((el.rank, el.item) for el in self.buffer)

    def sample(
        self,
        n_samples=8,
        n_steps=64,
        *,
        rng=np.random.default_rng(),
    ):
        # determinie the sufficiently long episodes
        eligible = []
        for j, ep in enumerate(self.buffer):
            input = ep.item.input
            dur_ = len(input.fin) - int(input.fin[-1])
            if dur_ < n_steps:
                continue

            eligible.append((j, dur_,))

        # sample starting strands from episodes
        chunks = []
        for i in rng.choice(len(eligible), size=n_samples):
            k, dur_ = eligible[i]
            j = rng.integers(dur_ - n_steps + 1)

            chunk = plyr.apply(lambda t: t[j:j + n_steps],
                               self.buffer[k].item)
            chunks.append(chunk)

        return plyr.apply(torch.stack, *chunks, _star=False, dim=1)

A common routine to plot the computed rollout metrics

In [None]:
def plot_ep_metrics(metrics):
    fig, axes = plt.subplots(2, 2, figsize=(7, 3), dpi=300)

    out_ = plyr.apply(list, *metrics, _star=False)
    for ax, (nom, val) in zip(axes.flat, out_.items()):
        ax.hist(val, label=nom, log=nom in ('score', 'return',), bins=20)
        ax.set_title(nom)

    fig.tight_layout()
    return fig, axes

Rank the peisode and put it into a buffer

In [None]:
def add_episodes(buf, iterable, *, C_cov=0.1):
    for ep in iterable:
        met = ep_stats(ep)
        if not met:
            continue

        rk = met['return'] + C_cov * met['coverage'] / met['duration']
        buf.push(rk, ep)

One step in the joint differentiable rollout collection and a procedure to collect a differentiable rollout

In [None]:
from nle_toolbox.utils.rl.engine import step

def collect(env, agent, npyt, hx, *, n_steps, visualize=None, fps=0.01):
    """Collect a fragment of the trajectory."""
    # (sys) get a view into numpy's observation arrays
    vw_vis = None
    if visualize is not None:
        vw_vis = plyr.apply(plyr.getitem, npyt.npy.obs, index=visualize)

    for j in range(n_steps):
        if vw_vis is not None:
            ipynb_render(vw_vis, clear=True, fps=fps)

        # (sys) get $(x_t, a_{t-1}, r_t, d_t), v_t, \pi_t$
        _, hx, _ = out = step(env, agent, npyt, hx)
        yield out

Prepare a strand extractor and a buffer for ranking episodes, and ready the vectorized env and the runtime context.

In [None]:
epx, buf = EpisodeExtractor(), RankedBuffer(128)
metrics = []

env = SerialVecEnv(factory, n_envs=4)
npyt, hx = prepare(env, rew=0., fin=True), None

A visualized evaluation run.

In [None]:
n_total = 16384
n_steps, visualize = 0, 0
with torch.no_grad(), tqdm.tqdm(
    initial=n_steps, total=n_total, ncols=80,
    disable=visualize is not None,
) as bar:
    nfo_ = {}
    while progress(bar, n_steps):
        # (sys) collect a fragment of the episode time `t` afterstates, t=0..N-1
        fragment, hxx, nfo = zip(*collect(
            env, agent, npyt, hx, n_steps=128,
            visualize=visualize, fps=0.05
        ))
        # XXX `fragment` is ((x_t, a_{t-1}, r_t, d_t), v_t, \mu_t), t=0..N-1

        # (sys) retain running state `hx`, but detach its grads (truncated bptt)
        # ATTN do not update `npyt` and `hx`!
        if hxx[-1] is not None:
            hx = plyr.apply(torch.Tensor.detach, hxx[-1])
        
        # (sys) shift and collate the info dicts
        #    `d0, (d1, ..., dn)` -->> `(d0, ..., d{n-1}), dn`
        *nfo, nfo_ = nfo_ or nfo[0], *nfo
        nfo = plyr.apply(torch.as_tensor, *nfo, _star=False)

        # (sys) repack the fragment data
        # XXX note, `.act[t]` is $a_{t-1}$, but the other `*[t]` are $*_t$,
        #  e.g. `.rew[t]` is $r_t$, and `pol[t]` is `$\pi_t$.
        input, output = plyr.apply(torch.cat, *fragment, _star=False)

        # (sys) incerment the step count
        n_steps += input.fin.numel()

        # (sys) extract episode strands with log-probs of the taken actions
        episodes = epx.extract(input.fin, Episode(input, output, nfo))
        add_episodes(buf, episodes, C_cov=0.5)

        # (evl) compute the metrics of the completed episodes
        for ep in episodes:
            met = ep_stats(ep)
            if met:
                metrics.append(met)

    # (sys) extract the residual episode strands
    unfinished = epx.finish()
    add_episodes(buf, unfinished, C_cov=0.5)

plot_ep_metrics(metrics);

The stats of the episodes in the buffer.

In [None]:
from matplotlib import pyplot as plt

out, fps = [], None
for rk, ep in buf:
    npy = plyr.apply(np.asarray, ep.input)
    off = int(npy.fin[-1])  # offset for unfinished strands
    for t in range(len(npy.fin) - off):
        obs = plyr.apply(plyr.getitem, npy.obs, index=t)
        ipynb_render(obs, fps=fps)

    met = ep_stats(ep)
    if met:
        out.append(met)

plot_ep_metrics(out);

Get ready to clone the successful episodes in the ranked buffer

In [None]:
n_samples, n_steps = 8, 64
r_bar, c_bar = 1.01, 1.1
rng = np.random.default_rng()

losses = []

Behaviour cloning
* let's have a look at [Self-Imitation Learning](https://proceedings.mlr.press/v80/oh18b.html)
> ... off-policy actor-critic algorithm that learns to reproduce the agent’s past good decisions.
... that exploiting past good experiences can indirectly drive deep exploration.

In [None]:
for k in tqdm.tqdm(range(50), ncols=70):
    # sample a bath of trajectory fragments
    input, myu, _ = out = buf.sample(n_samples, n_steps, rng=rng)

    # recompute the policy and value-to-go estimates for the episode
    # XXX amnesia training: forget the hx
    _, vp, _ = agent(input.obs, input.act, input.rew, hx=None, fin=input.fin)
    # XXX this is not EXACTLY identical to `fin=ep_.fin`, which is guaranteed
    #  to contain a reset `fin[0]` and possibly a `fin[-1]` (not in case when
    #  the episode is unfinished). We ignore pol[-1] $\pi_{T}$ and val[-1]
    #  $v(s_{T})$, both of which pertain to the next episode. `fin` affects
    #  only the recurrrent state anyway and 1) we set the initial to `None`,
    #  and 2) do not ever use the

    L_loglik = pyt_polgrad(vp.pol, input.act, adv=1.)
    ell = - L_loglik

    # get the v-trace target for the critic and the advantages to pol-grad
    # XXX here `.fin[-1]` properly blocks the last state-value backup
#     ret, _, rho = pyt_impala(
#         input.rew, input.fin, input.act,
#         val['ext'], pol, myuval['ext'], myupol,
#         gam=f_gamma['ext'], r_bar=r_bar, c_bar=c_bar
#     )
#     L_critic = pyt_critic(val['ext'], ret)

#     ell = (L_critic * (C_critic['ext'] / 2) - L_loglik)

    agent.optim.zero_grad()
    ell.backward()
    agent.optim.step()

#     losses.append((float(L_loglik), float(L_critic)))
    losses.append((float(L_loglik),))

In [None]:
L_loglik, = map(np.array, zip(*losses))

plt.semilogy(-L_loglik, c='C1')

In [None]:
assert False

In [None]:
from minihack.envs import register
from nle import nethack
from nle.env.tasks import NetHackScore

class WalkingNetHack(NetHackScore):
    def __init__(self, *args, observation_keys, **kwargs):
        kwargs["actions"] = kwargs.pop(
                "actions", tuple(nethack.CompassDirection)
        )
        super().__init__(*args, **kwargs, observation_keys=observation_keys)

register(
    id="NetHackChallengeMovingOnly-v0",
    entry_point=WalkingNetHack,
)

from functools import partial
fac = partial(factory, id='NetHackChallengeMovingOnly-v0')



<br>

In [None]:
f_max_grad = 1.
epx = EpisodeExtractor()

In [None]:
def learner(input, hx=None):
    return agent(**input._asdict(), hx=hx)

Block grads thorugh the runtime state `hx`

In [None]:
def truncate(hx, h0, *, alpha):
    if hx is None:
        return None

    # (sys) DO NOT backprop through the runtime state into
    # the previous fragment (truncated bptt)
    hx = plyr.apply(torch.Tensor.detach, hx)

    # (sys) DO pass some grad feedback into `h0` from the fragment
    #   `.lerp: hx <<-- (1 - w) * hx + w * h0` (broadcasts correctly).
    if h0 is not None and alpha > 0:
        hx = plyr.apply(torch.lerp, hx, h0, weight=alpha)

    return hx

import operator as op

def reduce(values, weight=None):
    flat = []
    if weight is not None:
        values = plyr.apply(op.mul, values, weight)
    plyr.apply(flat.append, values)
    return sum(flat)

def timeshift(state, *, shift=1):
    """Get current and shifted slices of nested objects."""
    # use `xgetitem` to let None through
    # XXX `curr[t]` = (x_t, a_{t-1}, r_t, d_t), t=0..T-H
    curr = plyr.apply(plyr.xgetitem, state, index=slice(None, -shift))

    # XXX `next[t]` = (x_{t+H}, a_{t+H-1}, r_{t+H}, d_{t+H}), t=0..T-H
    next = plyr.apply(plyr.xgetitem, state, index=slice(+shift, None))

    return curr, next

In [None]:
from torch.nn.utils import clip_grad_norm_

def update(input, output, hx=None):
    """do GAE-A2C learning w. intrinsic motivation"""
    # (t-bptt) truncate grads thru the initial recurrent state
    hx = truncate(hx, agent.initial_hx, alpha=f_h0_lerp)

    # (sys) do a diff-able forward pass thru the agent to get
    #  (v_t, \pi_t), t=0..N, staring with h_0 over the recorded
    #  fragment (x_t, a_{t-1}, r_t, f_t), t=0..N, but ignore,
    #  the updated `hx` (h_{N+1})
    _, (val, pol), _ = learner(input, hx=hx)

    # (gae) compute GAE and returns for all rewards
    # XXX `rew`, `fin`, `val` must be $r_t$, $d_t$ and $v(s_t)$!
    gae, ret = pg_targets(input.rew, val, f_gamma, f_lambda, fin=input.fin)
    
    # (sys) policy grad surrogate (uses common gae!), entropy of
    #  the policy, and the intrinsic/extrinsic critic loss.
    # XXX r_{t+1}, v_t, v{t+1} -->> A_t \log \pi_t(a_t)
    adv = reduce(gae, {'ext': 1., 'int': f_alpha}).detach()
    terms = {
        'pg': plyr.apply(pyt_polgrad, pol, input.act, adv=adv),
        'entropy': plyr.apply(pyt_entropy, pol),
        'critic': plyr.apply(pyt_critic, val, ret),        
    }

    # (sys) compute the loss
    loss = reduce(terms, {
        'pg': C_pg,
        'entropy': C_entropy,
        'critic': {
            'ext': C_critic['ext'] / 2,
            'int': C_critic['int'] / 2,
        }
    })

    # (sys) backprop through the agent
    agent.optim.zero_grad()
    loss.backward()
    grad = clip_grad_norm_(agent.parameters(), f_max_grad)
    agent.optim.step()

    # (sys) get the primary slice of the framgent (the last one overlaps)
    #  and extract episode strands
    vw_main = plyr.apply(plyr.xgetitem, input, index=slice(None, -1))
    episodes = epx.extract(vw_main.fin, vw_main)

    # (sys) recompute the runtime state for the next fragment
    with torch.no_grad():
        _, _, hx = learner(vw_main, hx=hx)
        # (sys) compute the new starting `hx`: passed thru the
        # __updated__ agent and properly mixed with the new `h0`.
        hx = truncate(hx, agent.initial_hx, alpha=f_h0_lerp)

    # (sys) collect the data for logging
    terms = plyr.apply(float, terms)
    terms['entropy'] /= vw_main.fin.numel()
    terms.update({
        'loss': float(loss),
        'grad': float(grad),
    })

    log = {'loss': terms}
    if episodes:
        log['metrics'] = plyr.apply(
            np.median, *map(ep_stats, episodes), _star=False)

    return hx, log

In [None]:
from nle_toolbox.utils.rl.engine import Input, prepare
from nle_toolbox.utils.rl.capsule import collect, launch

log = {}

caps = collect(
    learner,
    n_fragment_length,
    update=capture(update, to=log),
)

# the local runtime state for the main exec flow
# env is reset, and act is sampled from the env
npy, pyt = prepare(env, rew=0., fin=True)

# recv the first action of the encapsulated agent
act = launch(caps, Input(
    npy.obs, npy.act, {'ext': npy.rew, 'int': npy.rew}, npy.fin
))

n_steps = 0
while n_steps < 2*n_fragment_length * len(env):
    # step thru the env and write x_{t+1}, a_t, r_{t+1}, d_{t+1}
    obs, rew, fin, nfo = env.step(act)
    plyr.apply(np.copyto, npy, Input(obs, act, rew, fin))
    n_steps += len(env)

    # compute intrinsic motivation (self-supervised rewards)
    rew = {'ext': rew, 'int': torch.randn(rew.shape)}

    # decide which capsule to route the data to
    act = caps.send((obs, rew, fin, nfo))

    if log:
        wandb.log({'n_steps': n_steps, **log}, commit=True)
        log.clear()

<br>

In [None]:
with torch.no_grad():
    # `(W_ii|W_if|W_ig|W_io)`
    blsw = agent.core.weight_ih_l0[:, 128:]
    blsw = blsw.reshape(len(blsw), -1, 5)
    norms = blsw.norm(p=2, dim=-1)

norms = dict(zip([
    'hunger', 'status', 'hp', 'mp',
    'str', 'dex', 'con', 'int', 'wis', 'cha',
    'strprc', 'AC', 'encumberance',
], norms.T))

breaks = 0, 128, 256, 384, 512
c_pair = "C0", "C1"
fig, axes = plt.subplots(2,2, figsize=(7, 7,), dpi=300, sharey=True, sharex=True)
for nom, ax in zip(norms, axes.flat):
    ax.semilogy(norms[nom], label=nom)
    for j, (a, b) in enumerate(zip(breaks[1:], breaks)):
        ax.axvspan(a, b, color=c_pair[j&1], alpha=0.05, zorder=-10)

    ax.legend(fontsize='xx-small')

<br>

In [None]:
# plyr.ragged(lambda v, C: C * v.sum(), pg, C_pg)

In [None]:
flat = []
plyr.ragged(lambda v, C: flat.append(C * v), pg_gae, C_pg)
plyr.ragged(lambda v, C: flat.append(C * v), entropy, C_entropy)
plyr.ragged(lambda v, C: flat.append(C * v), critic, C_critic)

In [None]:
flat

Remove currently unused fileds from the observations

In [None]:
from nle_toolbox.utils.env.defs import MAX_GLYPH

def filter(
    glyphs,
    blstats,
    inv_letters,
    inv_glyphs,
    **ignore,
):
    return dict(
        glyphs=glyphs,
        blstats=blstats,
        inv_letters=inv_letters,
        inv_glyphs=inv_glyphs,
    )

<br>

     y  k  u  
      \ | /   
    h - . - l 
      / | \   
     b  j  n  

Spell tracker

In [None]:
def spellcaster(obs, mask, *, dir='.', ctoa):
    yield from map(ctoa.get, f'Z{letter}{dir}')

Random policy

In [None]:
def linger(obs, mask, n=16, *, seed=None, ctoa=None):
    rng, j = np.random.default_rng(seed), 0
    while not mask.all() and j < n:
        # if we're in LINGER state, pick a random non-forbidden action
        # XXX whelp... tilde on int8 is `two's complement`, not the `logical not`
        act = rng.choice(*np.logical_not(mask).nonzero())

        obs, mask = (yield act)
        j += 1

def search(obs, mask, n=6, *, ctoa):
    yield from map(ctoa.get, f'{n:d}s')

Level and dungeon mapper

In [None]:
from nle.nethack import (
    NLE_BL_X,
    NLE_BL_Y,
    NLE_BL_DNUM,
    NLE_BL_DLEVEL,
    # NLE_BL_DEPTH,  # derived from DNUM and DLEVEL
    # XXX does not uniquely identify floors,
    #  c.f. [`depth`](./nle/src/dungeon.c#L1086-1084)
    DUNGEON_SHAPE,
    MAX_GLYPH,
)

from nle_toolbox.utils.env.defs import \
    glyph_is, dt_glyph_ext, ext_glyphlut

from nle_toolbox.bot.level import Level, DungeonMapper

Detemine the walkability of the observed tiles

In [None]:
from nle_toolbox.utils.env.defs import symbol, GLYPH_CMAP_OFF, glyph_group, get_group
from nle_toolbox.utils.env.defs import glyphlut, ext_glyphlut

closed_doors = get_group(symbol, GLYPH_CMAP_OFF, *[
    'S_vcdoor', 'S_hcdoor',
    'S_vcdbridge', 'S_hcdbridge',
])

open_doors = get_group(symbol, GLYPH_CMAP_OFF, *[
    'S_ndoor',
    'S_vodoor', 'S_hodoor',
    'S_vodbridge', 'S_hodbridge',
])

is_closed_door = np.isin(ext_glyphlut.id.value, np.array(list(closed_doors)))
is_actor = np.isin(ext_glyphlut.id.group, np.array(list(glyph_group.ACTORS)))
is_pet = ext_glyphlut.id.group == glyph_group.PET

is_open_door = np.isin(ext_glyphlut.id.value, np.array(list(open_doors)))
is_object = np.isin(ext_glyphlut.id.group, np.asarray(list(glyph_group.OBJECTS)))
is_walkable = ext_glyphlut.is_accessible | is_open_door | is_object

In [None]:
traps = get_group(symbol, GLYPH_CMAP_OFF, *[
    'S_arrow_trap',
    'S_dart_trap',
    'S_falling_rock_trap',
    'S_squeaky_board',
    'S_bear_trap',
    'S_land_mine',
    'S_rolling_boulder_trap',
    'S_sleeping_gas_trap',
    'S_rust_trap',
    'S_fire_trap',
    'S_pit',
    'S_spiked_pit',
    'S_hole',
    'S_trap_door',
    'S_teleportation_trap',
    'S_level_teleporter',
    'S_magic_portal',
    'S_web',
    'S_statue_trap',
    'S_magic_trap',
    'S_anti_magic_trap',
    'S_polymorph_trap',
    'S_vibrating_square',
])

is_trap = np.isin(ext_glyphlut.id.value, np.array(list(traps)))

The core of the "smart" dungeon explorer

In [None]:
from scipy.special import softmax

def crawler(obs, mask, *, dir, seed=None):
    dng = DungeonMapper()

    # own random number generator
    rng = np.random.default_rng(seed)

    # a simple state machine: linger <<-->> crawler
    state, n_linger, stack = 'linger', 16, []
    while True:
        dng.update(obs)
        pos = dng.level.trace[-1]

        if state == 'crawl':
            if stack:
                plan.pop()
                act = dir[stack.pop()]

            else:
                state, n_linger = 'linger', 16
                continue

        elif state == 'linger':
            if n_linger > 0:
                n_linger -= 1

                # if we're in LINGER state, pick a random non-forbidden action
                # XXX whelp... tilde on int8 is `two's complement`, not the `logical not`
                act = rng.choice(*np.logical_not(mask).nonzero())

            else:
                lvl = dng.level

                # we've run out linger moves, time to pick a random destination
                # and go to it
                state = 'crawl'

                # get the walkability cost
                cost = np.where(
                    # is_walkable[lvl.bg_tiles.glyph]
                    (is_walkable | is_pet)[lvl.bg_tiles.glyph]
                    , .334, np.inf)
                # XXX adjust `cost` for hard-to-pass objects?
                cost[is_trap[lvl.bg_tiles.glyph]] = 10.

                # get the shortest paths from the current position
                value, path = dij(cost, pos)

                # draw a destination, the further the better
                prob = softmax(np.where(
                    is_closed_door[lvl.bg_tiles.glyph],
                    100.,
                    np.where(
                        np.logical_and(
                            np.isfinite(value),
                            np.logical_not(
                                is_trap[lvl.bg_tiles.glyph]
                            )
                        ), value, -np.inf
                    ))
                )
                dest = divmod(rng.choice(prob.size, p=prob.flat), prob.shape[1])

                # reconstruct the path to the destination in reverse order
                plan = list(backup(path, dest))
                for (r1, c1), (r0, c0) in zip(plan, plan[1:]):
                    stack.append(dir_to_ascii[r1-r0, c1-c0])

                plan.pop()
                continue

        obs, mask = yield act

How do we want to explore?
* open closed doors
* explore tunnels

Implementing the random dungeon crwaler

In [None]:
dng = getgeneratorlocals(gen).get('dng')
# dng.level.trace[-1]

In [None]:
plt.imshow(dng.level.bg_tiles.info.is_accessible)

     y  k  u  
      \ | /   
    h - . - l 
      / | \   
     b  j  n  

In [None]:
# get the walkability cost
cost = np.where((
    is_walkable
    | is_pet
)[obs['glyphs']], 1., np.inf)
# XXX adjust `cost` for hard-to-pass objects?
cost[is_trap[obs['glyphs']]] = 10.

# get shroteste paths from the current position
bls = obs['blstats']
value, path = dij(cost, (bls[NLE_BL_Y], bls[NLE_BL_X]))

prob = softmax(np.where(
    np.logical_and(
        np.isfinite(value),
        np.logical_not(
            is_trap[obs['glyphs']]
        )
    ), value, -np.inf
))

plt.imshow(value)

<hr>

Test the algo

In [None]:
r, c = 12, 12

rng = np.random.default_rng()  #248675)

cost = -np.log(rng.random((21, 79)))
# cost = np.ones((21, 79))
cost[rng.random(cost.shape) < .5] = np.inf

value, path = dij(cost, (r, c))


# mask = is_walkable[lvl.bg_tiles.glyph] | is_walkable[lvl.stg_tiles.glyph]
mask = np.isfinite(value)
mask[r, c] = False  # mask the current position

from scipy.special import softmax

value = np.where(value > 5, 0., -np.inf)
prob = softmax(np.where(mask, value, -np.inf))

Play around with the shortes path.

In [None]:
r, c = divmod(rng.choice(prob.size, p=prob.flat, ), prob.shape[1])

displ = cost.copy()
plan = list(backup(path, (r, c)))
for ij in plan:
    displ[ij] = 10
displ[12, 12] = 11


fig, ax = plt.subplots(1, 1, dpi=300)
ax.imshow(displ)

In [None]:
commands = []
for (r1, c1), (r0, c0) in zip(plan, plan[1:]):
    commands.append(dir_to_ascii[r1-r0, c1-c0])

''.join(reversed(commands))

<br>

In [None]:
assert False

A non-illegal random action exploration.

In [None]:
from copy import deepcopy
from nle_toolbox.bot.chassis import get_wrapper


def random_explore(seed=None, n_steps=1000, *, auto=False, fps=None, copy=False):
    """A non-illegal random action explorer.
    """
    ss_pol, ss_env = np.random.SeedSequence(seed).spawn(2)

    rng, j, n_linger, pf = np.random.default_rng(ss_pol), 0, 0, None
    with factory(seed=ss_env) as env:
        # we need access to the Chassis for additional meta state variables
        cha = get_wrapper(env, Chassis)

        # ActionMasker caches the esacpe action id
        ESC = get_wrapper(env, ActionMasker).escape

        # setup the dungeon mapper
        dng = DungeonMapper()

        # launch the episode
        (obs, mask), fin = env.reset(), False
        while (
            ipynb_render(obs, clear=True, fps=fps)
            and not (fin or j >= n_steps)
        ):
            # though nle reuses buffers, we do not deep copy them
            #  delegating this to the downstream user instead
            yield deepcopy(obs) if copy else obs

            # default to immediately escaping from any menu or prompt
            act = ESC
            if not (cha.in_menu or cha.prompt):
                dng.update(obs)

                # if we're in LINGER state, pick a random non-forbidden action
                # XXX whelp... tilde on int8 is `two's complement`, not the `logical not`
                act = rng.choice(*np.logical_not(mask).nonzero())

            (obs, mask), rew, fin, info = env.step(act)
            j += 1

            if fin and auto:
                ipynb_render(obs, clear=True, fps=fps)
                (obs, mask), fin = env.reset(), False

Get a random episode

In [None]:
# from inspect import getgeneratorlocals
episode = random_explore(
    seed=None,
    n_steps=256,
    auto=False,
    copy=True,
    fps=0.01,
)


glyphs = [next(episode)]
# dng = getgeneratorlocals(episode).get('dng')

glyphs.extend(obs['glyphs'] for obs in episode)

In [None]:
assert False

In [None]:
from scipy.special import softmax

def dstination_prob(lvl, pos):
    r, c = pos
    dist = np.maximum(abs(lvl.bg_tiles.rc.r - r), abs(lvl.bg_tiles.rc.c - c))

    mask = is_walkable[lvl.bg_tiles.glyph] | is_walkable[lvl.stg_tiles.glyph]
    mask[r, c] = False  # mask the current position
    return softmax(np.minimum(np.where(mask, dist, -np.inf), 5))

rng = np.random.default_rng()
prob = dstination_prob(dng.level, dng.level.trace[-1])
cost = np.where(prob > 0, 1., float('inf'))

plt.imshow(prob)

In [None]:
plt.imshow(cost)

In [None]:
def backup(path, dest):
    p0 = dest
    while True:
        p0, p1 = path[p0], p0
        yield p1
        if p0 is None:
            return

#         (r0, c0), (r1, c1) = p0, p1
#         yield directions[r1-r0, c1-c0]
        

In [None]:
value, path = dij(cost, dng.level.trace[-1])

In [None]:
val = value.copy()
r, c = rng.choice(dng.level.bg_tiles.rc.flat, p=prob.flat)

fig, ax = plt.subplots(1, 1, dpi=300)
for i, j in backup(path, (r, c)):
    val[i, j] = 0.

val[r, c] = np.inf

ax.imshow(val[:, 10:40])

<br>

<br>

In [None]:
from einops import repeat, rearrange
from transformers import ViTModel, ViTConfig

from transformers.models.vit.modeling_vit import to_2tuple


class NLEViTEmbeddings(nn.Module):
    def __init__(self, config):
        n_rows, n_cols = to_2tuple(2 * config.window + 1)
        super().__init__()

        self.cls = nn.Parameter(torch.zeros(
            1, config.hidden_size))

        self.posemb = nn.Parameter(torch.zeros(
            1 + n_rows * n_cols, config.hidden_size))

    def forward(self, input, **ignore):
        x = rearrange(input, 'B D H W -> B (H W) D')

        # cls-token and positional embedding
        cls = repeat(self.cls.unsqueeze(0), '() N D -> B N D', B=len(x))
        return torch.cat((cls, x), dim=1) + self.posemb


model = ViTModel(ViTConfig(
    hidden_size=128,
    num_hidden_layers=1,
    num_attention_heads=8,
    intermediate_size=512,
    window=2,
    # image_size,
    # patch_size,  # are ignored
))

model.embeddings = NLEViTEmbeddings(model.config)


# obs = input.obs
# gly = agent.features.obs(obs)

# win = gly['vicinity']

# size = dict(zip("TBCHW", win.shape[:3]))
# x = rearrange(win, 'T B C H W -> (T B) C H W')

# out = rearrange(model(x).pooler_output, '(T B) C -> T B C', **size)

In [None]:
model

<br>