In [1]:
import jax
import tax
import clu
import tqdm
import haiku as hk
import numpy as np
import collections 
import jax.numpy as jnp
import matplotlib as mpl
import matplotlib.pyplot as plt
import mbrl
import brax
import tqdm
import functools

from brax import envs
from brax.io import html
from jax import jit
from functools import partial
from mbrl.algs.rs import trajectory_search, forecast, score, plan
from IPython.display import HTML, IFrame, display, clear_output 

def visualize(sys, qps):
  """Renders a 3D visualization of the environment."""
  return HTML(html.render(sys, qps))


tax.set_platform('cpu')

rng = jax.random.PRNGKey(42)


In [2]:
name = 'halfcheetah'
envf = envs.create_fn(name)
env = envf()
env_state = env.reset(rng=rng)
action_size = env.action_size
observation_size = env.observation_size

In [3]:
dummy_action = jax.random.uniform(rng, (action_size,), minval=-1, maxval=1)

In [4]:
env_state_next = env.step(env_state, dummy_action)

# World Model

In this section, we build the `step` or `world` variable necessary to plan

In [8]:
@jit
def step(carry, t):
    rng, env_state, action_trajectory = carry
    action = action_trajectory[t]
    env_state_next = env.step(env_state, action)
    carry = (rng, env_state_next, action_trajectory)
    
    info = dict(
        observation=env_state.obs,
        observation_next=env_state.obs,
        reward=env_state_next.reward,
        terminal=1 - env_state_next.done,
        action=action,
        env_state=env_state,
        env_state_next=env_state_next,
    )
    return carry, info

# Planning Routine

In [9]:
forecast_ = partial(
    forecast, step_fn=step,
    horizon=20, action_dim=action_size,
    minval=-1, maxval=1,
)

In [9]:
env_state = env.reset(rng)
for _ in tqdm.notebook.trange(1000):
    action, _ = plan(rng, env_state, forecast_, score)
    # Slow...

  0%|          | 0/1000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [13]:
@jit
def one_step_interaction(carry, t):
    rng, env_state = carry
    action = plan(rng, env_state, forecast_, score)[0][0]
    env_state_next = env.step(env_state, action)
    carry = (rng, env_state_next)
    
    info = dict(
        observation=env_state.obs,
        observation_next=env_state.obs,
        reward=env_state_next.reward,
        terminal=1 - env_state_next.done,
        action=action,
        env_state=env_state,
        env_state_next=env_state_next,
    )
    return carry, info

In [14]:
env_state = env.reset(rng)
init = (rng, env_state)


In [15]:
%%time
_, out = jax.lax.scan(one_step_interaction, init, jnp.arange(1000))  # First should be long.

CPU times: user 1min 23s, sys: 134 ms, total: 1min 23s
Wall time: 1min 23s


In [49]:
out.keys()

dict_keys(['action', 'env_state', 'env_state_next', 'observation', 'observation_next', 'reward', 'terminal'])

In [51]:
sum(out['reward'])

5864.84186822176

In [22]:
%%time
_, out = jax.lax.scan(one_step_interaction, init, jnp.arange(1000)) 

CPU times: user 3.55 ms, sys: 0 ns, total: 3.55 ms
Wall time: 2 ms


In [None]:
out.rewards