In [13]:
import jax
import tax
import tqdm
import haiku as hk
import numpy as np
import collections 
import jax.numpy as jnp
import matplotlib as mpl
import matplotlib.pyplot as plt
import mbrl

from jax import jit
from functools import partial
from mbrl.envs.oracle._acrobot import env_params, dynamics, render
from mbrl.envs.oracle._acrobot import reset_fn, step_fn
from mbrl.algs.rs import forecast
from mbrl.algs.rs import plan
from mbrl.algs.rs import score

rng = jax.random.PRNGKey(42)
Environment = collections.namedtuple('Environment', ['step', 'reset'])

env_params

{'dt': 0.2,
 'LINK_LENGTH_1': 1.0,
 'LINK_LENGTH_2': 1.0,
 'LINK_MASS_1': 1.0,
 'LINK_MASS_2': 1.0,
 'LINK_COM_POS_1': 0.5,
 'LINK_COM_POS_2': 0.5,
 'LINK_MOI': 1.0,
 'MAX_VEL_1': 12.566370614359172,
 'MAX_VEL_2': 28.274333882308138,
 'AVAIL_TORQUE': DeviceArray([-1.,  0.,  1.], dtype=float32),
 'torque_noise_max': 0.0,
 'book_or_nips': 'book',
 'action_arrow': None,
 'domain_fig': None,
 'actions_num': 3,
 'action_size': 3,
 'action_type': 'discrete',
 'max_episode_steps': 500}

In [14]:
step_fn = jit(partial(step_fn, env_params=env_params))
reset_fn = jit(partial(reset_fn, env_params=env_params))
dynamics = jit(partial(dynamics, env_params=env_params))
env = Environment(step=step_fn, reset=reset_fn)

In [23]:
state_0, observation_0 = reset_fn(rng)
u = 1      # 0, 1, 2
dynamics(state_0, u)
step_fn(state_0, u)

(DeviceArray([-0.00265431,  0.04787853, -0.08390957, -0.1455917 ], dtype=float32),
 DeviceArray([ 0.9999965 , -0.00265431,  0.99885404,  0.04786024,
              -0.08390957, -0.1455917 ], dtype=float32),
 DeviceArray(-1., dtype=float32),
 DeviceArray(0., dtype=float32),
 {})

In [37]:
@jit
def world(carry, t):
    keys, (state, observation), trajectory = carry
    action = trajectory[t]
    rng = keys[t]
    state_next, observation_next, reward, done, info = env.step(state, action)
    reward = reward.astype(jnp.float32)
    carry = keys, (state_next, observation_next), trajectory
    return carry, {
        "state": state,
        "state_next": state_next,
        "observation": observation,
        "observation_next": observation_next,
        "reward": reward, "action": action, "terminal": 1. - done,
    }

In [39]:
score_    = jit(score)
forecast_ = partial(forecast, 
                    step_fn=world, 
                    horizon=250,
                    action_dim=3, 
                    minval=None, 
                    maxval=None,     # Number of discrete actions possible
                    action_type='discrete')

#forecast_(rng, (state_0, observation_0))

In [40]:
rng, rng_reset = jax.random.split(rng, 2)
state_0, observation_0  = env.reset(rng_reset)
action, _ = plan(rng, (state_0, observation_0), jit(forecast_), jit(score))
action[0]

DeviceArray(0, dtype=int32)

In [44]:
%%time
# RS:Model.
score = 0
rng, rng_reset = jax.random.split(rng, 2)
state, observation = env.reset(rng_reset)
list_states = []
for _ in tqdm.notebook.trange(200):
    rng, rng_plan = jax.random.split(rng, 2)
    list_states.append(state)
    action = plan(rng_plan, (state, observation), forecast_, score_,  population=5000)[0][0]
    state, observation, reward, terminal, info = env.step(state, action)
    score += reward
print(score)

  0%|          | 0/200 [00:00<?, ?it/s]

-200.0
CPU times: user 8.94 s, sys: 418 ms, total: 9.36 s
Wall time: 8.46 s


In [45]:
_, info = render(list_states[0], {})

In [46]:
for s in list_states:
    render(s, info)