In [1]:
import jax
import tax
import clu
import tqdm
import haiku as hk
import numpy as np
import collections 
import jax.numpy as jnp
import matplotlib as mpl
import matplotlib.pyplot as plt
import mbrl
import brax
import tqdm
import functools

from brax import envs
from brax.io import html
from jax import jit
from functools import partial
from mbrl.algs.rs import trajectory_search, forecast, score, plan
from IPython.display import HTML, IFrame, display, clear_output 

def visualize(sys, qps):
  """Renders a 3D visualization of the environment."""
  return HTML(html.render(sys, qps))


tax.set_platform('cpu')

rng = jax.random.PRNGKey(42)

name = 'halfcheetah'
envf = envs.create_fn(name)
env = envf()
env_state = env.reset(rng=rng)
action_size = env.action_size
observation_size = env.observation_size

@jit
def step(carry, t):
    rng, env_state, action_trajectory = carry
    action = action_trajectory[t]
    env_state_next = env.step(env_state, action)
    carry = (rng, env_state_next, action_trajectory)
    
    info = dict(
        observation=env_state.obs,
        observation_next=env_state.obs,
        reward=env_state_next.reward,
        terminal=1 - env_state_next.done,
        action=action,
        env_state=env_state,
        env_state_next=env_state_next,
    )
    return carry, info

In [2]:
forecast_ = partial(
    forecast, step_fn=step,
    horizon=20, action_dim=action_size,
    minval=-1, maxval=1,
)

In [3]:
@jit
def one_step_interaction(carry, t):
    rng, env_state = carry
    action = plan(rng, env_state, forecast_, score)[0][0]
    env_state_next = env.step(env_state, action)
    carry = (rng, env_state_next)
    
    info = dict(
        observation=env_state.obs,
        observation_next=env_state.obs,
        reward=env_state_next.reward,
        terminal=1 - env_state_next.done,
        action=action,
        env_state=env_state,
        env_state_next=env_state_next,
    )
    return carry, info

In [4]:
%%time
# Compilation of the interaction with the environment during one episode 
env_state = env.reset(rng)
init = (rng, env_state)
_, out = jax.lax.scan(one_step_interaction, init, jnp.arange(1000))  # First should be long.
print('Compilation Done.')

Compilation Done.
CPU times: user 1min 23s, sys: 143 ms, total: 1min 23s
Wall time: 1min 23s


In [None]:
_, out = jax.lax.scan(one_step_interaction, init, jnp.arange(1000)) 

# Training Loop (MBRL + Value Function)

In [None]:
# Exploration (=> Gather Data)
info = jax.lax.scan(one_step_interaction, init, jnp.arange(1000)) 

In [None]:
rb = tax.ReplayBuffer(100_000)

In [None]:
rb.add({
    'observation': observation,
    'observation_next': observation_next,
    'reward': reward,
    'discount': 
})