In [13]:
import jax
import tax
import clu
import tqdm
import haiku as hk
import numpy as np
import collections 
import jax.numpy as jnp
import matplotlib as mpl
import matplotlib.pyplot as plt

import mbrl
from jax import jit
from functools import partial

from mbrl.envs.oracle.pendulum import render, step, reset, env_params
from mbrl.algs.rs import trajectory_search, forecast, score, plan

Environment = collections.namedtuple('Environment', ['step', 'reset'])

In [14]:
rng = jax.random.PRNGKey(42)
env = Environment(
    jit(lambda state, u: step(env_params, state, u)), 
    jit(reset)
)

In [15]:
def world(carry, t):
    keys, (env_state, observation), trajectory = carry
    action = trajectory[t]
    env_state_next, observation_next, reward, terminal, info = \
        env.step(env_state, action)
    carry = keys, (env_state_next, observation_next), trajectory
    return carry, {
        "observation": observation,
        "observation_next": observation_next,
        "reward": reward, "action": action, "terminal": 1 - terminal,
        "env_state": env_state, 'env_state_next': env_state_next
    }

In [16]:
score_    = jit(score)
forecast_ = partial(forecast, 
                    step_fn=world, 
                    horizon=20, 
                    action_dim=1, 
                    minval=-2., 
                    maxval=2.)

In [21]:
env_state_0, ob_0 = env.reset(rng)
action, _ = plan(rng, (env_state_0, ob_0), forecast_, score_)
action[0]

DeviceArray([1.6705208], dtype=float32)

In [34]:
# Random
score = 0
env_state, observation = env.reset(rng)
for _ in tqdm.notebook.trange(200):
    rng, key = jax.random.split(rng)
    action = jax.random.uniform(key, (1,), minval=-2., maxval=2.)
    env_state, observation_next, reward, terminal, info = env.step(env_state, action)
    score += reward
    
print(f'Random Score: {score}')

  0%|          | 0/200 [00:00<?, ?it/s]

Random Score: -1474.5244140625


In [49]:
# RS:Model.
score = 0
env_state, observation = env.reset(rng)
for _ in tqdm.notebook.trange(200):
    rng, key = jax.random.split(rng)
    action = plan(key, (env_state, observation), forecast_, score_)[0][0]
    env_state, observation, reward, terminal, info = env.step(env_state, action)
    score += reward

print(f'Random Score: {score}')

  0%|          | 0/200 [00:00<?, ?it/s]

Random Score: -121.38945770263672


In [63]:
""" Entire Loop with scan"""

def one_step(carry, t):
    key, (env_state, observation)  = carry
    key, subkey = jax.random.split(key)
    action = plan(subkey, (env_state, observation), forecast_, score_)[0][0]
    env_state_next, observation_next, reward, terminal, info = \
        env.step(env_state, action)
    carry = key, (env_state_next, observation_next)
    return carry, {
        "observation": observation,
        "observation_next": observation_next,
        "reward": reward, "action": action, "terminal": 1 - terminal,
        "env_state": env_state, 'env_state_next': env_state_next
    }

In [64]:
env_state, observation = env.reset(rng)
init = (rng, (env_state, observation))
_, out = jax.lax.scan(one_step, init, jnp.arange(200))

In [65]:
jnp.sum(out['reward'])

DeviceArray(-115.51241, dtype=float32)