In [1]:
import jax
import tax
import tqdm
import haiku as hk
import numpy as np
import collections 
import jax.numpy as jnp
import matplotlib as mpl
import matplotlib.pyplot as plt
import mbrl
import brax
import tqdm
import functools

from brax import envs
from brax.io import html
from jax import jit
from functools import partial
from mbrl.algs.rs import trajectory_search, forecast, score, plan
from IPython.display import HTML, IFrame, display, clear_output 

def visualize(sys, qps):
  """Renders a 3D visualization of the environment."""
  return HTML(html.render(sys, qps))


tax.set_platform('cpu')

rng = jax.random.PRNGKey(42)


In [2]:
name = 'halfcheetah'
envf = envs.create_fn(name)
env = envf()
env_state = env.reset(rng=rng)
action_size = env.action_size
observation_size = env.observation_size

In [3]:
dummy_action = jax.random.uniform(rng, (action_size,), minval=-1, maxval=1)

In [4]:
env_state_next = env.step(env_state, dummy_action)

# World Model

In this section, we build the `step` or `world` variable necessary to plan

In [5]:
@jit
def step(carry, t):
    rng, env_state, action_trajectory = carry
    action = action_trajectory[t]
    env_state_next = env.step(env_state, action)
    carry = (rng, env_state_next, action_trajectory)
    
    info = dict(
        observation=env_state.obs,
        observation_next=env_state.obs,
        reward=env_state_next.reward,
        terminal=1 - env_state_next.done,
        action=action,
        env_state=env_state,
        env_state_next=env_state_next,
    )
    return carry, info

# Planning Routine

In [6]:
forecast_ = partial(
    forecast, step_fn=step,
    horizon=20, action_dim=action_size,
    minval=-1, maxval=1,
)

In [7]:
env_state = env.reset(rng)
for _ in tqdm.notebook.trange(1000):
    action, _ = plan(rng, env_state, forecast_, score)
    # Slow...

  0%|          | 0/1000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [8]:
@jit
def one_step_interaction(carry, t):
    rng, env_state = carry
    action = plan(rng, env_state, forecast_, score)[0][0]
    env_state_next = env.step(env_state, action)
    carry = (rng, env_state_next)
    
    info = dict(
        observation=env_state.obs,
        observation_next=env_state.obs,
        reward=env_state_next.reward,
        terminal=1 - env_state_next.done,
        action=action,
        env_state=env_state,
        env_state_next=env_state_next,
    )
    return carry, info

In [9]:
%%time
env_state = env.reset(rng)
init = (rng, env_state)
_, out = jax.lax.scan(one_step_interaction, init, jnp.arange(1000))  # First should be long.
print(out['reward'])

[-0.59614486 -1.2205255   1.2944616   3.7836158   3.735303    2.26326
  2.8578355   3.3452642   3.2587044   2.0958436   2.9565184   3.467357
  2.4864757   3.0019763   2.6563258   2.6354477   2.9771435   3.0021613
  3.2320645   2.8083746   2.3565824   3.2244482   2.832795    2.2789793
  1.0663259   2.0100381   3.8732      4.962577    4.8896017   4.751386
  5.372682    3.982963    4.211172    4.4375396   4.6597533   1.8023925
  4.096928    3.603277    6.0439186   5.6798043   3.7516007   3.7363558
  5.845078    2.9866107   6.594893    5.51392     4.53895     4.727485
  2.478754    3.7850282   6.7433977   4.286942    5.453629    3.9796917
  4.7742586   4.477082    3.9253964   4.562143    5.278983    3.918555
  4.8979263   4.475938    5.0183496   3.5996912   3.5401297   6.581811
  5.6840706   5.7974296   4.4645104   4.36213     5.944598    4.81706
  6.19832     4.972478    6.132314    4.3315625   4.831616    3.1196222
  6.042688    6.9899635   6.106084    6.014324    6.654696    5.967949
  

In [10]:
%%time
rng, subrng = jax.random.split(rng)
env_state = env.reset(subrng)
init = (rng, env_state)
_, out = jax.lax.scan(one_step_interaction, init, jnp.arange(1000)) 
out['reward'].sum()

CPU times: user 141 ms, sys: 0 ns, total: 141 ms
Wall time: 133 ms


In [None]:
out['reward'].sum()