In [1]:
import jax
import tax
import tqdm
import haiku as hk
import numpy as np
import collections 
import jax.numpy as jnp
import matplotlib as mpl
import matplotlib.pyplot as plt
import mbrl
import brax
import tqdm
import functools

from brax import envs
from brax.io import html
from jax import jit
from functools import partial
from mbrl.algs.rs import trajectory_search, forecast, score, plan
from IPython.display import HTML, IFrame, display, clear_output 

def visualize(sys, qps):
  """Renders a 3D visualization of the environment."""
  return HTML(html.render(sys, qps))


tax.set_platform('cpu')

rng = jax.random.PRNGKey(42)


In [2]:
name = 'halfcheetah'
envf = envs.create_fn(name)
env = envf()
env_state = env.reset(rng=rng)
action_size = env.action_size
observation_size = env.observation_size

In [3]:
@jit
def one_step_interaction(carry, t):
    rng, env_state = carry
    rng, key = jax.random.split(rng)
    action = jax.random.uniform(rng, (action_size,), minval=-1, maxval=1)
    env_state_next = env.step(env_state, action)
    carry = (rng, env_state_next)
    
    info = dict(
        observation=env_state.obs,
        observation_next=env_state.obs,
        reward=env_state_next.reward,
        terminal=1 - env_state_next.done,
        action=action,
        env_state=env_state,
        env_state_next=env_state_next,
    )
    return carry, info

In [4]:
env_state = env.reset(rng)
init = (rng, env_state)

In [5]:
%%time
_, out = jax.lax.scan(one_step_interaction, init, jnp.arange(1000))  # First should be long.

CPU times: user 19.4 s, sys: 6.18 ms, total: 19.4 s
Wall time: 19.4 s


In [14]:
%%time
rng, subrng = jax.random.split(rng)
env_state = env.reset(subrng)
init = (rng, env_state)
_, out = jax.lax.scan(one_step_interaction, init, jnp.arange(1000)) 
out['reward']

CPU times: user 97.4 ms, sys: 7.96 ms, total: 105 ms
Wall time: 97.1 ms


DeviceArray([ 8.32524970e-02,  1.10118568e+00,  4.70390886e-01,
             -1.04250371e-01, -7.90085018e-01,  6.64394975e-01,
              1.88233167e-01, -3.75960410e-01,  4.71359432e-01,
             -6.71074033e-01, -8.34923267e-01, -1.60202587e+00,
             -5.15150309e-01, -3.28432679e-01,  1.23119451e-01,
             -4.51283574e-01,  1.58793077e-01,  9.36416388e-01,
             -1.79180786e-01, -4.22292262e-01, -1.33853662e+00,
             -7.29121566e-01, -8.05391073e-01, -9.14524794e-01,
             -5.66791952e-01,  1.73758306e-02, -1.48041308e+00,
             -1.09095883e+00,  4.13704067e-01, -1.12802136e+00,
             -6.86109841e-01, -3.86011362e-01, -6.79790258e-01,
             -8.22908759e-01, -1.76198602e+00, -4.73119736e-01,
             -6.16440654e-01, -1.37297082e+00,  3.94858122e-02,
             -1.13182855e+00, -7.04586029e-01, -9.57287014e-01,
             -3.14601898e-01,  4.04236972e-01, -2.38825083e-01,
              2.31088072e-01,  1.2285625