In [57]:
import jax
import gymnax
import flax.linen as nn

In [326]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
key = jax.random.PRNGKey(314)

In [3]:
rng, key_reset, key_act, key_step = jax.random.split(key, 4)

In [4]:
# Instantiate the environment & its settings.
env, env_params = gymnax.make("Breakout-MinAtar")

# Reset the environment.
obs, state = env.reset(key_reset, env_params)

# Sample a random action.
action = env.action_space(env_params).sample(key_act)

# Perform the step transition.
n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)

  return lax_numpy.astype(self, dtype, copy=copy, device=device)


## Could you learn?

In [6]:
class CNN(nn.Module):
    num_actions: int

    @nn.compact
    def __call__(self, x):
        x = x if len(x.shape) > 3 else x[None, :]
        x = nn.Conv(features=6, kernel_size=(5, 5))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2), padding="VALID")
        x = nn.Conv(features=16, kernel_size=(2, 2), padding="VALID")(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2), padding="VALID")
        x = x.reshape((x.shape[0], -1))  # Flatten
        x = nn.Dense(features=120)(x)
        x = nn.relu(x)
        x = nn.Dense(features=84)(x)
        x = nn.relu(x)
        # x = nn.Dense(20, name="last_layer")(x)
        # x = nn.relu(x)
        x = nn.Dense(self.num_actions, name="last_layer")(x)
        return x.squeeze()

model = CNN(num_actions=3)

In [7]:
import jax.numpy as jnp

In [8]:
params_init = model.init(key, n_obs)
jax.tree.map(jnp.shape, params_init)

{'params': {'Conv_0': {'bias': (6,), 'kernel': (5, 5, 4, 6)},
  'Conv_1': {'bias': (16,), 'kernel': (2, 2, 6, 16)},
  'Dense_0': {'bias': (120,), 'kernel': (64, 120)},
  'Dense_1': {'bias': (84,), 'kernel': (120, 84)},
  'last_layer': {'bias': (3,), 'kernel': (84, 3)}}}

In [19]:
from rebayes_mini.methods import low_rank_last_layer as onflow
from rebayes_mini.callbacks import get_null

In [20]:
def mean_fn(params, x):
    eta = model.apply(params, x)
    return jax.nn.softmax(eta)

def cov_fn(mean, eps=0.1):
    return jnp.diag(mean) - jnp.outer(mean, mean) + jnp.eye(len(mean)) * eps

In [21]:
agent = onflow.LowRankLastLayer(
    mean_fn,
    cov_fn,
    rank=5,
    dynamics_hidden=0.0,
    dynamics_last=0.0
)

In [22]:
bel_init = agent.init_bel(params_init)

In [23]:
bel_next, _ = agent.step(bel_init, action, n_obs, callback_fn=get_null)

In [24]:
agent.sample_predictive(key, bel_next, n_obs)

Array([0.6693919 , 0.04068932, 0.46270543], dtype=float32)

In [15]:
# Instantiate the environment & its settings.
env, env_params = gymnax.make("Breakout-MinAtar")

# Reset the environment.
obs, state = env.reset(key_reset, env_params)

# Sample a random action.
action = env.action_space(env_params).sample(key_act)

# Perform the step transition.
n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)

  return lax_numpy.astype(self, dtype, copy=copy, device=device)


## Gym lunar lander

In [667]:
%config InlineBackend.figure_format = "retina"

In [111]:
import gym

In [195]:
import jax
import jax.numpy as jnp

In [646]:
import matplotlib.pyplot as plt

In [682]:
from rebayes_mini.methods import low_rank_filter_revised as lrkf
from rebayes_mini.methods import base_filter

In [683]:
env = gym.make("LunarLander-v2")

In [684]:
class LinearPolicy(nn.Module):
    n_actions: int
    n_features: int
    
    @nn.compact
    def __call__(self, x):
        a = x[..., 0]
        x = x[..., 1:]
        embedding = nn.Embed(self.n_actions, self.n_features)(a.astype(int))
        x = jnp.einsum("...j,...j->...", embedding, x)
        return x

In [685]:
key = jax.random.PRNGKey(314)
X_init = jnp.ones(env.observation_space.shape[0] + 1)
model = LinearPolicy(n_actions=env.action_space.n, n_features=env.observation_space.shape[0])
params_init = model.init(key, X_init[None, :])
jax.tree.map(jnp.shape, params_init)

{'params': {'Embed_0': {'embedding': (4, 8)}}}

In [695]:
agent = lrkf.LowRankCovarianceFilter(
    model.apply,
    lambda x: jnp.eye(1) * 1e-3,
    dynamics_covariance=1e-3,
    rank=32
)


agent = base_filter.ExtendedFilter(
    model.apply,
    lambda x: jnp.eye(1) * 1.0,
    dynamics_covariance=1e-7
)

@jax.jit
def sample_predictives(key, bel, state):
    state = jnp.atleast_1d(state)
    actions = jnp.arange(4)
    keys = jax.random.split(key, 4)
    @jax.vmap
    def _sample(key, action):
        X = jnp.c_[action, state]
        sample = agent.sample_predictive(key, bel_init, X)
        return sample
    samples = _sample(keys, actions)
    return samples

In [696]:
from tqdm import tqdm

In [None]:
gamma = 0.999
state_prev, _ = env.reset(seed=314)
action_prev = 0
bel_init = agent.init_bel(params_init, cov=10.0)
bel = bel_init

steps = 5_000
episode_reward = 0.0
final_rewards = []
ewm_episode = 0.0
rewards = np.zeros(steps)

for s in (pbar := tqdm(range(steps))):
    key_t = jax.random.fold_in(key, s)
    action_value_samples = sample_predictives(key_t, bel, state_prev[None, :])
    action = action_value_samples.argmax().item()
    state_next, reward, terminated, truncated, info = env.step(action)
    
    if terminated:
        _, key_new = jax.random.split(key_t)
        state_prev, info = env.reset()
        action_value_samples = sample_predictives(key_t, bel, state_prev[None, :])
        action = action_value_samples.argmax().item()
        final_rewards.append(episode_reward)
        ewm_episode = 0.3 * episode_reward + (1 - 0.3) * ewm_episode
        pbar.set_description(f"EWM: {ewm_episode:0.2f}")
        episode_reward = 0.0
    else:
        rewards[s] = reward
        episode_reward += reward
    
    X_prev = jnp.insert(state_prev, 0, jnp.array([action_prev]))
    
    X_next = jnp.insert(state_next, 0, jnp.array([action]), 0)
    y = reward + gamma * agent.mean_fn(bel.mean, X_next)
    
    state_prev = state_next
    action_prev = action
    
    # bel = agent.update(bel, bel, y, X_next[None, :])
    bel = agent.update(bel, bel, y, X_next[None, :])

EWM: -118.27:  80%|████████████████████████████████████████████████████████████████████▌                 | 3988/5000 [01:39<00:24, 41.31it/s]

In [None]:
plt.plot(np.array(final_rewards))

In [618]:
action_value_samples = sample_predictives(key, bel, state_prev[None, :])
action = action_value_samples.argmax().item()

In [619]:
state_next, reward, terminated, truncated, info = env.step(action)

In [620]:
X_prev = jnp.insert(observation_next, 0, jnp.array([action_prev]))
yhat = agent.mean_fn(bel_init.mean, X_prev)

In [621]:
X_next = jnp.insert(observation_next, 0, jnp.array([action]), 0)
y = reward + gamma * agent.mean_fn(bel_init.mean, X_next)
state_prev = state_next

In [622]:
reward

np.float64(0.4749850519668257)