# Neuro proto

In [None]:
import time
import gym
import nle

import numpy as np
import matplotlib.pyplot as plt

In [None]:
del gym.Wrapper.__getattr__

We hide the NLE under several layers of wrappers. From the core to the shell:
1. `ReplayToFile` saves the seeds and the takes actions into a file for later inspection and replay.
2. `NLEAtoN` maps ascii actions to opaque actions accpeted by the NLE.
3. `NLEObservationPatches` patches tty-screens, botched by the cr-lf misconfiguration of the NLE's tty term emulator and NetHacks displays (lf only).
4. `NLEFeatureExtractor` adds extra features generated on-the-fly from the current NLE's observation. 

In [None]:
from nle_toolbox.utils.replay import ReplayToFile
from nle_toolbox.utils.env.wrappers import (
    NLEObservationPatches,
    NLEAtoN,
    NLEFeatureExtractor,
)


def factory():
    return NLEObservationPatches(
        NLEAtoN(
            ReplayToFile(
                gym.make("NetHackChallenge-v0"),
                save_on="done",
                sticky=True,
                folder="./replays",
            )
        )
    )

## Basic GUI Handling

NetHack's gui is not as intricate as in some other games. We need to deal
with menus, text prompts, messages and y/n questions. In order to analyze
the interface details and player's journey through the UI, we first implement
a simple command evaluator.

In [None]:
from collections import deque


def gui_run(env, *commands):
    pipe0 = deque([])
    obs, done = env.reset(), False
    for cmd in commands:
        pipe0.extend(cmd)
        while pipe0 and not done:
            obs, rew, done, info = env.step(pipe0.popleft())

        yield obs
        if done:
            break

A renderer

In [None]:
import pprint as pp
from nle_toolbox.utils.env.render import render as tty_render


def ipynb_render(obs, clear=True):
    from IPython.display import clear_output

    if clear:
        clear_output(wait=True)

    print(tty_render(**obs))


def dump(env, obs):
    ipynb_render(obs, clear=False)
    pp.pprint(
        (
            env.messages,
            env.menu,
            env.prompt,
            "FT"[env.in_menu]
            + "FT"[env.xwaitingforspace]
            + "FT"[env.in_yn_function]
            + "FT"[env.in_getlin],
            obs["blstats"][NLE_BL_TIME],
            # {
            #     chr(let): itm
            #     for let, itm in zip(obs['inv_letters'], obs['inv_strs'].view('S80')[:, 0])
            #     if let > 0
            # }
        ),
        width=120,
    )

### Menus

There are two types of menus on NetHack: single paged and multipage. Single
page menus popup in the middle of the terminal ontop of the dungeon map (and
are sort of `dirty`, meaning that they have arbitrary symbols around them),
while multi-page menus take up the entire screen after clearing it. Overlaid
menu regions appear to be right justified, while their contents' text is
left-justified. All menus are modal, i.e. capture the keyboard input until
exited. Some menus are static, i.e. display some information, while other
are interactive, i.e. allow item selection with letters or punctuation. However,
both kinds share two special control keys. The space `\0x20` (`\040`, 32,
`<SPACE>`) advances to the next page, or closes the menu, if the page was
the last or the only one. The escape `\0x1b` (`\033`, 27, `^[`) immediately
exits any menu.

In [None]:
pass

The following detects the type of the menu (overlay/fullscreen), its number
of pages, and extracts its content.

In [None]:
pass

The following function extracts raw data from a menu and enumerates all
items, which can be interacted with.

In [None]:
pass

## Top Line Messages

The game reports events, displays status or information in the top two lines
of the screen. The NLE also provides the raw data in the `message` field of
the observation. When NetHack generally announces in the top line, however,
if it wants to communicate a single message longer than `80` characters, the
game allows it to spill over to the second line, appending a `--More--` suffix
to it. The game does the same if it has several short messages to announce.
In both cases NetHack's gui expects the user to confirm or dismiss each message
by pressing Space, Enter or Escape.

Some helper functions to fetch and detect multi-part messages.

In [None]:
pass

<br>

## Putting it all together

Below is a wrapper, which handles menus (unless an interaction is required) and
fetches all consecutive messages.

In [None]:
from nle_toolbox.bot.chassis import Chassis

<br>

## Testing

Let's test it in bulk.

In [None]:
import pprint as pp
from nle.nethack import NLE_BL_TIME

# seed = None
# seed = 12513325507677477210, 18325590921330101247  # multi
# seed = 1251332550767747710, 18325590921330101247  # single
# seed = 125133255076774710, 18325590921330101247  # single
# seed = 13765371332493407478, 12246923801353953927
# seed = 12301533412141513004, 11519511065143048485
# seed = 1632082041122464284, 11609152793318129379
seed = 12604736832047991440, 12469632217503715839  # an aspirant

with Chassis(factory(), split=False) as env:
    seed = env.seed(seed)
    for obs in gui_run(
        env,
        #         ';j:',   # a paragraph about a cat
        #         'acy'      # break wand and blow up
        "Zbyyy,",  # cast a sleep ray at a newt and pick up its corpse
        #         # FIXME the inventory seems interactive, but it is not
        #         'i',       # open the inventory
        #         # drop a quarterstaff
        #         '\033d*a',
        # open inventory
        "i><><",
    ):
        dump(env, obs)

In [None]:
assert False

In [None]:
from nle_toolbox.bot.chassis import InteractiveWrapper, Wrapper, get_wrapper

history = []


class Snitch(InteractiveWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.ctoa = {chr(a): j for j, a in enumerate(self.unwrapped.actions)}
        self.atoc = {j: chr(a) for j, a in enumerate(self.unwrapped.actions)}

    def reset(self):
        self._action = self.ctoa["."]
        return super().reset()

    def step(self, action):
        self._action = action
        return super().step(action)

    def update(self, obs, rew=0.0, done=False, info=None):
        asc = self.atoc[self._action]
        history.append(
            (
                str(asc.encode("unicode-escape"))[2:-1],
                self.env.env.messages,
                self.env.env.menu,
                self.env.env.prompt,
                self.env.env.in_menu,
                self.env.env.in_yn_function,
                self.env.env.in_getlin,
                self.env.env.xwaitingforspace,
            )
        )
        # update level representation
        return obs, rew, done, info

In [None]:
from nle_toolbox.bot.chassis import ActionMasker
from nle_toolbox.utils.seeding import set_seed


def factory(seed=None):
    # we force seed the main challenge task, since we are training here!
    # XXX this has one-off effect until the next reset

    env = ReplayToFile(
        gym.make("NetHackChallenge-v0"),
        save_on="done,close",
        sticky=True,
        folder="./replays",
    )
    env.seed(seed=seed)

    ctoa = {chr(a): j for j, a in enumerate(env.unwrapped.actions)}

    # use chassis atop tty patches
    env = ActionMasker(
        Chassis(NLEObservationPatches(env), space=ctoa[" "], split=False)
    )
    return Snitch(env)

In [None]:
import torch

from rlplay.utils.common import multinomial
from rlplay.engine.core import BaseActorModule


class RandomActor(BaseActorModule):
    def __init__(self, env):
        super().__init__()
        self.action_space = env.action_space

    def step(
        self,
        stepno,
        obs,
        act,
        rew,
        fin,
        *,
        hx=None,
        virtual=False,
    ):
        if not virtual:
            # act = torch.randint_like(act, self.action_space.n)
            raw = torch.randn(*act.shape, self.action_space.n, dtype=float)

            # mask illegal actions
            _, mask = obs
            raw = raw.masked_fill(mask.to(torch.int8), -float("inf"))
            prb = raw.softmax(-1)
            act = multinomial(prb)

        return act, (), {}

In [None]:
from rlplay.engine.rollout import same, multi, single

history.clear()
# seed = 12604736832047991440, 12469632217503715839
# seed = [*b'i9u48y7548']
seed = None
with factory(seed) as env:
    actor = RandomActor(env)
    rndgen = same.rollout(
        [
            env,
            # factory(seed),
            # factory(seed),
            # factory(seed),
        ],
        actor,
        n_steps=250,
    )
    fragment = next(rndgen)

* `&` results in `What command?`

Building a network

In [None]:
import torch
import torch.nn.functional as F

from torch import nn
from copy import deepcopy

import plyr

In [None]:
from gym import spaces
from nle_toolbox.bot.legacy.option import OptionWrapper


class CompositeActions(OptionWrapper):
    def __init__(self, env, *actions, quit=(65, 7)):
        super().__init__(env, reduce=sum, allow_empty=True)

        # register the composites and redefine the action space
        self.actions = (self.forward,) + actions
        self.action_space = spaces.Tuple(
            (
                # halting flag
                spaces.Discrete(2),
                spaces.Dict(
                    dict(
                        # the macro action head
                        macro=spaces.Discrete(1 + len(actions)),
                        # the original actions
                        micro=self.action_space,
                    )
                ),
            )
        )

        self.quit = quit

    def reset(self):
        obs = self.obs = super().reset()
        return obs

    def step(self, action):
        hlt, act = action

        obs, rew, done, info = super().step(self.dispatch(hlt, **act))
        self.obs = obs
        return obs, rew, done, info

    def dispatch(self, hlt, macro, micro):
        # omg lul, win by quitting!
        if hlt:
            # win by flying, ie. executing '\xf1y' (shorthand for `#quit\015y`)
            yield self.quit
            return

        yield from self.actions[macro](self.obs, micro)

    def forward(self, obs, micro):
        yield micro,  # one-action policy

A default recipe

In [None]:
recipe = dict(
    features=dict(
        glyphs=dict(
            embedding_dim=64,
            window=2,
        ),
        bls=dict(
            n_vitals=64,
            n_build=32,
        ),
        sizes=[
            # flattened vicinity, inventory and bls
            64 * ((2 + 1 + 2) * (2 + 1 + 2) + 55) + (64 + 32),
            2048,
            512,
        ],
    ),
    core=dict(
        input_size=512,
        hidden_size=256,
        num_layers=1,
        bias=True,
        dropout=0.0,
        kind="lstm",
    ),
    head=dict(
        n_features=256,
        heads=dict(
            macro=2,  # determined by the Composite action wrapper
            # the underlying discrete action space
            micro=len(ActionMasker._raw_nethack_actions),
        ),
    ),
)

In [None]:
recipe

Let's test it

In [None]:
def scripted(ctoa, actions):
    def generator(obs=None, act=None):
        for act in actions:
            yield tuple(map(ctoa.get, act))

    return generator

In [None]:
def factory(seed=None):
    env = NLEObservationPatches(
        #         ReplayToFile(
        gym.make("NetHackChallenge-v0"),
        #             save_on='done,close',
        #             sticky=True,
        #             folder='./replays',
        #         )
    )
    ctoa = {chr(a): j for j, a in enumerate(env.unwrapped.actions)}

    return CompositeActions(
        ActionMasker(
            Chassis(
                env,
                space=ctoa[" "],
                split=False,
            )
        ),
        scripted(ctoa, "32s"),  # search!
        quit=(
            ctoa["\xf1"],
            ctoa["n"],
        ),
    )

In [None]:
from nle_toolbox.bot.model.network import Network
from nle_toolbox.bot.model.utils import NeuralActorModule

module = Network(recipe)
actor = NeuralActorModule(module)
# XXX deal with the pesky warning!

In [None]:
import tqdm

n_epochs = 100
with factory() as env:
    rndgen = same.rollout(
        [
            env,
            factory(),
            # factory(),
            # factory(),
            # factory(),
            # factory(),
        ],
        actor,
        n_steps=25,
    )

    for ep, fragment in zip(tqdm.tqdm(range(n_epochs)), rndgen):
        fragment = next(rndgen)

In [None]:
obs, mask = fragment.state.obs
gl = module.features.glyphs(obs)

{k: v.shape for k, v in gl.items()}

In [None]:
from rlplay.utils.plotting.grid import make_grid

x = gl["vicinity"][:, -1].detach()
fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=160)
ax.imshow(
    make_grid(
        x.flatten(0, 1),
        aspect=(4, 3),
        normalize=False,
    ).numpy(),
    cmap=plt.cm.bone,
)

## Playground

In [None]:
import torch
from torch.optim.optimizer import Optimizer, required


class EMA(Optimizer):
    r"""Exponential average parameter tracker.

    Details
    -------
    Update the tracked parameter value $\hat{\theta}$ with the current
    values $\theta$ according to

    $$
        \hat{\theta}
            \longleftarrow (1 - \eta) \hat{\theta} + \eta \theta
        \,, $$

    with $\eta \in [-1, +1]$.
    """

    def __init__(self, params, lr=required):
        if lr is not required and abs(lr) > 1:
            raise ValueError(f"Invalid learning rate: {lr}")

        super().__init__(params, dict(lr=lr))

    @torch.no_grad()
    def step(self):
        """Perform a single moving average step."""
        for group in self.param_groups:
            lr = group["lr"]
            for p in group["params"]:
                # we update diff-able params only
                if p.grad is None:
                    continue

                if p not in self.state:
                    # lazy init buffers by simply copying current value
                    self.state[p] = p.clone()

                else:
                    # otherwise lerp the buffer with the current param
                    self.state[p].lerp_(p, lr)

    def zero_grad(self, set_to_none: bool = False):
        raise NotImplementedError

In [None]:
from torch.utils.data import TensorDataset


@torch.no_grad()
def gen(module, **sizes):
    out = {}
    for k, n in sizes.items():
        X = torch.randn(n, 16)
        out[k] = TensorDataset(
            X,
            module(X).argmax(-1),
        )

    return out

In [None]:
# dope!
L, S, A = nn.Linear, nn.Sequential, nn.ReLU
R = type(
    "R",
    (S,),
    dict(
        forward=lambda self, x: self[0](x) + x,
    ),
)

In [None]:
ref = nn.Sequential(
    nn.Linear(16, 2),
)

datasets = gen(ref, train=1024, test=256)

feeds = dict(
    train=torch.utils.data.DataLoader(datasets["train"], batch_size=32, shuffle=True),
    test=torch.utils.data.DataLoader(datasets["test"], batch_size=64, shuffle=False),
)

In [None]:
module = S(
    L(16, 32),
    A(),
    *[
        R(
            S(
                L(32, 32),
                A(),
            )
        )
        for _ in range(1)
    ],
    L(32, 2),
    nn.LogSoftmax(dim=-1),
)

optim = torch.optim.Adam(module.parameters(), lr=1e-3)

In [None]:
import tqdm
from copy import deepcopy

n_epochs = 500
losses, ema_tracked, tracked = [], [], []

# init ema
ema = deepcopy(module)
state = ema.state_dict()

for ep in tqdm.tqdm(range(n_epochs)):
    ep_losses = []
    for bx, by in feeds["train"]:
        loss = F.nll_loss(module(bx), by, reduction="mean")
        optim.zero_grad()
        loss.backward()
        optim.step()

        # update ema
        with torch.no_grad():
            plyr.suply(
                torch.Tensor.lerp_,
                state,
                module.state_dict(),
                weight=0.1,  # 1 / float(ep + 1),
            )

        ep_losses.append(float(loss))

    losses.append(np.mean(ep_losses))

    with torch.no_grad():
        nll = torch.cat(
            [F.nll_loss(ema(tx), ty, reduction="none") for tx, ty in feeds["test"]],
            dim=0,
        )

    ema_tracked.append(float(nll.mean(-1)))

    with torch.no_grad():
        nll = torch.cat(
            [F.nll_loss(module(tx), ty, reduction="none") for tx, ty in feeds["test"]],
            dim=0,
        )

    tracked.append(float(nll.mean(-1)))

In [None]:
plt.semilogy(losses)

plt.semilogy(tracked)
plt.semilogy(ema_tracked)

In [None]:
plyr.suply(plyr.getitem, fragment, index=(slice(None), -1))

In [None]:
from nle.nethack import NLE_BL_TIME

In [None]:
fragment.state.obs[0]["blstats"][..., NLE_BL_TIME]

In [None]:
from nle_toolbox.utils.env.defs import glyphlut

In [None]:
def prepare_from_example(
    obs,
    act,
    rew,
    fin,
    actor,
    env,
    hx,
    /,
    n_steps,
    n_envs,
    *,
    pinned=False,
    shared=False,
    device=None,
):
    # we get an example data for one time-step and a single
    #  environment and replicate it across time and envs
    pass

    # ensure float32 and bool data types for `rew_` and `fin_`, respectively,
    # while leaving `obs_` and `act_` intact as they are nested containers of
    # numpy arrays or scalars with proper dtypes.
    rew_, fin_, stepno_ = numpy.float32(rew), bool(fin), numpy.int64(0)

    # the buffers for the actor's info and state are `(1 + n_steps) x n_envs x ...`
    #  while for the env we allocate an `n_steps x n_envs x ...` buffer
    state, actor = torchify(
        (State(stepno_, obs, act, rew_, fin_), actor),
        1 + n_steps,
        n_envs,
        shared=shared,
        pinned=pinned,
    )
    env = torchify(env, n_steps, n_envs, shared=shared, pinned=pinned)

    # the hidden recurrent state has to be replicated manually
    hx = torchify(hx, shared=shared, pinned=pinned)

    return Fragment(state=state, actor=actor, env=env, hx=hx)

In [None]:
out, hx = mod(fragment.state.obs)  # , fin=fragment.state.fin, hx=None)

In [None]:
result = postproc(fragment.state.obs, hx, **out)

In [None]:
result[1]

<br>

### Basic automatic actions

In [None]:
from helper import render, get_logger

logger = get_logger("debug.txt")

Detect in-game events based on the response to the user

In [None]:
def detect(obs):
    """Detect and dispatch events."""
    while is_menu(obs):
        obx, rew, done, info = yield " "

Ensure that the game does not prompt for anything.

In [None]:
def flush(obs):
    yield "\033\033\033"

search and pass in-game time

In [None]:
def search(obs=None, n=10):
    """Search around the player's current location for several turns."""
    # issuing a search command is an atomic operation
    obx, rew, done, info = yield (f"{n:d}s" if n > 1 else "s")

    # which however may cause the game to bring up some state messages
    yield from detect(obs)

search and pass in-game time

In [None]:
from numpy.random import default_rng


def randomstep(obs=None, rng=default_rng(42)):
    dirs = rng.integers(9, size=100)
    path = "".join("ykuh.lbjn"[j] for j in dirs)

    # assume a good path that does not bump into walls
    logger.info(f"randomstep:: `{path}`")

    # yield singleton actions
    for s in path:
        obs, rew, done, info = yield s

    yield from detect(obs)

An end-of-everything action.

In [None]:
def quit(obs):
    yield from flush(obs)
    yield "#quit\015y"

The startup character analyzer

In [None]:
class Startup:
    def __init__(self):
        self.data = []

    def analyze(self, obs):
        screen = b"\n".join(obs["tty_chars"].view("S80")[:, 0]).decode("ascii")
        message = obs["message"].view("S256")[0].decode("ascii")
        self.data.append((screen, message))
        logger.info(f"Startup:: screen\n{' tty ':=^80}\n{screen}\n{'':=<80}")
        logger.info(f"Startup:: message `{message}`")

    def __call__(self, obs):
        # initial screen (hopefully)
        self.analyze(obs)
        # disable autopickup and request character overview
        obs, rew, done, info = yield "@\x18"
        while is_menu(obs):
            self.analyze(obs)
            obs, rew, done, info = yield " "

In [None]:
class ActionInspector(gym.Wrapper):
    def reset(self):
        self.history = []
        return super().reset()

    def step(self, action):
        self.history.append(action)
        return super().step(action)

    def __repr__(self):
        return "`" + "".join(env.history) + "`"

    __str__ = __repr__

In [None]:
from nle_toolbox.bot.legacy.genfun import is_suspended
from nle_toolbox.bot.legacy.option import Continue, Preempt

options = dict(
    startup=Startup(),
    #     quit=quit,
    search=search,
    detect=detect,
    randomstep=randomstep,
)


def preempt(obs, current):
    return False


def ok(obs, current):
    return True


def supervisor(obs, *, rng):
    current, done = "startup", False
    while not done:
        # schedule an option
        option = options[current](obs)
        obs, rew, done, info = yield option

        # preemption loop
        while is_suspended(option):
            # if it is a menu and it is ok to interject the current action
            if is_menu(obs) and ok(obs, current):
                # interject with a menu-gobbling option: unlike sending a new
                # interjecting suspends the current option and delegates control
                # to the alternative option until it terminates.
                obs, rew, done, info = yield Preempt(detect(obs))

            elif not preempt(obs, current):
                # there are no emergencies to deal with, so we can continue
                # with the current option
                obs, rew, done, info = yield Continue

        # pick the next option, since none is currently running
        current = rng.choice(list(options))

In [None]:
import time
from numpy.random import default_rng

from nle_toolbox.bot.legacy.option import InterruptibleOptionWrapper as OptionWrapper
from nle_toolbox.bot.legacy.option import OptionWrapper

seed = (12513325507677477210, 18325590921330101247)
with OptionWrapper(Inspector(factory()), reduce=sum, allow_empty=True) as env:
    seed = env.seed(seed)
    logger.info("python -m nle_toolbox.utils.play.one --seed" f" {seed[0]} {seed[1]}")

    rewards = []

    tick = time.monotonic_ns()

    tx = None
    obs, rew, done, info = env.reset(), 0.0, False, {}

    sup = supervisor(obs, rng=default_rng(14426138988763310091))
    while render(env, obs) and not done:
        tx = obs, rew, done, info = env.step(sup.send(tx))
        rewards.append(rew)

    tock = time.monotonic_ns()

print((tock - tick) / 1e9)

In [None]:
len(env.env.history)

In [None]:
render(env, obs)

In [None]:
sum(rewards)

In [None]:
set(map(type, map(sum, rewards)))

In [None]:
[j for j, l in enumerate(map(len, rewards)) if l == 0]

In [None]:
rewards[:10]

<br>

A proto option policy (neural)

In [None]:
def option(core, obs, lookahead=False):
    """An option `core` starting at `obs`.

    Details
    -------
    Sutton, R.S., Precup, D., Singh, S. (1999) suggest the following:

      At s_t the option (\pi, \beta), given the opportunity, executes
      the transition
       s_t, a_t -->> s_{t+1}, x_{t+1}, r_{t+1}
      with a_t = \pi(x_t) and then decides to halt w. prob.
      \beta_{t+1} = \beta(x_{t+1}).

    this corresponds to `lookahead = False`.

    The option `core` is a callable returning the action a_t, the halting
    probability \beta_t, and the next context `hx` for the observation x_t
    and current context `hx`. The core is designed in such a way as to
    keep all the relevant runtime data in an external context `hx`.

    They also mention interrupted options, which we are work-in-progress atm.
    """

    # init the flags and the core's context `hx`
    done, halt, hx = False, False, None
    if lookahead:
        # get a_t, \beta_t = \pi(x_t) and execute a transition
        #  s_t, a_t -->> s_{t+1}, x_{t+1}, r_{t+1} w. prob. \beta_t
        # and halt otherwise.
        # XXX potentially empty options
        act, halt, hx = core(obs, hx=hx)  # a_0, \beta_0 = \pi(x_0)
        while not (done or halt):  # for t \geq 0
            obx, rew, done, info = yield act  # send a_t, recv x_{t+1}
            act, halt, hx = core(obs, hx=hx)  # a_{t+1}, \beta_{t+1} = \pi(x_{t+1})
        # XXX a_t is enacted w. prob. \beta_t

    else:
        # get a_t, \beta_t = \pi(x_t), execute a transition
        #    s_t, a_t -->> s_{t+1}, x_{t+1}, r_{t+1},
        # then halt w. prob. \beta_t, regardless of the new x_{t+1}.
        while not (done or halt):  # for t \geq 0
            act, halt, hx = core(obs, hx=hx)  # a_t, \beta_t = \pi(x_t)
            obx, rew, done, info = yield act  # send a_t, recv x_{t+1}
        # XXX a_t is enacted w. prob. \beta_{t-1}

    # nothing to comm through StopIteration
    return

a simple neural net for random actions

In [None]:
import torch
from rlplay.engine.utils.shared import torchify
from rlplay.utils.common import multinomial

from nle_toolbox.bot.model.blstats import BLStatsVitalsEmbedding


class SirvivalWalk(torch.nn.Module):
    _actions = "ykuh.lbjn"

    def __init__(self, vitals):
        super().__init__()
        self.vitals = vitals

        self.core = torch.nn.GRU(128, 64, 2)
        self.control = torch.nn.Sequential(
            torch.nn.ReLU(),
            torch.nn.Linear(64, 9 + 1),
        )

        # our experience buffer
        self.buffer = []

    def forward(self, obs, hx=None):
        out, hx = self.core(self.vitals(obs["blstats"]), hx=hx)

        logits = self.control(out)
        l_act, l_halt = torch.split(logits, (9, 1), dim=-1)

        # non-diffable stuff
        act = multinomial(l_act.softmax(-1))
        tau = torch.rand_like(l_halt).logit_()  # logistic r.v.
        return (
            act,
            l_halt.ge(tau),
            hx,
            dict(
                l_act=l_act,
                l_halt=l_halt,
            ),
        )

    @torch.no_grad()
    def run(self, obs, hx=None):
        logger.info(f"SirvivalWalk:: start")
        done, halt, probs = False, False, []
        while not (done or bool(halt)):
            act, halt, hx, a_info = self(torchify(obs, 1, 1), hx=hx)

            act = int(act)
            obs_, rew, done, e_info = yield self._actions[act]

            p_halt = a_info["l_halt"].sigmoid()
            probs.append(float(p_halt))
            self.buffer.append((obs["blstats"], act, obs_["blstats"]))

            obs = obs_

        # nothing to comm through StopIteration
        logger.info(f"SirvivalWalk:: {probs}")
        return

    def train(self):
        pass

<br>

In [None]:
seeds = [
    #     (5009195464289726085, 12625175316870653325),
    #     (7002570039340100249, 14426138988763310091),
    #     (14278027783296323177, 11038440290352864458),
    #     (14046273391210721807, 3865099148830813988),
    #     (18386314338156462112, 4255575630009817530)
    None
]


vitals = BLStatsVitalsEmbedding(128)
surv = SirvivalWalk(vitals)

for seed in seeds:
    with Driver(
        factory(),
        startup=Startup(),
        quit=quit,
        search=search,
        #         randomstep=surv.run
    ) as env:
        seed = env.seed(seed)
        logger.info(
            "python -m nle_toolbox.utils.play.one --seed" f" {seed[0]} {seed[1]}"
        )

        obs, done, rewards = env.reset(), False, []
        while render(env, obs) and not done:
            if preempt(obs, current):
                action = current = schedule(obs)
            else:
                action = env.Continue

            obs, rew, done, info = env.step(action, n_slice=10)
            rewards.append(rew)
            time.sleep(0.1)

In [None]:
surv.buffer

In [None]:
import pdb

pdb.pm()

In [None]:
class Network(nn.Module):
    def __init__(self, recipe):
        recipe = deepcopy(recipe)
        super().__init__()

        # the embedders
        self.glyphs = GlyphFeatures(**recipe["glyphs"])
        self.bls = BLStatsEmbedding(**recipe["bls"])

        # build the feature network
        n_features = recipe["features"][-1]
        layers = [
            ModuleDict(
                dict(
                    vicinity=nn.Flatten(-3, -1),
                    inventory=nn.Flatten(-2, -1),
                    bls=nn.Identity(),
                ),
                dim=-1,
            )
        ]
        for n, m in zip(recipe["features"], recipe["features"][1:]):
            layers.append(nn.Linear(n, m, bias=True))
            layers.append(nn.ReLU())

        self.features = nn.Sequential(*layers)

        # the core is either absent or a GRU
        core = nn.GRU if "core" in recipe else nn.Identity
        self.core = core(n_features, n_features, 1, bias=True, batch_first=False)

        # construct h0 for the core
        shape, h0 = hx_shape(self.core), None
        if isinstance(shape, torch.Size):
            h0 = torch.nn.Parameter(torch.zeros(*shape))

        elif isinstance(shape, tuple):
            h0 = torch.nn.ParameterList(
                [torch.nn.Parameter(torch.zeros(*s)) for s in shape]
            )
        self.register_parameter("h0", h0)

        # halting probability logit, action prelogits, and the critic's value
        self.head = LinearSplitter(n_features, dict(val=1, hlt=1, **recipe["heads"]))

    def forward(self, obs, fin=None, hx=None):
        # embed and compute features
        features = self.features(
            {
                **self.glyphs(obs),
                "bls": self.bls(obs),
            }
        )

        # prepare the hiddens: we keep an explicit `h0`, in the case
        # it is diff-able and learnable
        n_seq, n_batch = features.shape[:2]
        h0 = hx_broadcast(self.h0, n_batch)
        hx = h0 if hx is None else hx

        # run the core
        out = features
        if not isinstance(self.core, nn.Identity):
            outputs = []
            # in case fin is missing, create a never-resetting mask
            if fin is None:
                fin = torch.zeros(n_seq, n_batch, dtype=bool, device=features.device)

            # make sure the termination mask is on-device and numeric
            fin = fin.unsqueeze(-1).to(features)
            for x, f in zip(features.unsqueeze(1), fin.unsqueeze(1)):
                # reset the hiddens by lerping with `fin`: hx = hx * (1 - f) + h0 * f
                out, hx = self.core(x, hx=plyr.suply(torch.lerp, hx, h0, weight=f))
                outputs.append(out)

            out = torch.cat(outputs, dim=0)

        return self.head(out), hx

<br>