# `rlplay`-ing with world models

In [None]:
import torch
import numpy

import matplotlib.pyplot as plt
%matplotlib inline

import gym

# hotfix for gym's unresponsive viz (spawns gl threads!)
import rlplay.utils.integration.gym

See example.ipynb for the overview of `rlplay`

<br>

A base class for deep gaussian networks, taken from a prior project on Deep Weight Prior.

In [None]:
from torch.nn.functional import softplus
from torch.distributions import Normal, Independent


class BaseDeepIndependentGaussian(torch.nn.Module):
    def __init__(self, input_shape, event_shape):
        super().__init__()
        self.input_shape = torch.Size(input_shape)
        self.event_shape = torch.Size(event_shape)

        # zero and one for the std Gaussian prior
        # XXX keep as a buffer for device sync
        self.register_buffer('nilone', torch.tensor([0., 1.]))
    
    @property
    def prior(self):
        shape = self.event_shape

        # kl-std normal: factorized std gaussian prior
        return Independent(Normal(*self.nilone).expand(shape), len(shape))

    def forward(self, input):
        n_dim_input, n_dim_event = len(self.input_shape), len(self.event_shape)
        assert input.shape[-n_dim_input:] == self.input_shape

        # flatten the composite batch dim, keeping feature dims intact
        # XXX the final layer must have twice the number of channels/features for chunking
        output = self.features(input.flatten(0, -n_dim_input-1))

        # get location and scale with original batch dims (doubled features)
        output = output.reshape(*input.shape[:-n_dim_input], *output.shape[1:])
        loc, logscale = torch.chunk(output, 2, dim=-n_dim_event)
        # assert self.event_shape == loc.shape[-n_dim_event:]

        return Independent(Normal(loc, softplus(logscale)), n_dim_event)

A possible redesign of the indep gaussian class

In [None]:
# ENH wrap, not inherit!

# cast a dual-output network as a factorized gaussian
class AsIndependentGaussian(torch.nn.Module):
    def __init__(
        self,
        module,             # dims In: (B*)CS* -->> dims Out: \1FS*  # re syntax
        n_dim_in=1,         # determines the number of trailing dims designated as input features
        n_dim_out=1,        # the number of trailing dims allotted to a single random draw (event_size)
        prior=None,         # The prior associated with this Gaussian, standard if None
        batch_first=None,   # bool, the order of batch and sequence dims for recurrent nets
                            # XXX might be very awkward to implement...
        scale_fn=softplus,  # The transformation to apply to the output desiganted as scale
    ):
        assert n_dim_out >= 1
        assert batch_first is None

        super().__init__()
        self.n_dim_in, self.n_dim_out = n_dim_in, n_dim_out
        self.scale_fn, self.batch_first = scale_fn, batch_first

        # construct the standard Gaussian prior as default
        if prior is None:
            # zero and one for the prior, kept as a buffer for sync
            self.register_buffer('nilone', torch.tensor([0., 1.]))

        else:
            self.prior = prior

        self.module = module

    @property
    def prior(self):
        # the default standard gaussian prior, can be overridden
        nd = self.n_dim_out
        nilone = self.nilone.reshape(2, *(1,) * nd)
        return Independent(Normal(*nilone), nd)

    def forward(self, input):
        # If network outputs a tuple, then interpret its as a ready loc-scale pair.
        # Otherwise split the tensor in half along the correct trailing dim.

        assert self.batch_first is None
        if isinstance(input, torch.Tensor) and self.n_dim_in is not None:
            # we need this to assign proper batch dims to the returned
            # distribution object, and to make sure not to confuse events dims.
            assert self.n_dim_in >= 1

            # flatten the batch dims, keeping feature dims intact, then undo it on the output
            out = self.module(input.reshape(-1, *input.shape[-self.n_dim_in:]))
            out = out.reshape(*input.shape[:-self.n_dim_in],
                              *out.shape[-self.n_dim_out:])

            # the output's correct dim must have even size!
            loc, scale_ = torch.chunk(out, 2, dim=-self.n_dim_out)

            # apply the +ve valued monotonic transformation to the prescale
            input = loc, self.scale_fn(scale_)

        # cannot auto-infer the event dim, rely on the `n_dim_out` parameter
        # assume the shapes are proper
        return Independent(Normal(*input), self.n_dim_out)

In [None]:
pi = AsIndependentGaussian(torch.nn.Linear(32, 8*2))

enc = AsIndependentGaussian(torch.nn.Linear(32, 8*2))

In [None]:
x = torch.randn(1, 2, 3, 32)

In [None]:
q = enc(x)

In [None]:
pi = enc.prior(x)

In [None]:
dist.kl_divergence(q, pi)

A generic $
    (q(z \mid x), \pi(z), p(y \mid z))
$ loss for var-Bayes with an explicit prior and SGVB or IWAE

In [None]:
from math import log
import torch.distributions as dist

def vbayes(enc, dec, /, X, Y=None, *, prior=None, beta=1., n_draws=1, iwae=False):
    """Compute the SGVB or the IWAE objective.
    enc is the approximate posterior q(z \mid x)
    dec is the approximate model p(y \mid z)
    
    See the supplementary material of

        [Bachman and Precup (2015)](http://proceedings.mlr.press/v37/bachman15.html)

    for some brief but clear discussion of what turns out to be
    the idea below (variational trasncoder if X \neq Y).
    """
    pi = enc.prior if prior is None else prior
    Y = X if Y is None else Y  # auto-encode is Y is not X

    # `X=Y` is `*batch x *dec.event_shape`
    q = enc(X)  # q.batch_shape is `batch`

    # `Z` is `n_draws x *q.batch_shape x *q.event_shape`
    Z = q.rsample([n_draws])  # XXX diff-able sampling with (implicit) rep-trick!

    # `log_p` has shape `n_draws x *q.batch_shape x *q.event_shape`
    log_p = dec(Z).log_prob(Y)  # XXX may consume a lot of mem!

    ll = log_p.mean()  # dim=(0, 1)
    if iwae and n_draws > 1:
        # (iwae)_k = E_{x y} E_{S~q^k(z|x)} log E_{z~S} p(y|z) pi(z) / q(z|x)
        #  * like (sgvb)_k but E_z and log are interchanged
        #  * beta-anneal the pi(z) / q(z|x) ratio
        log_iw = log_p + (pi.log_prob(Z) - q.log_prob(Z)) * beta
        loss = log(n_draws) - torch.logsumexp(log_iw, dim=0).mean()  # dim=0

    else:
        # (sgvb)_k = E_{x y} E_{S~q^k(z|x)} E_{z~S} log p(y|z) pi(z) / q(z|x)
        kl_q_pi = dist.kl_divergence(q, pi)
        loss = beta * kl_q_pi.mean() - ll  # dim=0

    return loss, q, float(ll)

<br>

In [None]:
nilone = torch.tensor([0., 1.])
nilone = nilone.reshape(-1, *(1,)*1)
pi = Independent(Normal(*nilone), 1)

In [None]:
dist.kl_divergence(q, pi)

A simple function to collate a list of dicts into a dict of lists.

In [None]:
def collate(records):
    """collate identically keyed dicts"""
    out, n_records = {}, 0
    for record in records:
        for k, v in record.items():
            out.setdefault(k, []).append(v)
    
    return out

## Sophisticated CartPole with PG

### The environment

The environment factory

In [None]:
from rlplay.zoo.env import NarrowPath


class FP32Observation(gym.ObservationWrapper):
    def observation(self, observation):
        return observation.astype(numpy.float32)
#         obs = observation.astype(numpy.float32)
#         obs[0] = 0.  # mask the position info
#         return obs

#     def step(self, action):
#         obs, reward, done, info = super().step(action)
#         reward -= abs(obs[1]) / 10  # punish for non-zero speed
#         return obs, reward, done, info

class OneHotObservation(gym.ObservationWrapper):
    def observation(self, observation):
        return numpy.eye(1, self.env.observation_space.n,
                         k=observation, dtype=numpy.float32)[0]

def base_factory():
    return gym.make("LunarLander-v2")
#     return FP32Observation(gym.make("CartPole-v0").unwrapped)
    # return OneHotObservation(NarrowPath())

<br>

### the Actor

A procedure and a layer, which converts the input integer data into its
little endian binary representation as float $\{0, 1\}^m$ vectors.

In [None]:
def onehotbits(input, n_bits=63, dtype=torch.float):
    """Encode integers to fixed-width binary floating point vectors"""
    assert not input.dtype.is_floating_point
    assert 0 < n_bits < 64  # torch.int64 is signed, so 64-1 bits max

    # n_bits = {torch.int64: 63, torch.int32: 31, torch.int16: 15, torch.int8 : 7}

    # get mask of set bits
    pow2 = torch.tensor([1 << j for j in range(n_bits)]).to(input.device)
    x = input.unsqueeze(-1).bitwise_and(pow2).to(bool)

    # upcast bool to float to get one-hot
    return x.to(dtype)


class OneHotBits(torch.nn.Module):
    def __init__(self, n_bits=63, dtype=torch.float):
        assert 1 <= n_bits < 64
        super().__init__()
        self.n_bits, self.dtype = n_bits, dtype

    def forward(self, input):
        return onehotbits(input, n_bits=self.n_bits, dtype=self.dtype)

A special module dictionary, which aplies itself to the input dict of tensors

In [None]:
from typing import Optional, Mapping
from torch.nn import Module, ModuleDict as BaseModuleDict


class ModuleDict(BaseModuleDict):
    """The ModuleDict, that applies itself to hte indup dicts."""
    def __init__(
        self,
        modules: Optional[Mapping[str, Module]] = None,
        dim: Optional[int]=-1
    ) -> None:
        super().__init__(modules)
        self.dim = dim

    def forward(self, input):
        # enforce concatenation in the order of the declaration in  __init__
        return torch.cat([
            m(input[k]) for k, m in self.items()
        ], dim=self.dim)

A more sophisticated policy learner

In [None]:
from rlplay.engine import BaseActorModule
from rlplay.utils.common import multinomial

from torch.nn import Sequential, Linear, ReLU, LogSoftmax

class CartPoleActor(BaseActorModule):
    def __init__(self, lstm='none'):
        assert lstm in ('none', 'loop', 'cudnn')
        super().__init__()

        self.use_lstm = self.use_cudnn = False

        # blend the policy with a uniform distribution, determined by
        #  the exploration epsilon. We update it in the actor clones via a buffer
        # self.register_buffer('epsilon', torch.tensor(epsilon))
        # XXX isn't the stochastic policy random enough by itself?

        z_dim, a_dim = 8, 4
        self.baseline = Sequential(
            Linear(z_dim, 128),
            ReLU(),
            Linear(128, 1),
        )
        self.policy = Sequential(
            Linear(z_dim, 128),
            ReLU(),
            Linear(128, a_dim),
            LogSoftmax(dim=-1),
        )

    def forward(self, obs, act, rew, fin, *, hx=None, stepno=None, virtual=False):
        # value must not have any trailing dims, i.e. T x B
        logits = self.policy(obs)
        value = self.baseline(obs).squeeze(-1)

        if not self.training:
            actions = logits.argmax(dim=-1)

        else:
            actions = multinomial(logits.detach().exp())

        return actions, (), dict(value=value, logits=logits)

<br>

Modules for the WM

In [None]:
class DeepThreeLayerGaussian(BaseDeepIndependentGaussian):
    def __init__(self, dim_in=4, dim_out=2, h_dim=32):
        super().__init__([dim_in], [dim_out])
        self.features = torch.nn.Sequential(
            torch.nn.Linear(dim_in, h_dim),
            torch.nn.ELU(),
            torch.nn.Linear(h_dim, h_dim),
            torch.nn.ELU(),
            torch.nn.Linear(h_dim, 2 * dim_out),
        )

class Encoder:
    def __new__(self, x_dim=4, z_dim=2, h_dim=32):
        return DeepThreeLayerGaussian(x_dim, z_dim, h_dim=h_dim)

class Decoder:
    def __new__(self, z_dim=2, x_dim=4, h_dim=32):
        return DeepThreeLayerGaussian(z_dim, x_dim, h_dim=h_dim)

class Dynamics(DeepThreeLayerGaussian):
    def __init__(self, z_dim=2, a_dim=2, u_dim=4, h_dim=32):
        super().__init__(z_dim + u_dim, z_dim, h_dim=h_dim)
        self.act = torch.nn.Embedding(a_dim, u_dim)  # hardcoded
    
    def forward(self, zed, act, fin=None, hx=None):
        input = torch.cat([zed, self.act(act)], dim=-1)
        return super().forward(input), ()

In [None]:
from math import log

class WMCartPoleActor(BaseActorModule):
    def __init__(self, lstm='none', n_draws=1):
        assert lstm in ('none', 'loop', 'cudnn')
        super().__init__()

        self.use_lstm = self.use_cudnn = False

        x_dim, z_dim, a_dim = 8, 4, 4
        self.enc = Encoder(x_dim=x_dim, z_dim=z_dim)
        self.dec = Decoder(z_dim=z_dim, x_dim=x_dim)
#         self.dyn = Dynamics(z_dim=z_dim, a_dim=a_dim, u_dim=2)
        self.dyn = None

        self.baseline = Sequential(
            Linear(z_dim, 20),
            ReLU(),
            Linear(20, 1),
        )
        self.policy = Sequential(
            Linear(z_dim, 20),
            ReLU(),
            Linear(20, a_dim),
            LogSoftmax(dim=-1),
        )
        
        self.n_draws = n_draws
        self.x_dim, self.z_dim, self.a_dim = x_dim, z_dim, a_dim

    def forward(self, obs, act, rew, fin, *, hx=None, stepno=None, virtual=False):
        # diff-able pass through the encoder to nudge it towards
        #  task-meaningful abstractions! non-diffable pass yield poor
        #  evaluation performance
        Z = self.enc(obs).rsample([self.n_draws])
        
        # NO single-step foresight
#         for zx in Z:  # for each draw run a shallow mcts
#             path = [Node(0., zx, )]
#             for a in range(self.a_dim):
#                 pass
#         # n_draws x T x B x ...
#         for a_hyp in range(self.a_dim):
#             Z_hyp = self.dyn(Z, torch.full(Z.shape[:3], a_hyp))
#             value = self.baseline(Z_hyp).squeeze(-1)

        # value must not have any trailing dims, i.e. T x B
        logits = self.policy(Z).logsumexp(dim=0) - log(self.n_draws)
        value = self.baseline(Z).mean(dim=0).squeeze(-1)

        actions = multinomial(logits.detach().exp())

        return actions, (), dict(value=value, logits=logits)

In [None]:
from collections import namedtuple

Node = namedtuple('Node', 'pior, zx, children')

<br>

### A2C algo

Service functions for the algorithms

In [None]:
from rlplay.engine.utils.plyr import apply, suply, xgetitem


def timeshift(state, *, shift=1):
    """Get current and shfited slices of nested objects."""
    # use xgetitem to lett None through
    # XXX `curr[t]` = (x_t, a_{t-1}, r_t, d_t), t=0..T-H
    curr = suply(xgetitem, state, index=slice(None, -shift))

    # XXX `next[t]` = (x_{t+H}, a_{t+H-1}, r_{t+H}, d_{t+H}), t=0..T-H
    next = suply(xgetitem, state, index=slice(shift, None))

    return curr, next

The Advantage Actor-Critic algo

In [None]:
import torch.nn.functional as F
from rlplay.algo.returns import pyt_vtrace

# @torch.enable_grad()
def vtrace(fragment, module, *, gamma=0.99, C_entropy=1e-2, C_value=0.5):
    # REACT: (state[t], h_t) \to (\hat{a}_t, h_{t+1}, \hat{A}_t)
    _, _, info = module(
        fragment.state.obs, fragment.state.act,
        fragment.state.rew, fragment.state.fin,
        hx=fragment.hx, stepno=fragment.state.stepno)

    # Assume `.act` is unstructured: `act[t]` = a_{t+1} -->> T x B x 1
    state, state_next = timeshift(fragment.state)
    act = state_next.act.unsqueeze(-1)  # actions taken during the rollout

    # \pi is the target policy, \mu is the behaviour policy (T+1 x B x ...)
    log_pi, log_mu = info['logits'], fragment.actor['logits']

    # the importance weights
    log_pi_a = log_pi.gather(-1, act).squeeze(-1)
    log_mu_a = log_mu.gather(-1, act).squeeze(-1)
    log_rho = log_mu_a.sub_(log_pi_a.detach()).neg_()

    # `.actor[t]` is actor's extra info in reaction to `.state[t]`, t=0..T
    val = fragment.actor['value']  # info['value'].detach()
    # XXX Although Esperholt et al. (2018, sec.~4.2) use the value estimate of
    # the rollout policy for the V-trace target in eq. (1), it makes more sense
    # to use the estimates of the current policy, as has been done in monobeast.
    #  https://hackernoon.com/intuitive-rl-intro-to-advantage-actor-critic-a2c-4ff545978752
    val, bootstrap = val[:-1], val[-1]
    target = pyt_vtrace(state_next.rew, state_next.fin, val,
                        gamma=gamma, bootstrap=bootstrap,
                        omega=log_rho, r_bar=1., c_bar=1.)

    # the critic's mse score against v-trace targets (min)
    critic_mse = F.mse_loss(info['value'][:-1], target, reduction='mean') / 2

    # \delta_t = r_{t+1} + \gamma \nu(s_{t+1}) 1_{\neg d_{t+1}} - v(s_t)
    adv = torch.empty_like(state_next.rew).copy_(bootstrap)
    adv[:-1].copy_(target[1:])  # copy the v-trace targets \nu(s_{t+1})
    adv.masked_fill_(state_next.fin, 0.).mul_(gamma)
    adv.add_(state_next.rew).sub_(val)
    # XXX note `val` here, not `target`! see sec.~4.2 in (Esperholt et al.; 2018)

    # the policy surrogate score (max)
    # \rho_t = \min\{ \bar{\rho}, \frac{\pi_t(a_t)}{\mu_t(a_t)} \}
    rho = log_rho.exp_().clamp_(max=1.)
    vtrace_score = log_pi_a.mul(adv.mul_(rho)).mean()

    # the policy's neg-entropy score (min)
    f_min = torch.finfo(log_pi.dtype).min
    negentropy = log_pi.exp().mul(log_pi.clamp(min=f_min)).sum(dim=-1).mean()

    # maximize the entropy and the reinforce score, minimize the critic loss
    objective = C_entropy * negentropy + C_value * critic_mse - vtrace_score
    return objective.mean(), dict(
        entropy=-float(negentropy),
        policy_score=float(vtrace_score),
        value_loss=float(critic_mse),
    )

The word model loss: the VAE loss and dynamic predicton loss of
a World Model of [Ha and Schdmihuber (2018)](https://proceedings.neurips.cc/paper/2018/hash/2de5d16682c3c35007e4e92982f1a2ba-Abstract.html)
* formal losses may not be as in the paper!

In [None]:
from functools import partial

def wm_loss(enc, dec, dyn, /, fragment, *, beta=1., n_draws=10, iwae=False):
    info = {}

    # ASSSUME obs and act are UNSTRUCTURED
    state_curr, state_next = timeshift(fragment.state)

    # get VAE loss and the encoding distribution q
    loss_vae, enc_q, info['vae'] = vbayes(
        enc, dec, fragment.state.obs,
        beta=beta, n_draws=n_draws, iwae=iwae
    )

    loss_dyn = 0.
    if dyn is not None:
        # prepare the dynamics model (takes in action, mask, and recurrent state)
        # XXX dyn(...) returns a distirbution and the hx update
        dyn_ = lambda X: dyn(X, state_next.act, state_next.fin, hx=fragment.hx)[0]

        # get r(z_{t+1} \mid z_t, a_t)
        Z = enc_q.sample()  # XXX non diff-able sampling!
        loss_dyn, dyn_r, info['dyn'] = vbayes(
            dyn_, dec, X=Z[:-1], Y=state_next.obs,
            prior=dyn.prior, beta=1., n_draws=n_draws, iwae=False
        )

    return loss_vae + loss_dyn, info

<br>

### Run!

Initialize the learner and the environment factories

In [None]:
from functools import partial


factory_eval = partial(base_factory)
factory = partial(base_factory)

# learner = CartPoleActor(lstm='none')
learner = WMCartPoleActor(n_draws=1)

learner.train()
device_ = torch.device('cpu')  # torch.device('cuda:0')
learner.to(device=device_)

optim = torch.optim.Adam(learner.parameters(), lr=1e-3)
sched = None  # torch.optim.lr_scheduler.ReduceLROnPlateau(optim, mode='max', min_lr=1e-7)

Initialize the sampler

In [None]:
T, B = 25, 4

sticky = learner.use_cudnn

In [None]:
from rlplay.engine.rollout import multi

batchit = multi.rollout(
    factory,
    learner,
    n_steps=T,
    n_actors=6,
    n_per_actor=B,
    n_buffers=24,
    n_per_batch=1,
    sticky=sticky,
    pinned=False,
    clone=True,
    close=False,
    device=device_,
    start_method='fork',  # fork in notebook for macos, spawn in linux
)

A generator of evaluation rewards

In [None]:
from rlplay.engine.rollout.evaluate import evaluate

test_it = evaluate(factory_eval, learner, n_envs=4, n_steps=200,
                   clone=False, device=device_, start_method='fork')

Implement your favourite training method

In [None]:
gamma = 0.99
C_entropy, C_wm = 0.1, 0.1

In [None]:
import tqdm
from math import exp
from torch.nn.utils import clip_grad_norm_

torch.set_num_threads(1)

losses, rewards = [], []
for epoch in tqdm.tqdm(range(200)):
    for j, batch in zip(range(100), batchit):
        loss, info = vtrace(batch, learner, gamma=gamma, C_value=1., C_entropy=C_entropy)
        loss_wm, info_ = wm_loss(
            learner.enc, learner.dec, learner.dyn, batch,
            beta=1., iwae=False, n_draws=1
        )
        info.update(info_)

        optim.zero_grad()
        (loss + loss_wm * C_wm).backward()
        grad = clip_grad_norm_(learner.parameters(), max_norm=1.0)
        optim.step()

        losses.append(dict(
            **info, loss=float(loss), grad=float(grad),
            C_entropy=C_entropy,
            perplexity=exp(info['entropy']),
        ))
        
#         if info['entropy'] * 1.5 < ent_target:
#             C_entropy *= 2
        
#         elif info['entropy'] > ent_target * 1.5:
#             C_entropy /= 2

    # fetch the evaluation results lagged by one inner loop!
    rewards.append(next(test_it))
    if sched is not None:
        sched.step(rewards[-1].mean())

    break

In [None]:
C_entropy

In [None]:
# close the generators
batchit.close()
test_it.close()

<br>

In [None]:
data = {k: numpy.array(v) for k, v in collate(losses).items()}

In [None]:
if 'value_loss' in data:
    plt.semilogy(data['value_loss'])

In [None]:
if 'entropy' in data:
    plt.plot(data['entropy'])

In [None]:
if 'policy_score' in data:
    plt.plot(data['policy_score'])

In [None]:
if 'grad' in data:
    plt.semilogy(data['grad'])

In [None]:
if 'vae' in data:
    plt.plot(data['vae'])

In [None]:
if 'dyn' in data:
    plt.plot(data['dyn'])

In [None]:
rewards = numpy.stack(rewards, axis=0)

In [None]:
rewards

In [None]:
m, s = numpy.median(rewards, axis=-1), rewards.std(axis=-1)

In [None]:
fi, ax = plt.subplots(1, 1, figsize=(4, 2), dpi=300)

ax.plot(numpy.mean(rewards, axis=-1))
ax.plot(numpy.median(rewards, axis=-1))
ax.plot(numpy.min(rewards, axis=-1))
ax.plot(numpy.std(rewards, axis=-1))
# ax.plot(m+s * 1.96)
# ax.plot(m-s * 1.96)

plt.show()

<br>

The ultimate evaluation run

In [None]:
from rlplay.engine import core

with factory_eval() as env:
    learner.eval()
    eval_rewards, info = core.evaluate([
        env.env
    ], learner, render=True, n_steps=1e4, device=device_)

print(sum(eval_rewards))

<br>

Let's analyze the performance

In [None]:
from rlplay.algo.returns import npy_returns, npy_deltas

td_target = eval_rewards + gamma * info['value'][1:]
td_error = td_target - info['value'][:-1]
# td_error = npy_deltas(
#     eval_rewards, numpy.zeros_like(eval_rewards, dtype=bool), info['value'][:-1],
#     gamma=gamma, bootstrap=info['value'][-1])

fig, ax = plt.subplots(1, 1, figsize=(4, 2), dpi=300)
ax.semilogy(abs(td_error) / abs(td_target))
ax.set_title('relative td(1)-error');

In [None]:
from rlplay.algo.returns import npy_returns, npy_deltas

# plt.plot(
#     npy_returns(eval_rewards, numpy.zeros_like(eval_rewards, dtype=bool),
#                 gamma=gamma, bootstrap=info['value'][-1]))
fig, ax = plt.subplots(1, 1, figsize=(4, 2), dpi=300)
ax.plot(info['value']);

In [None]:
import math
from scipy.special import softmax, expit, entr

*head, n_actions = info['logits'].shape
proba = softmax(info['logits'], axis=-1)

fig, ax = plt.subplots(1, 1, figsize=(4, 2), dpi=300)
ax.plot(entr(proba).sum(-1)[:, 0])
ax.axhline(math.log(n_actions), c='k', alpha=0.5, lw=1);

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4, 2), dpi=300)
ax.hist(info['logits'][..., 1] - info['logits'][..., 0], bins=51);  # log-ratio

<br>

In [None]:
q = learner.enc(batch.state.obs)

In [None]:
log_pi_raw = learner.policy(q.rsample([100]))

In [None]:
log_pi = log_pi_raw.logsumexp(dim=0) - log(log_pi_raw.shape[0])

In [None]:
log_pi

In [None]:
q.entropy()

In [None]:
assert False

$$
y = \bigvee_k
  \neg \bigl(
    \bigvee_i (
      \neg x_i \wedge w_{i k}
    )
  \bigr) \wedge 1_k
  = \bigvee_k \bigwedge_i (
      x_i \vee \neg w_{i k}
    )
  \,.$$

<br>