In [None]:
import numpy
import torch

import matplotlib.pyplot as plt

## Tinkering with envs, rendering and ui control

In [None]:
import gym
from gym import ObservationWrapper

from numpy.random import default_rng

human mode rendered that measures the wall-time it took to draw.

In [None]:
import time

def render(env):
    """Render and return the time (in ms.) it took, or
    zero if the rendered is closed.
    """
    mark = time.monotonic()
    if not env.render(mode='human'):
        return 0.

    return time.monotonic() - mark

Simple Keyboard control for pyglet windows

In [None]:
from pyglet.window import key, Window

class SimpleUIControl:
    """A bare-bones keyboard event handler for pyglet UI."""

    action, pause, waiting, dream = None, False, False, False
    def __init__(self, keymap):
        self.KEYMAP = keymap

    def on_key_press(self, symbol, modifiers):
        if symbol == key.SPACE:
            self.pause = not self.pause
            return

        if symbol == key.P:
            self.dream = not self.dream
            return

        if symbol in self.KEYMAP and not self.waiting:
            self.action, self.waiting = symbol, True
            return

    def on_key_release(self, symbol, modifiers):
        if symbol in self.KEYMAP and self.waiting:
            self.action, self.waiting = None, False
            return
    
    def register(self, window):
        assert isinstance(window, Window)

        window.push_handlers(
            self.on_key_press, 
            self.on_key_release,
        )

        return self

### Random Disco Maze

In [None]:
from gym_discomaze.ext import ExploreRandomDiscoMaze

class DiscoMazeFactory:
    KEYMAP = dict(zip(
        [None, key.E, key.A, key.S, key.D, key.W],
        ['stay', 'stay', 'west', 'south', 'east', 'north'],
    ))

    def __call__(self, seed=None):
        env = ExploreRandomDiscoMaze(
            field=(2, 2), n_colors=5, generator=seed)
        env.KEYMAP = self.KEYMAP

        return env

factory = DiscoMazeFactory()

### Atari

### Gym-Minigrid

<br>

## Random Exploration

In [None]:
import tqdm

collect data for the VAE by random exploration

In [None]:
from rlplay.engine.rollout import same, multi, single
from rlplay.engine.core import BaseActorModule

class RandomActor(BaseActorModule):
    def __init__(self, env):
        super().__init__()
        self.action_space = env.action_space

    def step(
        self,
        stepno,
        obs,
        act,
        rew,
        fin,
        *,
        hx=None,
        virtual=False,
    ):
        if not virtual:
            act = torch.randint_like(act, self.action_space.n)
        return act, (), {}

In [None]:
env = factory()

**FKRY**

a code block in `shared.py#L119-122` (either `.empty`, if
replaced with `.zeros`, or `.copy_`) may freeze due to waiting
on a mutex inside torch's `parallel_for`. This can be avoided
if we set `torch.set_num_threads(1)` in the main process.

```python
torch.set_num_threads(1)  # in case of unexpected hangups
```
* maybe the funky behavior is due to the `fork` method...

In [None]:
# rndgen = same.rollout(
#     [factory(_) for _ in range(12)],
#     RandomActor(env),
#     n_steps=1,
# )

n_actions = env.action_space.n

rndgen = multi.rollout(
    factory,
    RandomActor(env),
    n_steps=8,      # n_steps=8,
    n_per_actor=5,  # minigrid: 2, discomaze: 5
    n_buffers=12,   # n_buffers=12,
    n_actors=4,     # n_actors=3,
    n_per_batch=2,  # minigrid: 1, discomaze: 2
    start_method='fork',
    entropy=None,
)

# T, B = 8, 6

In [None]:
def to_tensor(obs):
    # `.div` makes a copy for sure
    return obs.permute(0, 1, 4, 2, 3).div(255)

In [None]:

ref = []
for frag, j in zip(rndgen, tqdm.tqdm(range(5))):
    obs = to_tensor(frag.state.obs)
    ref.append(obs.flatten(0, 1))

ref = torch.cat(ref)

In [None]:
plt.imshow(ref.std(0).permute(1, 2, 0))
plt.gca().set_axis_off()

<br>

## Models

In [None]:
import torch.nn.functional as F

from torch.nn import Conv2d, Embedding, Linear, Sequential, LeakyReLU, Flatten, GRU
from torch.nn import ConvTranspose2d


In [None]:
from torch.nn import init

def init_random_eye_embed(layer):
    assert isinstance(layer, torch.nn.Embedding)  # guard
    n_embed, n_dim = layer.weight.shape
    layer.weight.data.normal_(std=1e-1)
    init.eye_(layer.weight[:, :n_embed])

Modules for vbayes

In [None]:
from torch.nn import ReLU
from rlplay.zoo.models.vae import Reshape

from rlplay.zoo.models.vae import AsIndependentGaussian
from rlplay.zoo.models.vae import AsIndependentBernoulli
from rlplay.zoo.models.vae import AsIndependentContinuousBernoulli

How to force the variance of the Gaussian decoder to one?
```python
dec = AsIndependentGaussian(
    ...,
    fn_loc=torch.sigmoid,
    fn_scale = torch.ones_like,
)
```

### DiscoMaze vae

In [None]:
n_latent = 128

enc = AsIndependentGaussian(
    Sequential(
        Conv2d(3, 64, 3),
        ReLU(),
        Conv2d(64, 128, 3),
        ReLU(),
        Flatten(-3, -1),
        Linear(128, 2 * n_latent)
    ),
    n_dim_in=3,
    n_dim_out=1,
)

dec = AsIndependentContinuousBernoulli(
    Sequential(
        Linear(n_latent, 128),
        ReLU(),
        Reshape((128, 1, 1), start_dim=-1, end_dim=-1),
        ConvTranspose2d(128, 64, 3),
        ReLU(),
        ConvTranspose2d(64, 3, 3),
    ),
    n_dim_in=1,
    n_dim_out=3,
    # validate_args=False,
)

### Atari VAE

### minigrid VAE

In [None]:
class DynGRU(torch.nn.Module):
    def __init__(self, n_blocks=None, n_hid=32, n_act=3):
        super().__init__()
        
        self.embed = torch.nn.Embedding(
            n_act,
            4 * n_act,
            max_norm=1.,
        )
        init_random_eye_embed(self.embed)

        self.transform = torch.nn.GRUCell(
            4 * n_act,
            2 * n_hid,
        )
        self.output = Sequential(
            torch.nn.Tanh(),
            Linear(2 * n_hid, 2 * n_hid),
        )

    def forward(self, z, act):
        hid = torch.cat([z]*2, dim=-1)
        return self.output(self.transform(self.embed(act), hid))

In [None]:
dyn = AsIndependentGaussian(
    DynGRU(
        n_blocks=1,
        n_hid=n_latent,
        n_act=n_actions,
    ),
    n_dim_in=1,
    n_dim_out=1,
)

<br>

### Train

In [None]:
from plyr import apply, suply, xgetitem

def timeshift(state, *, shift=1):
    """Get current and shifted slices of nested objects."""
    # use `xgetitem` to let None through
    # XXX `curr[t]` = (x_t, a_{t-1}, r_t, d_t), t=0..T-H
    curr = suply(xgetitem, state, index=slice(None, -shift))

    # XXX `next[t]` = (x_{t+H}, a_{t+H-1}, r_{t+H}, d_{t+H}), t=0..T-H
    next = suply(xgetitem, state, index=slice(shift, None))

    return curr, next

In [None]:
def get_data(fragment):
    st, sn = timeshift(fragment.state)
    return suply(
        torch.flatten,
        (
            to_tensor(st.obs),
            sn.act.clone(),  # force a clone, since __getitem__ returns a view!
            to_tensor(sn.obs),
        ),
        start_dim=0,
        end_dim=1
    )

prefetch

In [None]:
import random

def buffered_shuffle(buffer, *obs, n_capacity=None):
    n_capacity = n_capacity or len(buffer)

    evicted = []
    for x in zip(*obs):
        if len(buffer) >= n_capacity:
            ix = random.randint(0, n_capacity - 1)
            evicted.append(buffer[ix])
            buffer[ix] = x

        else:
            buffer.append(x)
    
#     if not evicted:
#         random.shuffle(buffer)
#         evicted = buffer[:]
#         buffer.clear()

    return evicted


# experience replay buffer
buffer, n_capacity = [], 512
for frag in rndgen:
    # flatten the fragment's temporal and batch dims
    st, act, sn = get_data(frag)

    batch = buffered_shuffle(
        buffer,
        st, act, sn,
        n_capacity=n_capacity,
    )

    # abort on full buffer
    if batch:
        break

train

In [None]:
import torch.distributions as dist

from rlplay.zoo.models.vae import vbayes

In [None]:
jn, j0, j1 = 10000, 500, 3000
e0, e1, lo = 1500, 9500, 1e-2

# jn, j0, j1 = 1000, 50, 300
# e0, e1, lo = 150, 950, 1e-2
# jn, j0, j1 = 5000, 250, 1500
# e0, e1, lo = 750, 4750, 1e-2

optim = torch.optim.Adam([
    *enc.parameters(),
    *dyn.parameters(),
    *dec.parameters(),
], lr=1e-3, weight_decay=1e-5)

# sched = None

sched = torch.optim.lr_scheduler.LambdaLR(
    optim, lr_lambda=lambda e: min(1., max(lo, (e1 - e) / (e1 - e0))),
)

In [None]:
from math import isnan
from torch.utils.data._utils.collate import default_collate

from torch.distributions import kl_divergence as KL

losses = []
for frag, j in zip(rndgen, tqdm.tqdm(range(jn))):
    st, act, sn = default_collate(
        buffered_shuffle(buffer, *get_data(frag))
    )

    # learn vae
    # XXX it is possible that torch optimizes away mul-by-zero autograd paths
    beta = min(1., max(0., (j - j0) / (j1 - j0)))
    C = 0.25

    if False:
        # classic VAE neg-elbo via one-sample SGVB
        pi = enc.prior(st)  # \pi = p(z)
        qt = enc(st)        # q_t = q(z_t \mid x_t)
        zt = qt.rsample()   #   draw > z_t \sim q_t
        pt = dec(zt)        # p_t = p(x_t \mid z_t)
      
        llt = pt.log_prob(st).mean()
        klt = KL(qt, pi).mean()
        llh = 0.
        klh = 0.
#         loss = vbayes(
#             enc,
#             dec,
#             st,
#             beta=beta,
#             n_draws=1,
#             iwae=False,
#         )

    else:
        # neg-ELBO from Embed2Control
        # XXX don't forget that it is __R__sample, you eed-yot!
        # XXX see eq. (8) of Watter at al. (2015) for `qh`
        pi = enc.prior(st)  # \pi = p(z)
        qt = enc(st)        # q_t = q(z_t \mid x_t)
        zt = qt.rsample()   # draw -->> z_t \sim q_t
        pt = dec(zt)        # p_t = p(x_t \mid z_t)
        # XXX zt.detach() below?
        qh = dyn(zt, act)   # \hat{q}_{t+1} = r(z_{t+1} \mid z_t, a_t)
        zh = qh.rsample()   # draw -->> \hat{z}_{t+1} \sim \hat{q}_{t+1}
        ph = dec(zh)        # \hat{p}_{t+1} = p(x_t \mid \hat{z}_{t+1})
        qn = enc(sn)        # q_{t+1} = q(z_{t+1} \mid x_{t+1})

        # \log p_t(x_t) + \log \hat{p}_{t+1}(x_{t+1})
        llt = pt.log_prob(st).mean()
        llh = ph.log_prob(sn).mean()

        # KL(q_t \| \pi) + \lambda KL(\hat{q}_{t+1} \| q_{t+1})
        klt = KL(qt, pi).mean()
        klh = KL(qh, qn).mean()

    loss = beta * (klt + C * klh) - llh - llt

    optim.zero_grad()
    loss.backward()
    optim.step()

    with torch.no_grad():
        loglik = dec(enc(ref).sample()).log_prob(ref).mean()

    losses.append((
        -float(loss),
        float(loglik),
        float(llt),
        float(llh),
        float(klt),
        float(klh),
    ))
    
    if sched is not None:
        sched.step()
    
    if isnan(float(loss)):
        raise FloatingPointError

In [None]:
train, test, llt, llh, klt, klh = zip(*losses)
plt.plot(llt)
plt.plot(llh)
plt.plot(test)

reload an older model

display the reference set reconstruction errors

In [None]:
from rlplay.utils.plotting.grid import make_grid as _make_grid

def make_grid(tensor):
    return _make_grid(
        tensor,
        aspect=(1, 1),
        pixel=(1, 1),
        pad=(0, 0),
        normalize=False,
    )

In [None]:
with torch.no_grad():
    rec = dec(enc(ref).sample()).mean

err = F.mse_loss(rec, ref, reduction='none').mean(1)

fig, ax = plt.subplots(1, 1, figsize=(3, 3), dpi=240)
ax.imshow(make_grid(err))

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(3, 5), dpi=240, sharey=True, sharex=True)
ax[0].imshow(ref[3].permute(1, 2, 0))
ax[1].imshow(rec[3].permute(1, 2, 0))

<br>

A gui-play loop, with extra reconstruction renderer.

In [None]:
@torch.no_grad()
def get_iv0(obs, sample=True):
    # get the initial internal vector
    x = torch.tensor(obs).permute(2, 0, 1).div(255)
    q = enc(x.unsqueeze(0))
    return q.sample() if sample else q.mean

@torch.no_grad()
def next_iv(zed, act, sample=True):
    # get the next internal vector via one-step dynamics
    r = dyn(zed, act)
    return r.sample() if sample else r.mean

def play(env, ctrl, viewer=None, sample=True):
    obs, fin, act = env.reset(), False, env.action_space.sample()

    z0 = get_iv0(obs, sample=sample)
    zt = next_iv(z0, torch.tensor([act]), sample=sample)
    while not fin:
        if viewer is not None:
            with torch.no_grad():
                x = torch.tensor(obs).permute(2, 0, 1).div(255)
                hat_x = dec(z0).mean
                rec = hat_x[0].permute(1, 2, 0).mul(255).byte().squeeze(-1)
            
                xp1 = dec(zt).mean
                nxt = xp1[0].permute(1, 2, 0).mul(255).byte().squeeze(-1)
            
            if isinstance(viewer, MultiViewer):
                viewer['curr'].imshow(rec.numpy())
                viewer['next'].imshow(nxt.numpy())
            else:
                viewer.imshow(rec.numpy())

        if ctrl.action is not None:  # makes atari turn-based!
            act = env.named_actions[env.KEYMAP[ctrl.action]]
            ctrl.action = None

            if not ctrl.dream:
                obs, rew, fin, info = env.step(act)
                # reset the latent dynamics
                zt = z0 = get_iv0(obs, sample=sample)
            # advance the latent dynamics state
            zt = next_iv(zt, torch.tensor([act]), sample=sample)

        if fin:  # pause on termination
            ctrl.pause = True

        # rendering and UI event loop
        while render(env) > 0:
            time.sleep(0.04)
            if not ctrl.pause:
                break

        else:
            return False

    return True

The loop

In [None]:
# for gym_discomaze
env = factory(123)

ctrl = SimpleUIControl(env.KEYMAP)

render(env)  # sets up the viewer gui, so that the next line works
ctrl.register(env.unwrapped.viewer.window)

from rlplay.utils.plotting import ImageViewer, MultiViewer

with MultiViewer(scale=(5, 5)) as mvw:
    vw1 = mvw.get('curr', r'\hat{s}_t')
    vw2 = mvw.get('next', r'\hat{s}_{t+1}')
    while play(env, ctrl, mvw):
        pass

In [None]:
while env.render('human'):
    pass

In [None]:
assert False

```python
start = torch.arange(1., 5.)
end = torch.empty(4).fill_(10)
torch.lerp(start, end, torch.full_like(start, 0.5))
```

$$
\mathbb{H}
    \mathcal{N}_d(\mu, \Sigma)
        = \frac12 \log \det 2 \pi e \Sigma
    \,. $$

$$
\operatorname{KL}\bigl(
    \mathcal{N}_d(\mu_0, \Sigma_0)
    \| \mathcal{N}_d(\mu_1, \Sigma_1)
\bigr)
    = \frac12 \biggl\{
        \operatorname{tr}\Bigl(
            \Sigma_1^{-1} \Sigma_0
            + \Sigma_1^{-1} ( \mu_1 - \mu_0) ( \mu_1 - \mu_0)^\top
        \Bigr)
        + \log \frac{ \det 2 \pi \Sigma_1 }{ \det 2 \pi e \Sigma_0 }
    \biggr\},
    \,. $$

$$
f(x)
    = \frac1{\sqrt{2\pi \sigma^2}} e^{-\frac{(x-\mu)^2}{2\sigma^2}}
    \,. $$

$X \sim p$ and $Y = g(X)$, then
$$
p_Y(y)
    = \frac{d}{dy} \mathbb{P}(Y\leq y)
    = f(g^{-1}(y)) \frac{d}{dy} g^{-1}(y)
    = \frac1{g'(x)} p_X(x) \big\vert_{x=g^{-1}(y)}
    \,, $$
thus
$$
\log p_Y(y)
    = \frac{d}{dy} \mathbb{P}(Y\leq y)
    = f(g^{-1}(y)) \frac{d}{dy} g^{-1}(y)
    = \log p_X(g^{-1}(y)) - \log{g'(g^{-1}(y))}
    \,. $$

If $X \sim p$ and $Y = g(X)$, then
$$
p_Y(y) dy
    = f(g^{-1}(y)) \bigl\lvert \det J_{g^{-1}}(y) \bigr\rvert dy
    % = f(g^{-1}(y)) \bigl\lvert \det (J_g(g^{-1}(y)))^{-1} \bigr\rvert dy
    = f(g^{-1}(y)) \bigl\lvert \det J_g(g^{-1}(y)) \bigr\rvert^{-1} dy
    \,, $$
thus
$$
\log p_Y(y)
    = \frac{d}{dy} \mathbb{P}(Y\leq y)
    = f(g^{-1}(y)) \frac{d}{dy} g^{-1}(y)
    = \log p_X(g^{-1}(y)) - \log{g'(g^{-1}(y))}
    \,. $$

Hence
$$
\int_A p_X(x) dx
    = \mathbb{P}(X \in A)
    = \mathbb{P}(Y \in g(A))
    = \int_{g(A)} p_Y(y) dy
    = \int_A p_Y(g(x)) \bigl\lvert \det J_g(x) \bigr\rvert dx
    \,. $$

almost surely
$$
p_X(x)
    = p_Y(g(x)) \bigl\lvert \det J_g(x) \bigr\rvert
    \,. $$

if $x_k = f_k(x_{k-1})$ and $f_k$ is a diffeomorphism on the support, then
(for $g=f_k^{-1}$)
$$
p_k(x_k)
    = p_{k-1}(x_{k-1}) \,
    \biggl\lvert
        \det \frac{\partial f_k(x_{k-1})}{\partial x_{k-1}}
    \biggr\rvert^{-1}
    \bigg\vert_{x_{k-1} = f_k^{-1}(x_k)}
    \,. $$
[more details on NF](http://akosiorek.github.io/ml/2018/04/03/norm_flows.html)

also this [good VAE recap](https://www.borealisai.com/en/blog/tutorial-5-variational-auto-encoders/)

if $x_k = f_k(x_{k-1})$ and $f_k$ is a diffeomorphism on the support, then
(for $g=f_k^{-1}$)
$$
p_k(x_k)
    = p_{k-1}(x_{k-1}) \,
    \biggl\lvert
        \det \frac{\partial x_k}{\partial x_{k-1}}
    \biggr\rvert^{-1}
    \bigg\vert_{x_{k-1} = f_k^{-1}(x_k)}
    \,. $$
with $
    \frac{\partial x_k}{\partial x_{k-1}}
        \equiv \frac{\partial x_k^\top}{\partial x_{k-1}}
$.

<br>

<br>