In [1]:
import jax
import tax
import clu
import chex
import tqdm
import haiku as hk
import numpy as np
import collections 
import jax.numpy as jnp
import matplotlib as mpl
import matplotlib.pyplot as plt

import mbrl
from jax import jit
from jax import vmap
from functools import partial

from mbrl.envs.oracle.pendulum import render, step, reset, env_params
from mbrl.algs.cem import forecast, score
from mbrl.algs.cem import get_elite_stats
from mbrl.algs.cem import score
from mbrl.algs.cem import plan

Environment = collections.namedtuple('Environment', ['step', 'reset'])

In [2]:
rng = jax.random.PRNGKey(42)
env = Environment(
    jit(lambda state, u: step(env_params, state, u)), 
    jit(reset)
)
def world(carry, t):
    keys, (env_state, observation), trajectory = carry
    action = trajectory[t]
    env_state_next, observation_next, reward, terminal, info = \
        env.step(env_state, action)
    carry = keys, (env_state_next, observation_next), trajectory
    return carry, {
        "observation": observation,
        "observation_next": observation_next,
        "reward": reward, "action": action, "terminal": 1 - terminal,
        "env_state": env_state, 'env_state_next': env_state_next
    }

In [3]:
score_    = jit(score)
forecast_ = partial(forecast, 
                    step_fn=world, 
                    horizon=20, 
                    action_dim=1, 
                    minval=-2., 
                    maxval=2.)

In [4]:
action_dim, horizon = 1, 20
loc = jnp.zeros((horizon, action_dim))
scale = jnp.ones((horizon, action_dim))

In [5]:
env_state_0, ob_0 = env.reset(rng)
traj = forecast_(rng, (env_state_0, ob_0), loc, scale)

In [7]:
_ = plan(rng, (env_state_0, ob_0),  forecast_, score_, action_dim=1)

In [10]:
# Random
score = 0
env_state, observation = env.reset(rng)
for _ in tqdm.notebook.trange(200):
    rng, key = jax.random.split(rng)
    action = jax.random.uniform(key, (1,), minval=-2., maxval=2.)
    env_state, observation_next, reward, terminal, info = env.step(env_state, action)
    score += reward
    
print(f'Random Score: {score}')

  0%|          | 0/200 [00:00<?, ?it/s]

Random Score: -1354.366943359375


In [20]:
%%time
# CEM:Model.
score = 0
env_state, observation = env.reset(rng)
for _ in tqdm.notebook.trange(200):
    rng, key = jax.random.split(rng)
    action = plan(rng, (env_state, observation),  forecast_, score_, action_dim=1)[0][0]
    env_state, observation, reward, terminal, info = env.step(env_state, action)
    score += reward

print(f'Random Score: {score}')

  0%|          | 0/200 [00:00<?, ?it/s]

Random Score: -0.30772554874420166
CPU times: user 982 ms, sys: 40.2 ms, total: 1.02 s
Wall time: 817 ms


In [12]:
""" Entire Loop with scan"""

def one_step(carry, t):
    key, (env_state, observation)  = carry
    key, subkey = jax.random.split(key)
    action = plan(rng, (env_state, observation),  forecast_, score_, action_dim=1)[0][0]
    env_state_next, observation_next, reward, terminal, info = \
        env.step(env_state, action)
    carry = key, (env_state_next, observation_next)
    return carry, {
        "observation": observation,
        "observation_next": observation_next,
        "reward": reward, "action": action, "terminal": 1 - terminal,
        "env_state": env_state, 'env_state_next': env_state_next
    }

In [22]:
%%time
env_state, observation = env.reset(rng)
init = (rng, (env_state, observation))
_, out = jax.lax.scan(one_step, init, jnp.arange(200))

CPU times: user 490 ms, sys: 0 ns, total: 490 ms
Wall time: 488 ms


In [14]:
jnp.sum(out['reward'])

DeviceArray(-484.64685, dtype=float32)