# Let's try an non-hierarchical RL agent

In [None]:
import time
import gym
import nle

import matplotlib.pyplot as plt

In [None]:
del gym.Wrapper.__getattr__

Import other useful modules

In [None]:
import plyr
import torch
import numpy as np

from torch import nn
from torch.nn import functional as F

We hide the NLE under several layers of wrappers. From the core to the shell:
1. `ReplayToFile` handles seeding and logs the taken actions and seed into a file for later inspection and replay.

2. `NLEPatches` patches tty-screens, botched by the cr-lf misconfiguration of the NLE's tty term emulator and NetHacks displays (lf only).

3. `Chassis` handles skippable gui events that do not require a decision, such as collecting menu pages unless an interaction is required, fetching consecutive topline or log messages.

4. `ActionMasker` computes the mask of action that are **forbidden** in the current game state (_gui_ or _play_)

5. `RecentHistory` keeps a brief log of actions taken in the environment (partially duplicates the functionality of the `Replay` wrapper).

6. `NLEAtoN` maps ascii actions to opaque actions accpeted by the NLE.

7. (**unused**) `NLEFeatures` adds extra features generated on-the-fly from the current NLE's observation.
  * see `NLEFeaturesVicinity` below

In [None]:
from nle_toolbox.wrappers.replay import ReplayToFile, Replay

from nle_toolbox.wrappers.features import NLEPatches, NLEAtoN
from nle_toolbox.bot.chassis import Chassis, ActionMasker

A temporary wrapper that bails out on any menu or prompts

In [None]:
from nle_toolbox.bot.chassis import InteractiveWrapper
from nle_toolbox.bot.chassis import get_wrapper

class AutoEscape(InteractiveWrapper):
    def __init__(self, env, escape='\033'):
        super().__init__(env)
        self.chassis = get_wrapper(env, Chassis)
        self.escape = escape

    def update(self, obs, rew=0., done=False, info=None):
        # default to immediately escaping from any menu or prompt
        while self.chassis.in_menu or self.chassis.prompt:
            obs, rew, done, info = self.env.step(self.escape)

        # update must always return the most recent relevant transition data
        return obs, rew, done, info

A wrpper that keeps track of the action history

In [None]:
from collections import deque

class RecentHistory(gym.Wrapper):
    """The base interaction architecture is essentially a middleman, who passes
    the action to the underlying env and intercepts the resulting transition
    data. It also is allowed, but not obliged to interact with the env, while
    intercepting the observations.
    """
    def __new__(cls, env, *, n_recent=0, map=None):
        if n_recent < 1:
            return env
        return object.__new__(cls)

    def __init__(self, env, *, n_recent=0, map=None):
        super().__init__(env)
        self.recent = deque([], n_recent)
        self.map = map if callable(map) else lambda x: x

    def reset(self, seed=None):
        return self.env.reset()

    def step(self, action):
        self.recent.append(self.map(action))
        return self.env.step(action)

A wrapper that keeps the specified observation keys.

In [None]:
from nle_toolbox.wrappers.features import ObservationWrapper

class ObservationKeyFilter(ObservationWrapper):
    """Filter out the fields of the observation dict specified fields and optionally make a copy.
    """
    def __init__(self, env, *keys):
        super().__init__(env)
        self.keys = frozenset(keys)

        self.observation_space = gym.spaces.Dict(
            self.observation(self.observation_space)
        )

    def observation(self, observation):
        return {k: v for k, v in observation.items() if k in self.keys}

A wrapper, which pre-extracts the field-of-view around the agent
* `NLEFeatures` is a little bit outdated, but hopefully, if the wrapper below helps,
then it will be updated and merged

In [None]:
from nle_toolbox.wrappers.features import ObservationWrapper

from nle_toolbox.utils.fold import npy_fold2d
from nle.nethack import (
    MAX_GLYPH,
    NLE_BL_X,
    NLE_BL_Y,
    DUNGEON_SHAPE,
)

class NLEFeaturesVicinity(ObservationWrapper):
    def __init__(self, env, *, k=3):
        super().__init__(env)

        decl = self.observation_space['glyphs']

        # create bordered glyph array
        rows, cols = DUNGEON_SHAPE
        glyphs = self.glyphs = np.full((
            k + rows + k, k + cols + k,
        ), MAX_GLYPH, dtype=decl.dtype)

        # create view for fast access
        self.vw_glyphs = glyphs[k:-k, k:-k]
        self.vw_vicinity = npy_fold2d(
            glyphs, k=k, n_leading=0, writeable=True,
            # XXX pytorch does not like read-only views
        )

        # declare the observation space
        self.observation_space['vicinity'] = gym.spaces.Box(
            0, MAX_GLYPH,
            dtype=self.vw_vicinity.dtype,
            shape=self.vw_vicinity.shape[2:],
        )

    def observation(self, observation):
        np.copyto(self.vw_glyphs, observation['glyphs'], 'same_kind')

        bls = observation['blstats']
        vic = self.vw_vicinity[bls[NLE_BL_Y], bls[NLE_BL_X]]

        # make sure to produce a coipy of the array
        observation.update(dict(vicinity=vic.copy()))
        return observation

The strength stats in AD&D 2ed, upon which the mechanics of NetHack is based,
comes in two ints: strength and percentage.

In [None]:
from nle.nethack import (
    NLE_BL_STR25,
    NLE_BL_STR125,
)

class NLEFeaturesStrengthPatch(ObservationWrapper):
    def observation(self, observation):
        bls = observation['blstats'].copy()

        # strength percentage is more detailed than `str` stat
        # XXX compare src/winrl.cc#L538 with src/attrib.c#L1072-1085
        #     e.g. src/dokick.c#L38 sums the transformed str with dex and con
        str, prc = bls[NLE_BL_STR125], 0.
        if str >= 122:
            str = min(str - 100, 25)

        elif str >= 19:
            str, prc = divmod(19 + str / 50, 1)  # divmod-by-one :)
        bls[NLE_BL_STR25] = int(str)
        bls[NLE_BL_STR125] = int(prc * 100)  # original step .02, so ok

        # replace the original blstats array
        observation.update(dict(blstats=bls))
        return observation

The factory for collecting random exploration rollouts

In [None]:
# from nle_toolbox.utils import seeding
import minihack
# env = gym.make("MiniHack-River-v0")

def factory(seed=None, folder=None, sticky=False):
#     env = gym.make('NetHackChallenge-v0')
    env = gym.make(
        'MiniHack-Room-Ultimate-15x15-v0',
        observation_keys=(
            'glyphs',
            'chars',
            'colors',
            'specials',
            'blstats',
            'message',
            'inv_glyphs',
            'inv_strs',
            'inv_letters',
            'inv_oclasses',
            'tty_chars',
            'tty_colors',
            'tty_cursor',
            'misc',
            'screen_descriptions',
        ),
    )

    from nle.nethack import ACTIONS
    ctoa = {chr(a): j for j, a in enumerate(env.unwrapped._actions)}
    atoc = tuple(map(chr, env.unwrapped._actions))

    # provide seeding capabilities and full action tracking
    if folder is None:
        env = Replay(env, sticky=sticky)

    else:
        env = ReplayToFile(env, sticky=sticky, folder=folder, save_on='done')
    env.seed(seed)

    # patch bugged tty output
    env = NLEPatches(env)

    # log recent actions
    env = RecentHistory(
        env,
        n_recent=128,
        map=lambda a: atoc[a],  # XXX atoc IS NOT a dict!
    )

    # skippable gui abstraction layer. Excluded if the action space does not
    #  hace the SPACE action.
    env = Chassis(env, space=ctoa.get(' '), split=False)

    # # auto-skip any menu or prompt
    # env = AutoEscape(env, escape=ctoa['\033'])

    # a feature extractor to potentially reduce
    #  the runtime complexity of the agent.
    env = NLEFeaturesVicinity(env, k=3)

    # properly handle str and str125 stats
    env = NLEFeaturesStrengthPatch(env)

    # filter unused observation keys
    # XXX this wrapper should be applied before any container-type
    #  modifications of the NLE's observation space.
    env = ObservationKeyFilter(
        env,
        # the map, bottom line stats and inventory
        'glyphs',
        # 'chars',
        # 'colors',
        # 'specials',
        'blstats',
        'inv_glyphs',
        # 'inv_strs',
        # 'inv_letters',
        # 'inv_oclasses',

        # used for in-notebook rendering
        'tty_chars',
        'tty_colors',
        'tty_cursor',

        # used by the GUI abstraction layer (Chassis)
        # 'message',
        # 'misc',

        # extra features produced by the upstream wrappers
        'vicinity',
        # 'is_objpile',
    )

    # compute and action mask based on the current NLE mode: gui or play
    env = ActionMasker(env)
    return env

A renderer for this **factory**

In [None]:
import pprint as pp
from time import sleep

from nle_toolbox.utils.env.render import render as tty_render
from IPython.display import clear_output

def ipynb_render(obs, clear=True, fps=None):
    if fps is not None:
        if clear:
            clear_output(wait=True)

        obs, mask = obs
        print(tty_render(**obs))
        if fps > 0:
            sleep(fps)

    return True

We start with implementing a simple command evaluator.

In [None]:
from collections import deque

def gui_run(env, *commands):
    pipe0 = deque([])
    obs, fin = env.reset(), False
    for cmd in commands:
        if fin:
            break

        pipe0.extend(cmd)
        while pipe0 and not fin:
            obs, rew, fin, nfo = env.step(pipe0.popleft())

        yield obs

Interesting historical seeds

In [None]:
seed = None
# seed = 13765371332493407478, 12246923801353953927
# seed = 12301533412141513004, 11519511065143048485
# seed = 1632082041122464284, 11609152793318129379
# seed = 5009195464289726085, 12625175316870653325
# seed = 8962210393456991721, 8431607288866012881
# seed = 14729177660914946268, 9187177962698747861
# seed = 16892554419799916328, 6562518563582851317

# seed = 12513325507677477210, 18325590921330101247  # Ranger, arrows, dualwields
# seed = 1251332550767747710, 18325590921330101247  # Monk, martial arts, single
# seed = 125133255076774710, 18325590921330101247  # single
# seed = 12604736832047991440, 12469632217503715839  # Wizard, three spells, exploding wand
# seed = 14278027783296323177, 11038440290352864458  # valkyrie, dual-wield
# seed = 5009195464289726085, 12625175316870653325  # priestess, can loot lots of spells

The code below is used to debug certain events and gui

In [None]:
with NLEAtoN(factory(seed, sticky=True)) as env:
    from nle_toolbox.bot.chassis import get_wrapper
    cha = get_wrapper(env, Chassis)

    for obs in gui_run(
        env,
#         ';j:',         # a paragraph about a cat
#         'acy',         # break a wand "of slow" and blow up
        ''
    ):
        pp.pprint(
            (
                cha.messages, cha.prompt,  # obs['tty_chars'][0].view('S80')[0].strip(),
                cha.in_getlin, cha.in_menu, cha.in_yn_function, cha.xwaitingforspace,
            )
        )

        ipynb_render(obs, clear=False, fps=0.01)  # dump(env.env, obs[0])

Random agent

In [None]:
def random(obs, n=float('inf'), *, seed=None):
    obs, mask = obs
    rng, j = np.random.default_rng(seed), 0
    while not mask.all() and j < n:
        # if we're in LINGER state, pick a random non-forbidden action
        # XXX whelp... tilde on int8 is `two's complement`, not the `logical not`
        act = rng.choice(*np.logical_not(mask).nonzero())

        obs, mask = (yield act)
        j += 1

Do a limited step run

In [None]:
try:
    with factory(seed=seed, sticky=True) as env:
        cha = get_wrapper(env, Chassis)
        msk = get_wrapper(env, ActionMasker)

        # init the agent and get its first reaction
        gen = random(obs)
        act = gen.send(None)

        # reset the env and get the initial obs
        obs, fin = env.reset(), False
        while ipynb_render(obs, clear=True, fps=0.01) and not fin:
            obs, rew, fin, info = env.step(act)
            act = gen.send(obs)

# Although the crawler is an infinite loop
#  we still trivially protect against `StopIteration`
except StopIteration:
    pass

finally:
    gen.close()

<br>

### Let's train an A2C agent

An object to extract full episodes from their trajectory fragments.

In [None]:
from nle_toolbox.utils.rl.tools import EpisodeExtractor

One step in the joint differentiable rollout collection

In [None]:
from nle_toolbox.utils.rl.engine import step

Prepare the runtume context for the advantage-actor-critic

In [None]:
from nle_toolbox.utils.rl.engine import prepare

A procedure to collect a differentiable rollout

In [None]:
def collect(env, agent, npyt, hx, *, n_steps, visualize=None):
    """Collect a fragment of the trajectory."""
    # (sys) get a view into numpy's observation arrays
    vw_vis = None
    if visualize is not None:
        vw_vis = plyr.apply(plyr.getitem, npyt.npy.obs, index=visualize)

    for j in range(n_steps):
        if vw_vis is not None:
            ipynb_render(vw_vis, clear=True, fps=0.01)

        # (sys) get $(x_t, a_{t-1}, r_t, d_t), v_t, \pi_t$
        ignore, hx = out = step(env, agent, npyt, hx)
        yield out

Compute the policy gradient surrogate, the entropy and other loss components

In [None]:
from nle_toolbox.utils.rl.engine import pyt_polgrad

from nle_toolbox.utils.rl.engine import pyt_entropy

from nle_toolbox.utils.rl.engine import pyt_critic

We shall use GAE in policy gradients and returns for the critic.

In [None]:
from nle_toolbox.utils.rl.returns import pyt_ret_gae

A function to compute the targets (GAE, returns) for policy grads and critic loss.

In [None]:
@torch.no_grad()
def pg_targets(rew, val, /, gam, lam, *, fin):
    r"""Compute the targets (GAE, returns) for policy grads and critic loss.

    Details
    -------
    The arguments `rew`, `fin`, and `val` are $r_t$, $d_t$ and $v(s_t)$,
    respectively! The td-error terms in GAE depend on $r_{t+1}$, $d_{t+1}$
    and on both $v(s_{t+1})$ and $v(s_t)$, hence on `rew[1:]`, `fin[1:]`,
    `val[1:]` and `val[:-1]`. `val[-1]` is the value-to-go estimate for
    the last state in the related trajectory fragment.
    """
    gae, ret = {}, {}
    for k in rew:
        ret[k], gae[k] = pyt_ret_gae(
            rew[k][1:], fin[1:], val[k],
            gam=gam[k], lam=lam[k],
        )

    return gae, ret

A nested parameter container, which keeps references to the leaf parameters.

In [None]:
class ParameterContainerModule(nn.Module):
    """Collect base parameters into parameter containers, and containers
    into module containers, since `nn.ParameterList` or `nn.ParameterDict`
    are actually subclasses of `nn.Module`.
    """
    def __init__(self, parameters):
        super().__init__()

        if isinstance(parameters, tuple) and hasattr(parameters, '_fields'):
            parameters = parameters._asdict()

        if isinstance(parameters, (tuple, list)):
            for j, it in enumerate(parameters):
                setattr(self, str(j), it)

        elif isinstance(parameters, dict):
            for k, it in parameters.items():
                setattr(self, k, it)

    def __getitem__(self, key):
        return getattr(self, str(key))

    def extra_repr(self):
        # use ParameterList's extra repr implementation
        tmpstr = nn.ParameterList.extra_repr(self)
        return '\n'.join(map(str.strip, tmpstr.splitlines()))

The following network feeds the **obs**ervation $x_t$, **act**ion $a_{t-1}$, and **rew**ard $r_t$
through the provided `features` network to their joint representations. They are then passed into
the LSTM core along with the recurrent state `hx` $h_t$, and finally through value and policy heads.

In [None]:
from nle_toolbox.utils.nn import multinomial
from nle_toolbox.utils.nn import masked_rnn, rnn_hx_shape
from nle_toolbox.utils.nn import LinearSplitter, ModuleDict

def masked_multinomial(raw, mask, dim=-1):
    """Draw a variate from the categorical rv, specified by
    the unnormalized logits `raw` at the indicated `dim`, optionally
    masked by `mask` boolean array of the same shape as `raw`.
    """
    raw = raw.detach().masked_fill(mask, -float('inf'))
    return multinomial(raw.softmax(dim=dim), 1, dim).squeeze(dim)


class NLENeuralAgent(nn.Module):
    """A generic recurrent agent."""
    def __init__(self, env, features, *, pol, val, core=None, h0=True):
        if not isinstance(features, ModuleDict):
            raise TypeError('Expected `features` to be a `ModuleDict`,'
                            f'got `{type(features)}` instead.')

        super().__init__()
        self.features = features

        self.core = nn.Identity()
        if isinstance(core, dict):
            self.core = nn.LSTM(**core)
        
        if h0:
            # XXX why shoul the agent OWN the initial hx? Make it an external
            #  parameter maybe?
            h0 = plyr.apply(torch.zeros, rnn_hx_shape(self.core))

            # assigning a single element tuple intentionally bypasses nn.Module's
            #  parameter registration logic. Furhter more any device moves and
            #  dtype changes to the module are reflected in the parameters in-place
            #  and thus are still reflected by this nested container, since the
            #  Parameter objects themselves are not changed and still referenced.
            self._hidden = (plyr.apply(nn.Parameter, h0),)

            # let the `nn.Module` machinery handle unstructured `h0`
            self.h0 = plyr.apply(
                lambda x: x, *self._hidden, _finalizer=ParameterContainerModule)

        else:
            self._hidden = (None,)
            self.register_parameter('h0', None)

        self.pol = LinearSplitter(**pol)
        self.val = LinearSplitter(**val)

    @property
    def initial_hx(self):
        return self._hidden[0]

    def forward(self, obs, act=None, rew=None, hx=None, *, fin=None):
        # make sure to extract the mask from the obs
        obs, mask = obs

        # `.features` is a ModuleDict, which ignores kwargs NOT declared
        #  at its `__init__`, which makes `locals()` work really neatly here.
        x = self.features(locals())
        out, hx = masked_rnn(self.core, x, hx, reset=fin, h0=self.initial_hx)

        # sampling before logsoftmax, because masking is
        #  applied to unnormalized logits.
        pol = self.pol(out)
        act = plyr.apply(masked_multinomial, pol, mask)

        return act, (
            plyr.apply(torch.squeeze, self.val(out), dim=-1),
            plyr.apply(F.log_softmax, pol, dim=-1),
        ), hx

<br>

### Redesigning the building blocks of the NLE featrue extractor

We simplify the shared glyph embedding layer to rely on glyph entities only.

In [None]:
from nle_toolbox.bot.model.glyph import GlyphEmbedding

We re-use the original glyph feature extractor layer.

In [None]:
from nle_toolbox.bot.model.glyph import GlyphFeatures
from einops import rearrange

class RestrictedGlyphFeatures(GlyphFeatures):
    def __init__(self, glyphs):
        super().__init__(glyphs, window=None)

    def forward(self, obs):
        # get the pre-extracted vicinities
        gl_vicinity = rearrange(self.glyphs(obs['vicinity']),
                                'T B H W C -> T B C H W')
        # embed inventory glyphs
        # XXX need to replace NO_GLYPH with MAX_GLYPH, unless they coincide.
        gl_inventory = rearrange(self.glyphs(obs['inv_glyphs']),
                                 'T B N ... -> T B ... N')

        return dict(
            vicinity=gl_vicinity,
            inventory=gl_inventory.contiguous(),
        )

A design of the new embedding:
* the vicinity is an ego-centric view into the map
* embed entities and their groupd additively
* also ad the `ego` embedding offset to the centre of the vicinity

In [None]:
# from einops import rearrange

from typing import Optional
from nle_toolbox.utils.env.defs import glyphlut, glyph_group, MAX_ENTITY

class GlyphVicinityEmbedding(nn.Module):
    """The combined ego-centric vicinity - inventory embedding.
    """
    def __init__(self, embedding_dim):
        super().__init__()

        # glyph-to-entity embedding
        self.entity = nn.Embedding(
            MAX_ENTITY + 1, embedding_dim, padding_idx=MAX_ENTITY,
        )
        self.entity.register_buffer(
            'lut', torch.tensor(glyphlut.entity).clone(),
        )

        # glyph-to-group embedding
        self.group = nn.Embedding(
            glyph_group.MAX + 1, embedding_dim, padding_idx=glyph_group.MAX,
        )
        self.group.register_buffer(
            'lut', torch.tensor(glyphlut.group).clone(),
        )

        # ego embedding offset
        self.ego = nn.Parameter(torch.randn(1, 1, embedding_dim))

    def forward(self, obs):
        # get the pre-extracted ego-centric vicinities
        vicinity = obs['vicinity']  # XXX or use `GlyphFeatures.forward`
        k = vicinity.shape[-1] // 2

        # glyph -->> W_e[glyph.entity] + W_g[glyph.group] + ego
        ent = self.entity(self.entity.lut[vicinity.long()])
        grp = self.group(self.group.lut[vicinity.long()])
        ego = F.pad(self.ego, (0, 0,) + (k, k) * 2)  # XXX pad spatial dims

        gl_vicinity = rearrange(ent + grp + ego, 'T B H W C -> T B C H W')

        # embed inventory glyphs
        # XXX need to replace NO_GLYPH with MAX_GLYPH, unless they coincide.
        inv = self.entity(self.entity.lut[obs['inv_glyphs'].long()])
        gl_inventory = rearrange(inv, 'T B N ... -> T B ... N')

        return dict(
            vicinity=gl_vicinity,
            inventory=gl_inventory.contiguous(),
        )

Now let's redo the bottom line stats: vitals first!

In [None]:
from typing import Optional

class BLSHungerEmbedding(nn.Embedding):
    from nle_toolbox.utils.env.defs import hunger
    from nle.nethack import NLE_BL_HUNGER

    def __init__(
        self,
        embedding_dim: int = 8,
        max_norm: Optional[float] = None,
        norm_type: float = 2.0,
        scale_grad_by_freq: bool = False,
    ):
        super().__init__(
            self.hunger.MAX + 1,
            embedding_dim,
            padding_idx=self.hunger.MAX,
            max_norm=max_norm,
            norm_type=norm_type,
            scale_grad_by_freq=scale_grad_by_freq,
            sparse=False,
        )

    def forward(self, blstats):
        return super().forward(blstats[..., self.NLE_BL_HUNGER])

In [None]:
from nle_toolbox.utils.nn import OneHotBits

class BLSConditionEmbedding(nn.Module):
    from nle_toolbox.utils.env.defs import condition
    from nle.nethack import NLE_BL_CONDITION

    def __init__(self, embedding_dim: int = 8):
        super().__init__()

        self.onehot = OneHotBits(self.condition.N_BITS)
        self.linear = nn.Linear(self.condition.N_BITS, embedding_dim)

    def forward(self, blstats):
        return self.linear(self.onehot(blstats[..., self.NLE_BL_CONDITION]))

In [None]:
from nle_toolbox.utils.nn import EquispacedEmbedding

class BLSVitalsEmbedding(EquispacedEmbedding):
    from nle.nethack import (
        NLE_BL_HP,
        NLE_BL_HPMAX,
        NLE_BL_ENE,
        NLE_BL_ENEMAX,
    )

    def __init__(self, num_bins: int):
        super().__init__(0, 1, num_bins - 1, scale='lin')

    def forward(self, blstats):
        hp = blstats[..., self.NLE_BL_HP] / blstats[..., self.NLE_BL_HPMAX]
        mp = blstats[..., self.NLE_BL_ENE] / blstats[..., self.NLE_BL_ENEMAX]
        return torch.cat([
            super().forward(torch.nan_to_num_(hp)),
            super().forward(torch.nan_to_num_(mp)),
        ], dim=-1)

Now we redo the stats and build

In [None]:
class BLSStatsEmbedding(nn.ModuleDict):
    from nle.nethack import (
        NLE_BL_STR25,
        NLE_BL_STR125,
        NLE_BL_DEX,
        NLE_BL_CON,
        NLE_BL_INT,
        NLE_BL_WIS,
        NLE_BL_CHA,
    )

    # 6 base stats (luck is hidden) range 0..25
    index = {
        'str': NLE_BL_STR25,
        # 'strprc': NLE_BL_STR125,  # handled manually
        'dex': NLE_BL_DEX,
        'con': NLE_BL_CON,
        'int': NLE_BL_INT,
        'wis': NLE_BL_WIS,
        'cha': NLE_BL_CHA,
    }

    def __init__(self, embedding_dim: int = 16):
        stats = {k: nn.Embedding(25 + 1, embedding_dim) for k in self.index}

        # embedding (adjusted) percentage strength for warrior classes
        stats['strprc'] = EquispacedEmbedding(0, 1, embedding_dim-1, scale='lin')
        super().__init__(stats)

    def forward(self, blstats):
        # deal with
        #  'strength_percentage',
        #  'str', 'dex', 'con', 'int', 'wis', 'cha',
        out = []
        for k, j in self.index.items():
            out.append(self[k](blstats[..., j]))

        out.append(self['strprc'](blstats[..., self.NLE_BL_STR125].div(99)))
        return torch.cat(out, dim=-1)

In [None]:
class BLSArmorClassEmbedding(nn.Embedding):
    from nle.nethack import NLE_BL_AC

    def __init__(
        self,
        embedding_dim: int,
        max_norm: Optional[float] = None,
        norm_type: float = 2.0,
        scale_grad_by_freq: bool = False,
    ) -> None:
        super().__init__(
            24,  # the AC is mapped to 24 bins by the lookup table below
            embedding_dim,
            padding_idx=None,  # no padding index,
            max_norm=max_norm,
            norm_type=norm_type,
            scale_grad_by_freq=scale_grad_by_freq,
            sparse=False,
        )

        # a bin lookup table for armor_class, a categorical variable.
        self.register_buffer(
            'lookup', torch.tensor(
                # 0..10 mapped to 11..1, 11..127 to 0
                [*reversed(range(1, 12))] + [0] * 117

                # 128..244 mapped to 23, 245..256 to 22..12
                + [23] * 117 + [*range(22, 11, -1)]
            )
        )

    def forward(self, blstats: torch.Tensor) -> torch.Tensor:
        # 'armor_class' in NetHack is descending just like in adnd. In the
        # [code](src/do_wear.c#L2107-2153) it appears that AC is confined
        # to the range of a `signed char`, however to adnd 2e mechanics it
        # is sufficient to consider the range [-10, 10] for the player's AC,
        # since we make d20 rolls anyway.
        # https://merricb.com/2014/06/08/a-look-at-armour-class-in-original-dd-and-first-edition-add/
        # XXX Also, NetHack, just why?! include/hack.h#L499-500
        return super().forward(self.lookup[blstats[..., self.NLE_BL_AC]])

In [None]:
class BLSEncumberanceEmbedding(nn.Embedding):
    from nle_toolbox.utils.env.defs import encumberance
    from nle.nethack import NLE_BL_CAP

    def __init__(
        self,
        embedding_dim: int = 8,
        max_norm: Optional[float] = None,
        norm_type: float = 2.0,
        scale_grad_by_freq: bool = False,
    ):
        #  'carrying_capacity'
        super().__init__(
            self.encumberance.MAX + 1,
            embedding_dim,
            padding_idx=self.encumberance.MAX,
            max_norm=max_norm,
            norm_type=norm_type,
            scale_grad_by_freq=scale_grad_by_freq,
            sparse=False,
        )

    def forward(self, blstats):
        return super().forward(blstats[..., self.NLE_BL_CAP])

The following layer puts all the preceding layers together.

In [None]:
# botl stats that were not accounted for.
from nle.nethack import (
    NLE_BL_X,
    NLE_BL_Y,
    NLE_BL_SCORE,
    NLE_BL_DEPTH,
    NLE_BL_GOLD,
    NLE_BL_HD,
    NLE_BL_XP,
    NLE_BL_EXP,
    NLE_BL_TIME,
    NLE_BL_DNUM,
    NLE_BL_DLEVEL,
)

class BLSEmbedding(nn.ModuleDict):
    bls_map = {
        'hunger': BLSHungerEmbedding,
        'condition': BLSConditionEmbedding,
        'vitals': BLSVitalsEmbedding,
        'stats': BLSStatsEmbedding,
        'armorclass': BLSArmorClassEmbedding,
        'encumberance': BLSEncumberanceEmbedding,
    }

    def __init__(self, **recipe):
        #  ignore `None`
        recipe = {k: v for k, v in recipe.items() if v is not None}
        if not recipe:
            raise RuntimeError(f"`{type(self).__name__}` got an empty recipe.")

        # check if the ingredietns are dicts
        not_dicts = [k for k, v in recipe.items() if not isinstance(v, dict)]
        if not_dicts:
            raise RuntimeError(f"These are not dicts: `{not_dicts}`.")

        super().__init__({
            k: self.bls_map[k](**v) for k, v in recipe.items()
        })

    def forward(self, blstats):
        return torch.cat([m(blstats) for m in self.values()], dim=-1)

Now we put the glyph- and botl- related features in one module

In [None]:
from einops import rearrange
from nle_toolbox.utils.nn import ModuleDict

from nle_toolbox.bot.model.vit import ViTEncoder

class GlyphViTEncoder(nn.Module):
    def __init__(
        self,
        radius: int,
        embedding_dim: int,
        num_attention_heads: int,
        intermediate_size: int,
        head_size:int = None,
        dropout: float = 0.0,
        *,
        n_layers: int = 1,
        b_mean: bool = False,
    ):
        super().__init__()
        self.extractor = RestrictedGlyphFeatures(
            GlyphEmbedding(embedding_dim),
            # window=radius,
        )

        self.vit = ViTEncoder(
            radius + 1 + radius,
            embedding_dim,
            num_attention_heads,
            intermediate_size,
            head_size=head_size,
            dropout=dropout,
            n_layers=n_layers,
            b_mean=b_mean,
        )

    def forward(self, obs):
        patch = self.extractor(obs)['vicinity']

        size = dict(zip("TBCHW", patch.shape[:2]))
        out, attn = self.vit(rearrange(patch, 'T B ... -> (T B) ...'))
        return rearrange(out, '(T B) ... -> T B ...', **size), \
            rearrange(attn, '(T B) ... -> T B ...', **size)

An encoder of glyphs, simpler than ViT

In [None]:
from typing import List
from einops import rearrange
from einops.layers.torch import Rearrange

class GlyphSimpleEncoder(nn.Module):
    def __init__(
        self,
        radius: int,
        embedding_dim: int,
        intermediate_size: int,
        dropout: float = 0.0,
        kind='old',
        **ignore,
    ):
        n_rows = n_cols = radius + 1 + radius
        assert kind in ('old', 'new',)

        super().__init__()
        if kind == 'new':
            self.extractor = GlyphVicinityEmbedding(embedding_dim)

        else:
            self.extractor = RestrictedGlyphFeatures(
                GlyphEmbedding(embedding_dim),
                # window=radius,
            )

        self.encoder = nn.Sequential(
            nn.LayerNorm(embedding_dim),
            Rearrange('B N C -> B (N C)'),
            nn.Linear(
                (n_rows * n_cols) * embedding_dim,
                intermediate_size,
                bias=True,
            ),
            nn.GELU(),
            nn.Linear(intermediate_size, intermediate_size, bias=True),
            nn.GELU(),
        )

    def forward(self, obs):
        patch = self.extractor(obs)['vicinity']

        size = dict(zip("TBCHW", patch.shape[:2]))
        x = rearrange(patch, 'T B C H W -> (T B) (H W) C')

        out = self.encoder(x)
        out = rearrange(out, '(T B) ... -> T B ...', **size)

        return out, None

A mutating encoder class

In [None]:
class GlyphEncoder(nn.Module):
    def __new__(cls, kind, **passthrough):
        if kind == 'vit':
            return GlyphViTEncoder(**passthrough)

        if kind == 'simple':
            return GlyphSimpleEncoder(**passthrough, kind='new')

        if kind == 'ego':
            return GlyphSimpleEncoder(**passthrough, kind='new')

        raise RuntimeError(f'Unknown encoder `{kind}`.')

The combined feature extractor

In [None]:
class NLEFeatures(nn.Module):
    def __init__(self, glyphs, blstats, **ignore):
        super().__init__()

        self.glyphs = GlyphEncoder(**glyphs)
        self.blstats = BLSEmbedding(**blstats)

    def forward(self, obs):
        gl, *ignore = self.glyphs(obs)
        bls = self.blstats(obs['blstats'])
        return torch.cat((gl, bls,), dim=-1)

Intrinsic motivation via Random Network distillation

In [None]:
class RNDModule(nn.Module):
    def __init__(
        self,
        radius: int,
        embedding_dim: int,
        sizes: list[int],
    ):
        n_rows = n_cols = radius + 1 + radius

        super().__init__()
        self.extractor = RestrictedGlyphFeatures(
            GlyphEmbedding(embedding_dim),
            # window=radius,
        )

        layers = [
            nn.LayerNorm(embedding_dim),
            Rearrange('B N C -> B (N C)'),
            nn.Linear(
                (n_rows * n_cols) * embedding_dim,
                sizes[0],
                bias=True,
            ),
        ]
        for n, m in zip(sizes, sizes[1:]):
            layers.append(nn.GELU())
            layers.append(nn.Linear(n, m, bias=True))

        self.encoder = nn.Sequential(*layers)

    def forward(self, obs):
        patch = self.extractor(obs)['vicinity']
        size = dict(zip("TBCHW", patch.shape[:2]))
        x = rearrange(patch, 'T B C H W -> (T B) (H W) C')
        return rearrange(self.encoder(x), '(T B) ... -> T B ...', **size)

Create the agent

In [None]:
embedding_dim = 16 # 256
intermediate_size = 128 # 1024

recipe = {
    'features': {
        'glyphs': {
            'kind': 'ego',
            # radius is currently hardcoded to be `k=3` in `factory()`
            'radius': 3,  # (2 + 1 + 2) * (2 + 1 + 2) * embedding_dim
            'embedding_dim': embedding_dim,
            'num_attention_heads': 4,
            'intermediate_size': intermediate_size,
            'head_size': None,
            'dropout': 0.0,
            'n_layers': 1,
            'b_mean': False,
        },
        'blstats': {
            'hunger': {'embedding_dim': 5},
            'condition': {'embedding_dim': 5},
            'vitals': {'num_bins': 5},  # * 2
            'stats': None,  # {'embedding_dim': 5},  # * (6 + 1)
            'armorclass': None,  # {'embedding_dim': 5},
            'encumberance': None,  # {'embedding_dim': 5},
        },  # (4 + 7 + 2) * 5
    },
    'core': {
        'input_size':  4 * 5 + intermediate_size, # (4 + 7 + 2) * 5 + intermediate_size,
        'hidden_size': 128,
        'num_layers': 2,  # it appears that more layers makes the agent learn more than nothing!
    },
    'pol': {
        'in_features': 128,
        'out_features': 8,  # len(ActionMasker._raw_nethack_actions)
    },
    'val': {
        'in_features': 128,
        'out_features': {
            'ext': 1,
            'int': 1,
        },
    },
}

rnd_recipe = {
    'radius': 3,  # hardcoded! see `factory()`
    'embedding_dim': embedding_dim,  # 32,
    'sizes': [
        intermediate_size,  # 256,
        16,  # 32,
    ],
}

The fragmented a2c parameters

In [None]:
import wandb

wandb.init(
    project='nle-toolbox',
    job_type='nethack',
    tags=[
        'VicinityWrapper',
        'ego',
#         'ViT',
    ],
    # mode='disabled',
    config=dict(
        # int weight in the gae mix for the polgrads
        f_alpha=0.5,

        # extrinsic/intrinsic reward PV discount
        f_gamma={'ext': 0.999, 'int': 0.99},

        # the GAE discount
        f_lambda={'ext': 0.96, 'int': 0.96},

        # the share of the runtime state's hx's gard to be passed to h0
        f_h0_lerp=0.05,

        # critic (both ext and int) and entropy weights in the loss
        C_pg=1.,
        C_critic={'ext': 0.5, 'int': 0.5},
        C_entropy=0.01,

        # the truncated-bptt length rollout
        n_fragment_length=20,

        # the number of envs run simultaneously
        n_batch=16,

        # the total number of steps allotted to training (summed across all envs)
        n_total=2_592_000 * 6,

        # also track the recipes
        recipe=recipe,
        rnd_recipe=rnd_recipe,
    ),
);

Create the vectorized env and the agent

In [None]:
from nle_toolbox.utils.rl.engine import SerialVecEnv

env = SerialVecEnv(factory, n_envs=wandb.config.n_batch)

Build an agent and the RND motivator from the recipe

In [None]:
# from nle_toolbox.bot.model.network import NetworkFeatures
from collections import OrderedDict

agent = NLENeuralAgent(env, **{
    **recipe,
    'features': ModuleDict({
        # XXX ignores kwargs not declared at `__init__`
        'obs': NLEFeatures(**recipe['features']),
    }, dim=-1)
})

rnd = ModuleDict(dict(
    target=RNDModule(**rnd_recipe).requires_grad_(False),
    online=RNDModule(**rnd_recipe),
))

Reset the bias terms in the recurrent core

In [None]:
from torch.nn import init
from nle_toolbox.utils.nn import rnn_reset_bias

agent.apply(rnn_reset_bias);

In [None]:
agent

In [None]:
rnd

<br>

A clunky tool to split the parameters into biases and weights.

Unlike `.bias` parameters, which **should not** be decayed, not every `.weight` parameter
**should be** decayed. For example, learnt positional encodings can be seen as bias terms
to the "linear" operation effected by an `nn.Embedding`, at the same time ordinary token
embeddings should also not be decayed either, since they effectively serve as the input
data representation.

`nn.LayerNorm` layers perform within-layer normalization and then re-scale and translate
it, meaning that their weight should be regularized to unit scales ant not zeros.

Hence:
* all biases, and weights of `nn.LayerNorm` and `nn.Embedding` should not be regularized
by weight decay


In [None]:
def split_parameters(module):
    decay, no_decay, seen = [], [], set()
    for prefix, mod in module.named_modules():
        for name, par in mod.named_parameters(prefix='', recurse=False):
            assert par not in seen
            seen.add(par)

            # always exclude bias terms
            if name.startswith('bias'):
                no_decay.append(par)

            # but always decay dense linear weights
            elif name.startswith('weight') and isinstance(
                mod, (nn.Linear, nn.LSTM)
            ):
                decay.append(par)

            # yet never decay normalization translations and embedding representations
            elif name.startswith('weight') and isinstance(
                mod, (nn.LayerNorm, nn.Embedding)
            ):
                no_decay.append(par)

            elif name in ('posemb', 'cls') and isinstance(
                mod, (ViTEncoder,)
            ):
                no_decay.append(par)

            else:
                no_decay.append(par)
                # raise TypeError(f'Unrecognized parameter `{name}` in `{prefix}` {mod}.')

    return decay, no_decay

Adam with high `weight_decay` may push many parameters' values
into denormalized fp mode, which is ultra slow on CPU (but not
as bad on GPU), see the answer and a reply form *njuffa* in to
this [stackoverflow](https://stackoverflow.com/questions/36781881).

In [None]:
torch.set_flush_denormal(True)

Init agent's optimizer

In [None]:
decay, no_decay = split_parameters(agent)

# AdamW doesnэt do what you expect it to do, Ivan! (although it
#  correctly decouples the objective's grad and the ell-2 weight reg).
#  See https://arxiv.org/abs/1711.05101.pdf
agent.optim = torch.optim.AdamW([
    dict(params=decay),
    dict(params=no_decay, weight_decay=0.),
], lr=1e-3, eps=1e-5, weight_decay=0.001)

An optimizer for the RND

In [None]:
decay, no_decay = split_parameters(rnd.online)
rnd.optim = torch.optim.AdamW([
    dict(params=decay),
    dict(params=no_decay, weight_decay=0.),
], lr=1e-3, eps=1e-5, weight_decay=0.01)

Cache the hyper-parameters

In [None]:
n_total = wandb.config.n_total
n_fragment_length = wandb.config.n_fragment_length
n_batch = wandb.config.n_batch

f_lambda = wandb.config.f_lambda
f_gamma = wandb.config.f_gamma
f_alpha = wandb.config.f_alpha
f_h0_lerp = wandb.config.f_h0_lerp

C_pg = wandb.config.C_pg
C_entropy = wandb.config.C_entropy
C_critic = wandb.config.C_critic

Progress bar update and termination condition checker.

In [None]:
import tqdm

def progress(bar, n):
    bar.update(n - bar.n)
    return n < bar.total

a service function to get diagnostic stats from an episode.

In [None]:
import math
from nle_toolbox.utils.env.defs import MAX_ENTITY

def ep_stats(ep):
    # episode duration and total score
    n_length = len(ep.fin) - int(ep.fin[-1])
    f_score = float(ep.rew[1:].sum())

    # count the number of unique glyphs seen during the episode
    off = -1 if ep.fin[-1] else None  # end-of-episode/reset corrrection
    vic = ep.obs[0]['vicinity'][:off]
    cnt = torch.bincount(vic.flatten(), minlength=MAX_ENTITY + 1)
    n_unique = int(cnt.gt(0).sum())

    # compute the entropy of the discrete distribution. This is
    #  a good proxy for the diversity, since it measures the amoun
    #  of information content the encountered glyphs signal.
    proba = cnt.div(cnt.sum())
    f_ent = F.kl_div(proba.new_zeros(()), proba, reduction='sum').neg()

    return n_length, f_score, n_unique, float(f_ent) / math.log(2)

learning process

In [None]:
from torch.nn.utils import clip_grad_norm_

n_steps, epx = 0, EpisodeExtractor()
npyt, hx = prepare(env, rew=0., fin=True), None
with tqdm.tqdm(
    initial=n_steps, total=n_total, ncols=80, disable=False,
) as bar:
    while progress(bar, n_steps):
        # (sys) collect a fragment of the episode time `t` afterstates, t=0..N-1
        fragment, hxx = zip(*collect(
            env, agent, npyt, hx, n_steps=n_fragment_length, visualize=None,
        ))

        # (sys) bootstrap the one-step value-to-go approximation
        # XXX do not update `npyt` and `hx`! Also be careful not to
        # add this last record to the trajectory!
        input = plyr.apply(torch.clone, npyt.pyt)
        act_, (val_, pol_), _ = agent(**input._asdict(), hx=hxx[-1])
        fragment += ((input, val_, pol_),)

        # (sys) repack data ((x_t, a_{t-1}, r^E_t, d_t), v_t, \pi_t)
        input, val, pol = plyr.apply(torch.cat, *fragment, _star=False)
        # XXX note, `.act[t]` is $a_{t-1}$, but the other `*[t]` are $*_t$,
        #  e.g. `.rew[t]` is $r_t$, and `pol[t]` is `$\pi_t$.

        # (sys) retain running state `hx`, but detach its grads (truncated bptt)
        # XXX although val[-1] is used as a non-diffable bootstrap for value-to-go
        #  estimate, and pol[-1] does not participates in either policy-grad nor
        #  the entropy computation, it is imperative that we DO NOT `.detach` prior
        #  to computing these values. Although torch does not backprop through unused
        #  tensors, for some reason, it destabilizes a reference implementation.
        hx = plyr.apply(torch.Tensor.detach, hxx[-1])

        # (sys) pass some feedback from the next fragment into `h0` by lerp-ing
        #  from `hx` to `h0`. `.lerp` does all the necessary broadcasting itself.
        #    `hx <<-- (1 - w) * hx + w * h0`
        if f_h0_lerp > 0:
            hx = plyr.apply(torch.lerp, hx, agent.initial_hx, weight=f_h0_lerp)

        # (sys) extract episode strands
        # XXX the current `input` overlaps the next one! due to the one-step-ahead bootstrap
        episodes = epx.extract(
            input.fin[:-1], plyr.apply(lambda t: t[:-1], input)
        )

        # (rnd) compute the diff-able intrinsic reward
        #     r^I_t = \frac12 \| f(x_t) - \bar{f}(x_t) \|, t=0..N
        # XXX we use huber loss and clamp the intinsic rewards to [0, 1]
        with torch.no_grad():
            rnd_target = rnd.target(input.obs[0])
        rnd_mse = F.mse_loss(
            rnd.online(input.obs[0]), rnd_target, reduction='none',
        ).sum(-1)

        rew_int = rnd_mse.detach()  # .clamp(max=1)  # torch.zeros_like(input.rew)

        # (gae) compute the extrinsic GAE and returns
        # XXX `rew`, `fin`, `val` are $r_t$, $d_t$ and $v(s_t)$!
        rew = {'ext': input.rew, 'int': rew_int}
        gae, ret = pg_targets(rew, val, f_gamma, f_lambda, fin=input.fin)
        # XXX had we used `plyr.apply()` we would need to invert the structure

        # (sys) policy grad surrogate (uses common gae!)
        # XXX r_{t+1}, v_t, v{t+1} -->> A_t \log \pi_t(a_t)
        adv = gae['ext'].add(gae['int'], alpha=f_alpha).detach()
        pg_gae = plyr.apply(pyt_polgrad, pol, input.act, adv=adv)

        # (sys) entropy of the policy
        # XXX kl-div computes \sum_n e^{\log p_n} \log p_n, so we flip the sign
        entropy = plyr.apply(pyt_entropy, pol)

        # (sys) extrinsic critic loss
        critic = plyr.apply(pyt_critic, val, ret)

        # (sys) compute the loss
        loss = \
            - pg_gae * C_pg \
            - entropy * C_entropy \
            + critic['ext'] * (C_critic['ext'] / 2) \
            + critic['int'] * (C_critic['int'] / 2)

        loss_rnd = rnd_mse.sum() / 2

        # (sys) backprop through the agent and the online network of RND
        # XXX this would fail if there is gradient leakage between RND
        #  and the agent.
        agent.optim.zero_grad()
        rnd.optim.zero_grad()

        loss.backward()
        loss_rnd.backward()
        grad = clip_grad_norm_(agent.parameters(), 5.)

        agent.optim.step()
        rnd.optim.step()

        # (sys) incerment the step count
        n_steps += n_fragment_length * n_batch

        # (log) the training progress
        with torch.no_grad():
            if episodes:
                dur_, ret_, unq_, div_ = \
                    map(np.median, zip(*map(ep_stats, episodes)))
                wandb.log({
                    'metrics/return': float(ret_),
                    'metrics/duration': float(dur_),
                    'metrics/n_unique': float(unq_),
                    'metrics/diversity': float(div_),
                }, commit=False)

            wandb.log({
                'n_steps': n_steps,
                'loss/loss': float(loss),
                'loss/ext': float(critic['ext']),
                'loss/int': float(critic['int']),
                'loss/entropy': float(entropy) / (n_fragment_length * n_batch),
                'loss/pg': float(pg_gae),
                'loss/rnd': float(loss_rnd),
            })
            bar.set_postfix_str(f"{float(loss):10.1e} {float(grad):10.1e}")

    # (sys) extract the residual episode strands
    unfninished = epx.finish()

In [None]:
import os
from nle_toolbox.utils.io import mkstemp


target = os.path.abspath('./checkpoints')
os.makedirs(target, exist_ok=True)

checkpoint = mkstemp('.pt', f'ckpt__vit__{n_steps}__', dir=target)
torch.save({
    'agent': agent.state_dict(),
    'agent.optim': agent.optim.state_dict(),
    'rnd': rnd.state_dict(),
    'rnd.optim': rnd.optim.state_dict(),
    'config': dict(wandb.config),
}, checkpoint)

In [None]:
wandb.finish(quiet=True)

In [None]:
import gc; gc.collect(2)

In [None]:
assert False

We finally got something:
* reducing the dim of the RND output vector from 64 to 16 was, perhaps,
the most significant step in the direction of the agent learning at least
something.
* the next thing was removing the build embedding and using `condition`,
`hunger`, and `vitals
* the effect of positive `f_h0_lerp` has not been investigated
  * reducing it from `.05` to `0.0` seems to **adversely** impact learning


Next idea to try: like token + position in in GERT and GPT, let's exploit additive embeddings.
* in glyphs: entity + group embedding -- entities semantics is modulated by its group (relevant to MON, PET, RIDE, STATUE), also add special `ego`-embedding at the centre of vicinity.
Let $g_{uv} \in \mathbb{G}$ be glyph at position $u, v$ relative to the `hero` (bls-x, -y coords).
$$
f_{uv}
    = w_E\bigl[\operatorname{entity}(g_{uv})\bigr]
    + w_G\bigl[\operatorname{group}(g_{uv})\bigr]
    + 1_{(0,0)}{(u, v)} w_{\mathrm{ego}}
  \,, $$
where $w_\cdot$ are game map-related embeddings
* in blstats: modulate an `ego`-embedding by the `vitals`, `condition` and other stats' embeddings.
  * should the the `ego`-embedding be shared between map and state?

<br>

A ranked buffer for episode rollouts.

* what do we do with the missing `hx`? clone episodes in full?
  * `you wake up on a cold stone floor in the middle of a vast chamber without a shred of memoery of how you got here. What do you do?` maybe it is OK to take contiguous fragments of a long episode and start with a wiped out memory.
* do we clone just the actions, or also the value function?

In [None]:
from heapq import heapreplace, heappushpop, heappush, heappop

from dataclasses import dataclass, field
from typing import Any

class RankedBuffer:
    @dataclass(order=True, frozen=True, repr=False)
    class RankedItem:
        rank: float
        item: Any = field(compare=False)

    def __init__(self, capacity):
        self.buffer = []
        self.capacity = capacity

    def push(self, rk, it):
        item = self.RankedItem(rk, it)
        # push the current item
        if len(self.buffer) < self.capacity:
            return heappush(self.buffer, item)
        # ... pop the lowest-ranking one, if we exceed capacity
        return heappushpop(self.buffer, item)

    def extend(self, pairs):
        last = None
        for rk, it in pairs:
            last = self.push(rk, it)
        return last

    def __bool__(self):
        return bool(self.buffer)

    def __getitem__(self, index):
        return self.buffer[index]

    def __repr__(self):
        return type(self).__name__ + f"({len(self.buffer)}/{self.capacity})"

    def __iter__(self):
        return ((el.rank, el.item) for el in self.buffer)

    def sample(
        self,
        n_samples=8,
        n_steps=64,
        *,
        rng=np.random.default_rng(),
    ):
        # determinie the sufficiently long episodes
        eligible = []
        for j, ep in enumerate(self.buffer):
            input, val, pol = ep.item
            dur_ = len(input.fin) - int(input.fin[-1])
            if dur_ < n_steps:
                continue

            eligible.append((j, dur_,))

        # sample starting strands from episodes
        chunks = []
        for i in rng.choice(len(eligible), size=n_samples):
            k, dur_ = eligible[i]
            j = rng.integers(dur_ - n_steps + 1)

            chunk = plyr.apply(lambda t: t[j:j + n_steps],
                               self.buffer[k].item)
            chunks.append(chunk)

        return plyr.apply(torch.stack, *chunks, _star=False, dim=1)

A procedure to get the likelihood of actions uder a given policy sequence

In [None]:
from nle_toolbox.utils.rl.returns import pyt_vtrace

def logpact(logpol, act):
    # (sys) get \log\mu_t(a_t) from `logpol[t][act[t+1]])`, t=0..T-1
    return logpol[:-1].gather(-1, act[1:].unsqueeze(-1)).squeeze_(-1)

def td_target(rew, fin, val, *, gam):
    # add extra trailing unitary dims for broadcasting
    fin_ = fin.reshape(fin.shape + (1,) * max(rew.ndim - fin.ndim, 0))
    gam_ = rew.new_full(fin_.shape, gam).masked_fill_(fin_, 0.)
    return torch.addcmul(rew, gam_, val[1:])

@torch.no_grad()
def pyt_impala(rew, fin, act, val, pol, myuval, myupol, *, gam, r_bar, c_bar):
    rho = logpact(pol, act) - logpact(myupol, act)
    vtr = pyt_vtrace(rew[1:], fin[1:], myuval, rho=rho,
                     gam=gam, r_bar=r_bar, c_bar=c_bar)

    # [O(T B F)] get the importance-weighted td(0) errors
    adv = td_target(rew[1:], fin[1:], vtr, gam=gam)
    rho_ = rho.reshape(rho.shape + (1,) * max(rew.ndim - fin.ndim, 0))
    rho_.exp_().clamp_(max=r_bar)
    adv.sub_(val[:-1]).mul_(rho_)

    return vtr[:-1], adv, rho

Compute the exploration metrics

In [None]:
from nle.nethack import NLE_BL_SCORE
from nle_toolbox.utils.env.defs import GLYPH_CMAP_OFF, symbol

def ep_metrics(ep, *, S_stone=symbol.S_stone + GLYPH_CMAP_OFF):
    met = {}

    # convert to numpy and determine the offset for unfinished episodes
    npy = plyr.apply(np.asarray, ep)
    off = int(npy.fin[-1])
    if len(npy.fin) <= off:
        return met

    obs, msk = npy.obs

    # score the episode
    met['ret'] = float(npy.rew[1:].sum())
    met['scr'] = obs['blstats'][-1-off, NLE_BL_SCORE]

    # coverage and action effectiveness
    gly = obs['glyphs'][:(-1 if off > 0 else None)]
    # XXX we exclude the terminal obs, because it is actually the init
    #  obs from the next episode
    non_stone = (gly != S_stone).mean((-2, -1))
    met['cov'] = non_stone.max() / non_stone.min()
    met['eff'] = sum((g0 != g1).mean() for g0, g1 in zip(gly, gly[1:]))

    met['len'] = len(npy.fin) - off

    return met

A common routine to plot the computed rollout metrics

In [None]:
def plot_ep_metrics(metrics):
    fig, axes = plt.subplots(2, 2, figsize=(7, 3), dpi=300)

    out_ = plyr.apply(list, *metrics, _star=False)
    for ax, (nom, val) in zip(axes.flat, out_.items()):
        ax.hist(val, label=nom, log=nom in ('scr', 'ret',), bins=20)
        ax.set_title(nom)

    fig.tight_layout()
    return fig, axes

Rank the peisode and put it into a buffer

In [None]:
def add_episodes(buf, iterable, *, C_cov=0.1):
    for input, val, pol in iterable:
        met = ep_metrics(input)
        if not met:
            continue

        rk = met['ret'] + C_cov * met['cov'] / met['len']
        buf.push(rk, (input, val, pol))

Prepare a strand extractor and a buffer for ranking episodes, and ready the vectorized env and the runtime context.

In [None]:
epx, buf = EpisodeExtractor(), RankedBuffer(128)
metrics = []

env = SerialVecEnv(factory, n_envs=4)
npyt, hx = prepare(env, rew=0., fin=True), None

A visualized evaluation run.

In [None]:
n_total = 16384
n_steps, visualize = 0, 0
with torch.no_grad(), tqdm.tqdm(
    initial=n_steps, total=n_total, ncols=80,
    disable=visualize is not None,
) as bar:
    while progress(bar, n_steps):
        # (sys) collect a fragment of the episode time `t` afterstates, t=0..N-1
        fragment, hxx = zip(*collect(env, agent, npyt, hx, n_steps=128, visualize=visualize))
        # XXX `fragment` is ((x_t, a_{t-1}, r_t, d_t), v_t, \mu_t), t=0..N-1

        # (sys) retain running state `hx`, but detach its grads (truncated bptt)
        # ATTN do not update `npyt` and `hx`!
        hx = plyr.apply(torch.Tensor.detach, hxx[-1])

        # (sys) repack the fragment data
        # XXX note, `.act[t]` is $a_{t-1}$, but the other `*[t]` are $*_t$,
        #  e.g. `.rew[t]` is $r_t$, and `pol[t]` is `$\pi_t$.
        input, _, _ = fragment = plyr.apply(torch.cat, *fragment, _star=False)

        # (sys) incerment the step count
        n_steps += input.fin.numel()

        # (sys) extract episode strands with log-probs of the taken actions
        episodes = epx.extract(input.fin, fragment)
        add_episodes(buf, episodes, C_cov=0.5)

        # (evl) compute the metrics of the completed episodes
        for input_, val_, pol_ in episodes:
            met = ep_metrics(input_)
            if met:
                metrics.append(met)

    # (sys) extract the residual episode strands
    unfinished = epx.finish()
    add_episodes(buf, unfinished, C_cov=0.5)

plot_ep_metrics(metrics);

The stats of the episodes in the buffer.

In [None]:
from matplotlib import pyplot as plt

out, fps = [], None
for rk, (input, val, pol) in buf:
    npy = plyr.apply(np.asarray, input)
    off = int(npy.fin[-1])  # offset for unfinished strands
    for t in range(len(npy.fin) - off):
        obs = plyr.apply(plyr.getitem, npy.obs, index=t)
        ipynb_render(obs, fps=fps)

    met = ep_metrics(input)
    if met:
        out.append(met)

plot_ep_metrics(out);

Get ready to clone the successful episodes in the ranked buffer

In [None]:
n_samples, n_steps = 8, 64
r_bar, c_bar = 1.01, 1.1
rng = np.random.default_rng()

losses = []

Behaviour cloning
* let's have a look at [Self-Imitation Learning](https://proceedings.mlr.press/v80/oh18b.html)
> ... off-policy actor-critic algorithm that learns to reproduce the agent’s past good decisions.
... that exploiting past good experiences can indirectly drive deep exploration.

In [None]:
for k in tqdm.tqdm(range(50), ncols=70):
    # sample a bath of trajectory fragments
    input, myuval, myupol = buf.sample(n_samples, n_steps, rng=rng)

    # recompute the policy and value-to-go estimates for the episode
    # XXX amnesia training: forget the hx
    _, (val, pol), _ = agent(input.obs, input.act, input.rew, hx=None, fin=None)
    # XXX this is not EXACTLY identical to `fin=ep_.fin`, which is guaranteed
    #  to contain a reset `fin[0]` and possibly a `fin[-1]` (not in case when
    #  the episode is unfinished). We ignore pol[-1] $\pi_{T}$ and val[-1]
    #  $v(s_{T})$, both of which pertain to the next episode. `fin` affects
    #  only the recurrrent state anyway and 1) we set the initial to `None`,
    #  and 2) do not ever use the

    L_loglik = pyt_polgrad(pol, input.act, adv=1.)
    ell = - L_loglik

    # get the v-trace target for the critic and the advantages to pol-grad
    # XXX here `.fin[-1]` properly blocks the last state-value backup
#     ret, _, rho = pyt_impala(
#         input.rew, input.fin, input.act,
#         val['ext'], pol, myuval['ext'], myupol,
#         gam=f_gamma['ext'], r_bar=r_bar, c_bar=c_bar
#     )
#     L_critic = pyt_critic(val['ext'], ret)

#     ell = (L_critic * (C_critic['ext'] / 2) - L_loglik)

    agent.optim.zero_grad()
    ell.backward()
    agent.optim.step()

#     losses.append((float(L_loglik), float(L_critic)))
    losses.append((float(L_loglik),))

In [None]:
L_loglik, = map(np.array, zip(*losses))

plt.semilogy(-L_loglik, c='C1')

In [None]:
assert False

<br>

In [None]:
with torch.no_grad():
    # `(W_ii|W_if|W_ig|W_io)`
    blsw = agent.core.weight_ih_l0[:, 128:]
    blsw = blsw.reshape(len(blsw), -1, 5)
    norms = blsw.norm(p=2, dim=-1)

norms = dict(zip([
    'hunger', 'status', 'hp', 'mp',
    'str', 'dex', 'con', 'int', 'wis', 'cha',
    'strprc', 'AC', 'encumberance',
], norms.T))

breaks = 0, 128, 256, 384, 512
c_pair = "C0", "C1"
fig, axes = plt.subplots(2,2, figsize=(7, 7,), dpi=300, sharey=True, sharex=True)
for nom, ax in zip(norms, axes.flat):
    ax.semilogy(norms[nom], label=nom)
    for j, (a, b) in enumerate(zip(breaks[1:], breaks)):
        ax.axvspan(a, b, color=c_pair[j&1], alpha=0.05, zorder=-10)

    ax.legend(fontsize='xx-small')

<br>

In [None]:
# plyr.ragged(lambda v, C: C * v.sum(), pg, C_pg)

In [None]:
flat = []
plyr.ragged(lambda v, C: flat.append(C * v), pg_gae, C_pg)
plyr.ragged(lambda v, C: flat.append(C * v), entropy, C_entropy)
plyr.ragged(lambda v, C: flat.append(C * v), critic, C_critic)

In [None]:
flat

Remove currently unused fileds from the observations

In [None]:
from nle_toolbox.utils.env.defs import MAX_GLYPH

def filter(
    glyphs,
    blstats,
    inv_letters,
    inv_glyphs,
    **ignore,
):
    return dict(
        glyphs=glyphs,
        blstats=blstats,
        inv_letters=inv_letters,
        inv_glyphs=inv_glyphs,
    )

<br>

     y  k  u  
      \ | /   
    h - . - l 
      / | \   
     b  j  n  

Spell tracker

In [None]:
def spellcaster(obs, mask, *, dir='.', ctoa):
    yield from map(ctoa.get, f'Z{letter}{dir}')

Random policy

In [None]:
def linger(obs, mask, n=16, *, seed=None, ctoa=None):
    rng, j = np.random.default_rng(seed), 0
    while not mask.all() and j < n:
        # if we're in LINGER state, pick a random non-forbidden action
        # XXX whelp... tilde on int8 is `two's complement`, not the `logical not`
        act = rng.choice(*np.logical_not(mask).nonzero())

        obs, mask = (yield act)
        j += 1

def search(obs, mask, n=6, *, ctoa):
    yield from map(ctoa.get, f'{n:d}s')

Level and dungeon mapper

In [None]:
from nle.nethack import (
    NLE_BL_X,
    NLE_BL_Y,
    NLE_BL_DNUM,
    NLE_BL_DLEVEL,
    # NLE_BL_DEPTH,  # derived from DNUM and DLEVEL
    # XXX does not uniquely identify floors,
    #  c.f. [`depth`](./nle/src/dungeon.c#L1086-1084)
    DUNGEON_SHAPE,
    MAX_GLYPH,
)

from nle_toolbox.utils.env.defs import \
    glyph_is, dt_glyph_ext, ext_glyphlut

from nle_toolbox.bot.level import Level, DungeonMapper

Detemine the walkability of the observed tiles

In [None]:
from nle_toolbox.utils.env.defs import symbol, GLYPH_CMAP_OFF, glyph_group, get_group
from nle_toolbox.utils.env.defs import glyphlut, ext_glyphlut

closed_doors = get_group(symbol, GLYPH_CMAP_OFF, *[
    'S_vcdoor', 'S_hcdoor',
    'S_vcdbridge', 'S_hcdbridge',
])

open_doors = get_group(symbol, GLYPH_CMAP_OFF, *[
    'S_ndoor',
    'S_vodoor', 'S_hodoor',
    'S_vodbridge', 'S_hodbridge',
])

is_closed_door = np.isin(ext_glyphlut.id.value, np.array(list(closed_doors)))
is_actor = np.isin(ext_glyphlut.id.group, np.array(list(glyph_group.ACTORS)))
is_pet = ext_glyphlut.id.group == glyph_group.PET

is_open_door = np.isin(ext_glyphlut.id.value, np.array(list(open_doors)))
is_object = np.isin(ext_glyphlut.id.group, np.asarray(list(glyph_group.OBJECTS)))
is_walkable = ext_glyphlut.is_accessible | is_open_door | is_object

In [None]:
traps = get_group(symbol, GLYPH_CMAP_OFF, *[
    'S_arrow_trap',
    'S_dart_trap',
    'S_falling_rock_trap',
    'S_squeaky_board',
    'S_bear_trap',
    'S_land_mine',
    'S_rolling_boulder_trap',
    'S_sleeping_gas_trap',
    'S_rust_trap',
    'S_fire_trap',
    'S_pit',
    'S_spiked_pit',
    'S_hole',
    'S_trap_door',
    'S_teleportation_trap',
    'S_level_teleporter',
    'S_magic_portal',
    'S_web',
    'S_statue_trap',
    'S_magic_trap',
    'S_anti_magic_trap',
    'S_polymorph_trap',
    'S_vibrating_square',
])

is_trap = np.isin(ext_glyphlut.id.value, np.array(list(traps)))

The core of the "smart" dungeon explorer

In [None]:
from scipy.special import softmax

def crawler(obs, mask, *, dir, seed=None):
    dng = DungeonMapper()

    # own random number generator
    rng = np.random.default_rng(seed)

    # a simple state machine: linger <<-->> crawler
    state, n_linger, stack = 'linger', 16, []
    while True:
        dng.update(obs)
        pos = dng.level.trace[-1]

        if state == 'crawl':
            if stack:
                plan.pop()
                act = dir[stack.pop()]

            else:
                state, n_linger = 'linger', 16
                continue

        elif state == 'linger':
            if n_linger > 0:
                n_linger -= 1

                # if we're in LINGER state, pick a random non-forbidden action
                # XXX whelp... tilde on int8 is `two's complement`, not the `logical not`
                act = rng.choice(*np.logical_not(mask).nonzero())

            else:
                lvl = dng.level

                # we've run out linger moves, time to pick a random destination
                # and go to it
                state = 'crawl'

                # get the walkability cost
                cost = np.where(
                    # is_walkable[lvl.bg_tiles.glyph]
                    (is_walkable | is_pet)[lvl.bg_tiles.glyph]
                    , .334, np.inf)
                # XXX adjust `cost` for hard-to-pass objects?
                cost[is_trap[lvl.bg_tiles.glyph]] = 10.

                # get the shortest paths from the current position
                value, path = dij(cost, pos)

                # draw a destination, the further the better
                prob = softmax(np.where(
                    is_closed_door[lvl.bg_tiles.glyph],
                    100.,
                    np.where(
                        np.logical_and(
                            np.isfinite(value),
                            np.logical_not(
                                is_trap[lvl.bg_tiles.glyph]
                            )
                        ), value, -np.inf
                    ))
                )
                dest = divmod(rng.choice(prob.size, p=prob.flat), prob.shape[1])

                # reconstruct the path to the destination in reverse order
                plan = list(backup(path, dest))
                for (r1, c1), (r0, c0) in zip(plan, plan[1:]):
                    stack.append(dir_to_ascii[r1-r0, c1-c0])

                plan.pop()
                continue

        obs, mask = yield act

How do we want to explore?
* open closed doors
* explore tunnels

Implementing the random dungeon crwaler

In [None]:
dng = getgeneratorlocals(gen).get('dng')
# dng.level.trace[-1]

In [None]:
plt.imshow(dng.level.bg_tiles.info.is_accessible)

     y  k  u  
      \ | /   
    h - . - l 
      / | \   
     b  j  n  

In [None]:
# get the walkability cost
cost = np.where((
    is_walkable
    | is_pet
)[obs['glyphs']], 1., np.inf)
# XXX adjust `cost` for hard-to-pass objects?
cost[is_trap[obs['glyphs']]] = 10.

# get shroteste paths from the current position
bls = obs['blstats']
value, path = dij(cost, (bls[NLE_BL_Y], bls[NLE_BL_X]))

prob = softmax(np.where(
    np.logical_and(
        np.isfinite(value),
        np.logical_not(
            is_trap[obs['glyphs']]
        )
    ), value, -np.inf
))

plt.imshow(value)

<hr>

Test the algo

In [None]:
r, c = 12, 12

rng = np.random.default_rng()  #248675)

cost = -np.log(rng.random((21, 79)))
# cost = np.ones((21, 79))
cost[rng.random(cost.shape) < .5] = np.inf

value, path = dij(cost, (r, c))


# mask = is_walkable[lvl.bg_tiles.glyph] | is_walkable[lvl.stg_tiles.glyph]
mask = np.isfinite(value)
mask[r, c] = False  # mask the current position

from scipy.special import softmax

value = np.where(value > 5, 0., -np.inf)
prob = softmax(np.where(mask, value, -np.inf))

Play around with the shortes path.

In [None]:
r, c = divmod(rng.choice(prob.size, p=prob.flat, ), prob.shape[1])

displ = cost.copy()
plan = list(backup(path, (r, c)))
for ij in plan:
    displ[ij] = 10
displ[12, 12] = 11


fig, ax = plt.subplots(1, 1, dpi=300)
ax.imshow(displ)

In [None]:
commands = []
for (r1, c1), (r0, c0) in zip(plan, plan[1:]):
    commands.append(dir_to_ascii[r1-r0, c1-c0])

''.join(reversed(commands))

<br>

In [None]:
assert False

A non-illegal random action exploration.

In [None]:
from copy import deepcopy
from nle_toolbox.bot.chassis import get_wrapper


def random_explore(seed=None, n_steps=1000, *, auto=False, fps=None, copy=False):
    """A non-illegal random action explorer.
    """
    ss_pol, ss_env = np.random.SeedSequence(seed).spawn(2)

    rng, j, n_linger, pf = np.random.default_rng(ss_pol), 0, 0, None
    with factory(seed=ss_env) as env:
        # we need access to the Chassis for additional meta state variables
        cha = get_wrapper(env, Chassis)

        # ActionMasker caches the esacpe action id
        ESC = get_wrapper(env, ActionMasker).escape

        # setup the dungeon mapper
        dng = DungeonMapper()

        # launch the episode
        (obs, mask), fin = env.reset(), False
        while (
            ipynb_render(obs, clear=True, fps=fps)
            and not (fin or j >= n_steps)
        ):
            # though nle reuses buffers, we do not deep copy them
            #  delegating this to the downstream user instead
            yield deepcopy(obs) if copy else obs

            # default to immediately escaping from any menu or prompt
            act = ESC
            if not (cha.in_menu or cha.prompt):
                dng.update(obs)

                # if we're in LINGER state, pick a random non-forbidden action
                # XXX whelp... tilde on int8 is `two's complement`, not the `logical not`
                act = rng.choice(*np.logical_not(mask).nonzero())

            (obs, mask), rew, fin, info = env.step(act)
            j += 1

            if fin and auto:
                ipynb_render(obs, clear=True, fps=fps)
                (obs, mask), fin = env.reset(), False

Get a random episode

In [None]:
# from inspect import getgeneratorlocals
episode = random_explore(
    seed=None,
    n_steps=256,
    auto=False,
    copy=True,
    fps=0.01,
)


glyphs = [next(episode)]
# dng = getgeneratorlocals(episode).get('dng')

glyphs.extend(obs['glyphs'] for obs in episode)

In [None]:
assert False

In [None]:
from scipy.special import softmax

def dstination_prob(lvl, pos):
    r, c = pos
    dist = np.maximum(abs(lvl.bg_tiles.rc.r - r), abs(lvl.bg_tiles.rc.c - c))

    mask = is_walkable[lvl.bg_tiles.glyph] | is_walkable[lvl.stg_tiles.glyph]
    mask[r, c] = False  # mask the current position
    return softmax(np.minimum(np.where(mask, dist, -np.inf), 5))

rng = np.random.default_rng()
prob = dstination_prob(dng.level, dng.level.trace[-1])
cost = np.where(prob > 0, 1., float('inf'))

plt.imshow(prob)

In [None]:
plt.imshow(cost)

In [None]:
def backup(path, dest):
    p0 = dest
    while True:
        p0, p1 = path[p0], p0
        yield p1
        if p0 is None:
            return

#         (r0, c0), (r1, c1) = p0, p1
#         yield directions[r1-r0, c1-c0]
        

In [None]:
value, path = dij(cost, dng.level.trace[-1])

In [None]:
val = value.copy()
r, c = rng.choice(dng.level.bg_tiles.rc.flat, p=prob.flat)

fig, ax = plt.subplots(1, 1, dpi=300)
for i, j in backup(path, (r, c)):
    val[i, j] = 0.

val[r, c] = np.inf

ax.imshow(val[:, 10:40])

<br>

<br>

In [None]:
from einops import repeat, rearrange
from transformers import ViTModel, ViTConfig

from transformers.models.vit.modeling_vit import to_2tuple


class NLEViTEmbeddings(nn.Module):
    def __init__(self, config):
        n_rows, n_cols = to_2tuple(2 * config.window + 1)
        super().__init__()

        self.cls = nn.Parameter(torch.zeros(
            1, config.hidden_size))

        self.posemb = nn.Parameter(torch.zeros(
            1 + n_rows * n_cols, config.hidden_size))

    def forward(self, input, **ignore):
        x = rearrange(input, 'B D H W -> B (H W) D')

        # cls-token and positional embedding
        cls = repeat(self.cls.unsqueeze(0), '() N D -> B N D', B=len(x))
        return torch.cat((cls, x), dim=1) + self.posemb


model = ViTModel(ViTConfig(
    hidden_size=128,
    num_hidden_layers=1,
    num_attention_heads=8,
    intermediate_size=512,
    window=2,
    # image_size,
    # patch_size,  # are ignored
))

model.embeddings = NLEViTEmbeddings(model.config)


# obs, msk = input.obs
# gly = agent.features.obs[0](obs)

# win = gly['vicinity']

# size = dict(zip("TBCHW", win.shape[:3]))
# x = rearrange(win, 'T B C H W -> (T B) C H W')

# out = rearrange(model(x).pooler_output, '(T B) C -> T B C', **size)

In [None]:
model

<br>