# `rlplay`-ing around with Policy Gradients

In [None]:
import torch
import numpy

import matplotlib.pyplot as plt
%matplotlib inline

import gym

# hotfix for gym's unresponsive viz (spawns gl threads!)
import rlplay.utils.integration.gym

See example.ipynb for the overview of `rlplay`

<br>

## Sophisticated CartPole with PG

### The environment

The environment factory

In [None]:
from rlplay.zoo.env import NarrowPath


class FP32Observation(gym.ObservationWrapper):
    def observation(self, observation):
        obs = observation.astype(numpy.float32)
        obs[0] = 0.  # mask the position info
        return obs

#     def step(self, action):
#         obs, reward, done, info = super().step(action)
#         reward -= abs(obs[1]) / 10  # punish for non-zero speed
#         return obs, reward, done, info

class OneHotObservation(gym.ObservationWrapper):
    def observation(self, observation):
        return numpy.eye(1, self.env.observation_space.n,
                         k=observation, dtype=numpy.float32)[0]

def base_factory(seed=None):
    # return gym.make("LunarLander-v2")
    return FP32Observation(gym.make("CartPole-v0").unwrapped)
    # return OneHotObservation(NarrowPath())

<br>

### the Actor

A procedure and a layer, which converts the input integer data into its
little-endian binary representation as float $\{0, 1\}^m$ vectors.

In [None]:
def onehotbits(input, n_bits=63, dtype=torch.float):
    """Encode integers to fixed-width binary floating point vectors"""
    assert not input.dtype.is_floating_point
    assert 0 < n_bits < 64  # torch.int64 is signed, so 64-1 bits max

    # n_bits = {torch.int64: 63, torch.int32: 31, torch.int16: 15, torch.int8 : 7}

    # get mask of set bits
    pow2 = torch.tensor([1 << j for j in range(n_bits)]).to(input.device)
    x = input.unsqueeze(-1).bitwise_and(pow2).to(bool)

    # upcast bool to float to get one-hot
    return x.to(dtype)


class OneHotBits(torch.nn.Module):
    def __init__(self, n_bits=63, dtype=torch.float):
        assert 1 <= n_bits < 64
        super().__init__()
        self.n_bits, self.dtype = n_bits, dtype

    def forward(self, input):
        return onehotbits(input, n_bits=self.n_bits, dtype=self.dtype)

A special module dictionary, which applies itself to the input dict of tensors

In [None]:
from typing import Optional, Mapping
from torch.nn import Module, ModuleDict as BaseModuleDict


class ModuleDict(BaseModuleDict):
    """The ModuleDict, that applies itself to the input dicts."""
    def __init__(
        self,
        modules: Optional[Mapping[str, Module]] = None,
        dim: Optional[int]=-1
    ) -> None:
        super().__init__(modules)
        self.dim = dim

    def forward(self, input):
        # enforce concatenation in the order of the declaration in  __init__
        return torch.cat([
            m(input[k]) for k, m in self.items()
        ], dim=self.dim)

An $\ell_2$ normalization layer.

In [None]:
from torch.nn.functional import normalize

class Normalize(torch.nn.Module):
    def __init__(self, dim=-1):
        super().__init__()
        self.dim = dim

    def forward(self, input):
        return normalize(input, dim=self.dim)

A more sophisticated policy learner

In [None]:
from rlplay.engine import BaseActorModule
from rlplay.utils.common import multinomial

from torch.nn import Sequential, Linear, ReLU, LogSoftmax

class CartPoleActor(BaseActorModule):
    def __init__(self, lstm='none'):
        assert lstm in ('none', 'loop', 'cudnn')
        super().__init__()

        self.use_lstm = self.use_cudnn = False

        # blend the policy with a uniform distribution, determined by
        #  the exploration epsilon. We update it in the actor clones via a buffer
        # self.register_buffer('epsilon', torch.tensor(epsilon))
        # XXX isn't the stochastic policy random enough by itself?

        self.baseline = Sequential(
            Linear(4, 20),
            ReLU(),
            Linear(20, 1),
        )
        self.policy = Sequential(
            Linear(4, 20),
            ReLU(),
            Linear(20, 2),
            LogSoftmax(dim=-1),
        )

    def forward(self, obs, act, rew, fin, *, hx=None, stepno=None, virtual=False):
        # value must not have any trailing dims, i.e. T x B
        logits = self.policy(obs)
        value = self.baseline(obs).squeeze(-1)

        if not self.training:
            actions = logits.argmax(dim=-1)

        else:
            actions = multinomial(logits.detach().exp())

        return actions, (), dict(value=value, logits=logits)

<br>

### PPO/GAE A2C and V-trace A2C algos

Service functions for the algorithms

In [None]:
from plyr import apply, suply, xgetitem


def timeshift(state, *, shift=1):
    """Get current and shifted slices of nested objects."""
    # use `xgetitem` to let None through
    # XXX `curr[t]` = (x_t, a_{t-1}, r_t, d_t), t=0..T-H
    curr = suply(xgetitem, state, index=slice(None, -shift))

    # XXX `next[t]` = (x_{t+H}, a_{t+H-1}, r_{t+H}, d_{t+H}), t=0..T-H
    next = suply(xgetitem, state, index=slice(shift, None))

    return curr, next

The Advantage Actor-Critic algo

In [None]:
import torch.nn.functional as F
from rlplay.algo.returns import pyt_gae, pyt_returns, pyt_multistep

# @torch.enable_grad()
def a2c(
    fragment, module, *, gamma=0.99, gae=1., ppo=0.,
    C_entropy=1e-2, C_value=0.5, c_rho=1.0, multistep=0,
):
    r"""The Advantage Actor-Critic algorithm (importance-weighted off-policy).

    Close to REINFORCE, but uses separate baseline value estimate to compute
    advantages in the policy gradient:
    $$
        \nabla_\theta J(s_t)
            = \mathbb{E}_{a \sim \beta(a\mid s_t)}
                \frac{\pi(a\mid s_t)}{\beta(a\mid s_t)}
                    \bigl( r_{t+1} + \gamma G_{t+1} - v(s_t) \bigr)
                \nabla_\theta \log \pi(a\mid s_t)
        \,, $$

    where the critic estimates the state's value under the current policy
    $$
        v(s_t)
            \approx \mathbb{E}_{\pi_{\geq t}}
                G_t(a_t, s_{t+1}, a_{t+1}, ... \mid s_t)
        \,. $$
    """
    state, state_next = timeshift(fragment.state)

    # REACT: (state[t], h_t) \to (\hat{a}_t, h_{t+1}, \hat{A}_t)
    _, _, info = module(
        state.obs, state.act, state.rew, state.fin,
        hx=fragment.hx, stepno=state.stepno)
    # info['value'] = V(`.state[t]`)
    #               <<-->> v(x_t)
    #               \approx \mathbb{E}( G_t \mid x_t)
    #               \approx \mathbb{E}( r_{t+1} + \gamma r_{t+2} + ... \mid x_t)
    #               <<-->> npv(`.state[t+1:]`)
    # info['logits'] = \log \pi(... | .state[t] )
    #                <<-->> \log \pi( \cdot \mid x_t)

    # `.actor[t]` is actor's extra info in reaction to `.state[t]`, t=0..T
    bootstrap = fragment.actor['value'][-1]
    #     `bootstrap` <<-->> `.value[-1]` = V(`.state[-1]`)

    # XXX post-mul by `1 - \gamma` fails to train, but seems appropriate
    # for the continuation/survival interpretation of the discount factor.
    #   <<-- but who says this is a good interpretation?
    # ret.mul_(1 - gamma)

    # \pi is the target policy, \mu is the behaviour policy
    log_pi, log_mu = info['logits'], fragment.actor['logits']

    # Future rewards after `.state[t]` are recorded in `.state[t+1:]`
    #  G_t <<-->> ret[t] = rew[t] + gamma * (1 - fin[t]) * (ret[t+1] or bootstrap)
    if multistep > 0:
        ret = pyt_multistep(state_next.rew, state_next.fin,
                            info['value'].detach(),
                            gamma=gamma, n_lookahead=multistep,
                            bootstrap=bootstrap.unsqueeze(0))

    else:
        ret = pyt_returns(state_next.rew, state_next.fin,
                          gamma=gamma, bootstrap=bootstrap)

    # the critic's mse score (min)
    #  \frac1{2 T} \sum_t (G_t - v(s_t))^2
    value = info['value']
    critic_mse = F.mse_loss(value, ret, reduction='mean') / 2
    # v(x_t) \approx \mathbb{E}( G_t \mid x_t )
    #        \approx G_t (one-point estimate)
    #        <<-->> ret[t]

    # compute the advantages $G_t - v(s_t)$
    #  or GAE [Schulman et al. (2016)](http://arxiv.org/abs/1506.02438)
    # XXX sec 6.1 in the GAE paper uses V from the `current` value
    #  network, not the one used during the rollout.
    # value = fragment.actor['value'][:-1]
    if gae < 1.:
        # the positional arguments are $r_{t+1}$, $d_{t+1}$, and $v(s_t)$,
        #  respectively, for $t=0..T-1$. The bootstrap is $v(S_T)$ from
        #  the rollout.
        adv = pyt_gae(state_next.rew, state_next.fin, value.detach(),
                      gamma=gamma, C=gae, bootstrap=bootstrap)

    else:
        adv = ret.sub(value.detach())

    # adv.sub_(adv.mean())
    # adv.div_(adv.std(dim=0))

    # Assume `.act` is unstructured: `act[t]` = a_{t+1} -->> T x B x 1
    act = state_next.act.unsqueeze(-1)  # actions taken during the rollout

    # the importance weights
    log_pi_a = log_pi.gather(-1, act).squeeze(-1)
    log_mu_a = log_mu.gather(-1, act).squeeze(-1)

    # the policy surrogate score (max)
    if ppo > 0:
        # the PPO loss is the properly clipped rho times the advantage
        ratio = log_pi_a.sub(log_mu_a).exp()        
        a2c_score = torch.min(
            ratio * adv,
            ratio.clamp(1. - ppo, 1. + ppo) * adv
        ).mean()

    else:
        # \exp{- ( \log \mu - \log \pi )}, evaluated at $a_t \mid z_t$
        rho = log_mu_a.sub_(log_pi_a.detach()).neg_()\
                      .exp_().clamp_(max=c_rho)

        # \frac1T \sum_t \rho_t (G_t - v_t) \log \pi(a_t \mid z_t)
        a2c_score = log_pi_a.mul(adv.mul_(rho)).mean()

    # the policy's neg-entropy score (min)
    #   - H(\pi(•\mid s)) = - (-1) \sum_a \pi(a\mid s) \log \pi(a\mid s)
    f_min = torch.finfo(log_pi.dtype).min
    negentropy = log_pi.exp().mul(log_pi.clamp(min=f_min)).sum(dim=-1).mean()

    # breakpoint()

    # maximize the entropy and the reinforce score, minimize the critic loss
    objective = C_entropy * negentropy + C_value * critic_mse - a2c_score
    return objective.mean(), dict(
        entropy=-float(negentropy),
        policy_score=float(a2c_score),
        value_loss=float(critic_mse),
    )

A couple of three things:
* a2c is on-policy and no importance weight could change this!
* L72-80: [stable_baselines3](./common/on_policy_algorithm.py#L183-192)
  and [rlpyt](./algos/pg/base.py#L49-58) use rollout data, when computing the GAE

* L61-62: [stable_baselines3](./stable_baselines3/a2c/a2c.py#L147-156) uses `vf_coef=0.5`,
  and **unhalved** `F.mse-loss`, while [rlpyt](./rlpyt/rlpyt/algos/pg/a2c.py#L93-94)
  uses `value_loss_coeff=0.5`, and **halved** $\ell_2$ loss!

The off-policy actor-critic algorithm for the learner, called V-trace,
from [Espeholt et al. (2018)](http://proceedings.mlr.press/v80/espeholt18a.html).

In [None]:
from rlplay.algo.returns import pyt_vtrace

# @torch.enable_grad()
def vtrace(fragment, module, *, gamma=0.99, C_entropy=1e-2, C_value=0.5):
    # REACT: (state[t], h_t) \to (\hat{a}_t, h_{t+1}, \hat{A}_t)
    _, _, info = module(
        fragment.state.obs, fragment.state.act,
        fragment.state.rew, fragment.state.fin,
        hx=fragment.hx, stepno=fragment.state.stepno)

    # Assume `.act` is unstructured: `act[t]` = a_{t+1} -->> T x B x 1
    state, state_next = timeshift(fragment.state)
    act = state_next.act.unsqueeze(-1)  # actions taken during the rollout

    # \pi is the target policy, \mu is the behaviour policy (T+1 x B x ...)
    log_pi, log_mu = info['logits'], fragment.actor['logits']

    # the importance weights
    log_pi_a = log_pi.gather(-1, act).squeeze(-1)
    log_mu_a = log_mu.gather(-1, act).squeeze(-1)
    log_rho = log_mu_a.sub_(log_pi_a.detach()).neg_()

    # `.actor[t]` is actor's extra info in reaction to `.state[t]`, t=0..T
    val = fragment.actor['value']  # info['value'].detach()
    # XXX Although Esperholt et al. (2018, sec.~4.2) use the value estimate of
    # the rollout policy for the V-trace target in eq. (1), it makes more sense
    # to use the estimates of the current policy, as has been done in monobeast.
    #  https://hackernoon.com/intuitive-rl-intro-to-advantage-actor-critic-a2c-4ff545978752
    val, bootstrap = val[:-1], val[-1]
    target = pyt_vtrace(state_next.rew, state_next.fin, val,
                        gamma=gamma, bootstrap=bootstrap,
                        omega=log_rho, r_bar=1., c_bar=1.)

    # the critic's mse score against v-trace targets (min)
    critic_mse = F.mse_loss(info['value'][:-1], target, reduction='mean') / 2

    # \delta_t = r_{t+1} + \gamma \nu(s_{t+1}) 1_{\neg d_{t+1}} - v(s_t)
    adv = torch.empty_like(state_next.rew).copy_(bootstrap)
    adv[:-1].copy_(target[1:])  # copy the v-trace targets \nu(s_{t+1})
    adv.masked_fill_(state_next.fin, 0.).mul_(gamma)
    adv.add_(state_next.rew).sub_(val)
    # XXX note `val` here, not `target`! see sec.~4.2 in (Esperholt et al.; 2018)

    # the policy surrogate score (max)
    # \rho_t = \min\{ \bar{\rho}, \frac{\pi_t(a_t)}{\mu_t(a_t)} \}
    rho = log_rho.exp_().clamp_(max=1.)
    vtrace_score = log_pi_a.mul(adv.mul_(rho)).mean()

    # the policy's neg-entropy score (min)
    f_min = torch.finfo(log_pi.dtype).min
    negentropy = log_pi.exp().mul(log_pi.clamp(min=f_min)).sum(dim=-1).mean()

    # maximize the entropy and the reinforce score, minimize the critic loss
    objective = C_entropy * negentropy + C_value * critic_mse - vtrace_score
    return objective.mean(), dict(
        entropy=-float(negentropy),
        policy_score=float(vtrace_score),
        value_loss=float(critic_mse),
    )

<br>

### Run!

Initialize the learner and the environment factories

In [None]:
from functools import partial


factory_eval = partial(base_factory)
factory = partial(base_factory)

learner = CartPoleActor(lstm='none')

learner.train()
device_ = torch.device('cpu')  # torch.device('cuda:0')
learner.to(device=device_)

optim = torch.optim.Adam(learner.parameters(), lr=1e-3)

Initialize the sampler

In [None]:
T, B = 25, 4

sticky = learner.use_cudnn

In [None]:
from rlplay.engine.rollout import multi

batchit = multi.rollout(
    factory,
    learner,
    n_steps=T,
    n_actors=6,
    n_per_actor=B,
    n_buffers=15,
    n_per_batch=2,
    sticky=sticky,
    pinned=False,
    clone=True,
    close=False,
    device=device_,
    start_method='fork',  # fork in notebook for macos, spawn in linux
)

A generator of evaluation rewards

In [None]:
from rlplay.engine.rollout.evaluate import evaluate

test_it = evaluate(factory_eval, learner, n_envs=4, n_steps=500,
                   clone=False, device=device_, start_method='fork')

Implement your favourite training method

In [None]:
n_epochs = 100
use_vtrace = True

# gamma, gae, ppo = 0.99, 0.92, 0.2
gamma, gae, ppo, multistep = 0.99, 1., 0.2, 0

In [None]:
import tqdm
from torch.nn.utils import clip_grad_norm_

torch.set_num_threads(1)

losses, rewards = [], []
for epoch in tqdm.tqdm(range(n_epochs)):
    for j, batch in zip(range(100), batchit):
        if use_vtrace:
            loss, info = vtrace(batch, learner, gamma=gamma)

        else:
            loss, info = a2c(batch, learner, gamma=gamma, gae=gae, ppo=ppo, multistep=multistep)
        

        optim.zero_grad()
        loss.backward()
        grad = clip_grad_norm_(learner.parameters(), max_norm=1.0)
        optim.step()

        losses.append(dict(
            loss=float(loss), grad=float(grad), **info
        ))

    # fetch the evaluation results lagged by one inner loop!
    rewards.append(next(test_it))

In [None]:
# close the generators
batchit.close()
test_it.close()

<br>

In [None]:
def collate(records):
    """collate identically keyed dicts"""
    out, n_records = {}, 0
    for record in records:
        for k, v in record.items():
            out.setdefault(k, []).append(v)
    
    return out


data = {k: numpy.array(v) for k, v in collate(losses).items()}

In [None]:
if 'value_loss' in data:
    plt.semilogy(data['value_loss'])

In [None]:
if 'entropy' in data:
    plt.plot(data['entropy'])

In [None]:
if 'policy_score' in data:
    plt.plot(data['policy_score'])

In [None]:
plt.semilogy(data['grad'])

In [None]:
rewards = numpy.stack(rewards, axis=0)

In [None]:
rewards

In [None]:
m, s = numpy.median(rewards, axis=-1), rewards.std(axis=-1)

In [None]:
fi, ax = plt.subplots(1, 1, figsize=(4, 2), dpi=300)

ax.plot(numpy.mean(rewards, axis=-1))
ax.plot(numpy.median(rewards, axis=-1))
ax.plot(numpy.min(rewards, axis=-1))
ax.plot(numpy.std(rewards, axis=-1))
# ax.plot(m+s * 1.96)
# ax.plot(m-s * 1.96)

plt.show()

<br>

The ultimate evaluation run

In [None]:
from rlplay.engine import core

with factory_eval() as env:
    learner.eval()
    eval_rewards, info = core.evaluate([
        env
    ], learner, render=True, n_steps=1e4, device=device_)

print(sum(eval_rewards))

<br>

Let's analyze the performance

In [None]:
from rlplay.algo.returns import npy_returns, npy_deltas

td_target = eval_rewards + gamma * info['value'][1:]
td_error = td_target - info['value'][:-1]
# td_error = npy_deltas(
#     eval_rewards, numpy.zeros_like(eval_rewards, dtype=bool), info['value'][:-1],
#     gamma=gamma, bootstrap=info['value'][-1])

fig, ax = plt.subplots(1, 1, figsize=(4, 2), dpi=300)
ax.semilogy(abs(td_error) / abs(td_target))
ax.set_title('relative td(1)-error');

In [None]:
from rlplay.algo.returns import npy_returns, npy_deltas

# plt.plot(
#     npy_returns(eval_rewards, numpy.zeros_like(eval_rewards, dtype=bool),
#                 gamma=gamma, bootstrap=info['value'][-1]))
fig, ax = plt.subplots(1, 1, figsize=(4, 2), dpi=300)
ax.plot(info['value'])
ax.axhline(1 / (1 - gamma), c='k', alpha=0.5, lw=1);

In [None]:
import math
from scipy.special import softmax, expit, entr

*head, n_actions = info['logits'].shape
proba = softmax(info['logits'], axis=-1)

fig, ax = plt.subplots(1, 1, figsize=(4, 2), dpi=300)
ax.plot(entr(proba).sum(-1)[:, 0])
ax.axhline(math.log(n_actions), c='k', alpha=0.5, lw=1);

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4, 2), dpi=300)
ax.hist(info['logits'][..., 1] - info['logits'][..., 0], bins=51);  # log-ratio

<br>

In [None]:
assert False

<br>

### Other agents

An agent that uses other inputs, beside `obs`.

In [None]:
class CartPoleActor(BaseActorModule):
    def __init__(self, epsilon=0.1, lstm='none'):
        assert lstm in ('none', 'loop', 'cudnn')
        super().__init__()

        self.use_lstm = (lstm != 'none')
        self.use_cudnn = (lstm == 'cudnn')

        # for updating the exploration epsilon in the actor clones
        self.register_buffer('epsilon', torch.tensor(epsilon))

        # the features
        n_output_dim = dict(obs=64, act=8, stepno=0)
        self.features = torch.nn.Sequential(
            ModuleDict(dict(
                obs=Linear(4, n_output_dim['obs']),

                act=Embedding(2, n_output_dim['act']),

                stepno=Sequential(
                    OneHotBits(32),
                    Linear(32, n_output_dim['stepno']),
                ),
            )),
            ReLU(),
        )

        # the core
        n_features = sum(n_output_dim.values())
        if self.use_lstm:
            self.core = LSTM(n_features, 64, 1)

        else:
            self.core = Sequential(
                Linear(n_features, 64, bias=True),
                ReLU(),
            )
    
        # the rest of the actor's model
        self.baseline = Linear(64, 1)
        self.policy = Sequential(
            Linear(64, 2),
            LogSoftmax(dim=-1),
        )

    def forward(self, obs, act, rew, fin, *, hx=None, stepno=None, virtual=False):
        # Everything is  [T x B x ...]
        input = self.features(locals())

        # `input` is T x B x F, `hx` is either `None`, or a proper recurrent state
        n_steps, n_envs, *_ = fin.shape
        if not self.use_lstm:
            # update `hx` into an empty container
            out, hx = self.core(input), ()

        elif not self.use_cudnn:
            outputs = []
            for x, m in zip(input.unsqueeze(1), ~fin.unsqueeze(-1)):
                # `m` indicates if NO reset took place, otherwise
                #  multiply by zero to stop the grads
                if hx is not None:
                    hx = suply(m.mul, hx)
        
                # one LSTM step [1 x B x ...]
                output, hx = self.core(x, hx)
                outputs.append(output)

            # compile the output
            out = torch.cat(outputs, dim=0)

        else:
            # sequence padding (MUST have sampling with `sticky=True`)
            if n_steps > 1:
                lengths = 1 + (~fin[1:]).sum(0).cpu()
                input = pack_padded_sequence(input, lengths, enforce_sorted=False)

            out, hx = self.core(input, hx)
            if n_steps > 1:
                out, lens = pad_packed_sequence(
                    out, batch_first=False, total_length=n_steps)

        # apply relu after the core and get the policy
        logits = self.policy(out)

        # value must not have any trailing dims, i.e. T x B
        value = self.baseline(out).squeeze(-1)

        if not self.training:
            actions = logits.argmax(dim=-1)

        else:
            # blend the policy with a uniform distribution
            prob = logits.detach().exp().mul_(1 - self.epsilon)
            prob.add_(self.epsilon / logits.shape[-1])

            actions = multinomial(prob)

        return actions, hx, dict(value=value, logits=logits)

A non-recurrent actor with features shared between the policy and the baseline.

In [None]:
class CartPoleActor(BaseActorModule):
    def __init__(self, epsilon=0.1, lstm='none'):
        assert lstm in ('none', 'loop', 'cudnn')
        super().__init__()

        self.use_lstm = self.use_cudnn = False

        # for updating the exploration epsilon in the actor clones
        self.register_buffer('epsilon', torch.tensor(epsilon))

        # the features
        self.features = Sequential(
            Linear(4, 20),
            ReLU(),
        )

        self.baseline = Linear(20, 1)
        self.policy = Sequential(
            Linear(20, 2),
            LogSoftmax(dim=-1),
        )


    def forward(self, obs, act, rew, fin, *, hx=None, stepno=None, virtual=False):
        x = self.features(obs)

        # value must not have any trailing dims, i.e. T x B
        logits = self.policy(x)
        value = self.baseline(x).squeeze(-1)

        if not self.training:
            actions = logits.argmax(dim=-1)

        else:
            # blend the policy with a uniform distribution
            prob = logits.detach().exp().mul_(1 - self.epsilon)
            prob.add_(self.epsilon / logits.shape[-1])

            actions = multinomial(prob)

        return actions, (), dict(value=value, logits=logits)

<br>

In [None]:
# stepno = batch.state.stepno
stepno = torch.arange(256)

In [None]:
with torch.no_grad():
    out = learner.features[0]['stepno'](stepno)

    out = F.linear(F.relu(out), learner.core[1].weight[:, -8:],
                       bias=learner.core[1].bias)
#     out = F.linear(F.relu(out), learner.core.weight_ih_l0[:, -8:],
#                        bias=learner.core.bias_ih_l0)
#     out = F.relu(out)

In [None]:
fig, axes = plt.subplots(3, 3, figsize=(8, 8), dpi=200,
                         sharex=True, sharey=True)

for j, ax in zip(range(out.shape[1]), axes.flat):
    ax.plot(out[:, j], lw=1)

fig.tight_layout(pad=0, h_pad=0, w_pad=0)

In [None]:
with torch.no_grad():
    plt.imshow(abs(learner.core[1].weight[:, -8:]).T)

In [None]:
lin = learner.features.stepno[1]

In [None]:
with torch.no_grad():
    plt.imshow(abs(lin.weight))