In [1]:

from IPython.display import HTML, IFrame, display, clear_output 

import os 
import jax
import tax
import brax
import jax.numpy as jnp
import functools
import matplotlib as mpl
import matplotlib.pyplot as plt

from brax import envs
from brax.training import ppo, sac
from brax.io import html

tax.set_platform('cpu')

def visualize(sys, qps):
  """Renders a 3D visualization of the environment."""
  return HTML(html.render(sys, qps))

# `Visualization of the Environment`

In [2]:
# Test basic random control of the agent.
from jax import jit 
from jax import vmap
from brax.envs.ant import Ant
from brax.envs.wrappers import GymWrapper

In [3]:
rng = jax.random.PRNGKey(42)
env = Ant()
env.step = jit(env.step)
env.reset = jit(env.reset)

In [4]:
state = env.reset(rng)

In [8]:
@jit
def step(carry, xs):
    rng, state = carry
    rng, rng_action = jax.random.split(rng)
    
    action = jax.random.uniform(rng_action, shape=(env.action_size,), 
                               minval=-1.0, maxval=1.0)
    new_state = env.step(state, action)    
    carry = [rng, new_state]
    info = {
        'reward': new_state.reward,
        'observation': state.obs,
        'observation_next': new_state.obs,
        'terminal': 1.0 - new_state.done,
        'steps': new_state.steps,
    }
    return carry, info

@jit
def episode(rng):
    state = env.reset(rng)
    init, xs = [rng, state], jnp.arange(1000)
    _, info = jax.lax.scan(step, init, xs)
    return info

vepisode = jit(vmap(episode))

In [8]:
%%time
init, xs = [rng, state], jnp.arange(10)
_, info = jax.lax.scan(step, init, xs)

CPU times: user 21.5 s, sys: 23.9 ms, total: 21.6 s
Wall time: 22 s


In [7]:
%%time
info = episode(rng)

CPU times: user 23.7 s, sys: 63 ms, total: 23.7 s
Wall time: 23.7 s


In [26]:
%%time
rng, subrng = jax.random.split(rng)
brng = jax.random.split(subrng, 256)
binfo = vepisode(brng)
print(binfo['observation'][0])

[[0.5133401  1.         0.         ... 0.         0.         0.        ]
 [0.501194   0.99731606 0.00557384 ... 0.         0.         0.        ]
 [0.43697065 0.9974769  0.02101452 ... 0.         0.         0.        ]
 ...
 [0.5251519  0.9531934  0.05337911 ... 0.         0.         0.        ]
 [0.511115   0.9457168  0.06538638 ... 0.         0.         0.        ]
 [0.47696105 0.9131294  0.01408968 ... 0.         0.         0.        ]]
CPU times: user 6.58 s, sys: 96.1 ms, total: 6.67 s
Wall time: 6.57 s


In [27]:
binfo['observation'].shape

(256, 1000, 87)

In [5]:
env.action_size

8

In [6]:
env.observation_size

87