In [1]:
import jax
import tax
import tqdm
import haiku as hk
import numpy as np
import collections 
import jax.numpy as jnp
import matplotlib as mpl
import matplotlib.pyplot as plt
import mbrl

from jax import jit
from functools import partial
from deluca.envs.classic._mountain_car import reset, env_params, dynamics, step
from deluca.envs.classic._mountain_car import render

from mbrl.algs.rs import forecast
from mbrl.algs.rs import plan
from mbrl.algs.rs import score

rng = jax.random.PRNGKey(42)
Environment = collections.namedtuple('Environment', ['step', 'reset'])

In [2]:
step     = jit(partial(step, env_params=env_params))
reset    = jit(partial(reset, env_params=env_params))
dynamics = jit(partial(dynamics, env_params=env_params))
env      = Environment(step, reset)

In [3]:
state0 = reset(rng)
u = jax.random.uniform(rng, shape=(1,))

In [4]:
dynamics(state0, u)

DeviceArray([-0.17480217, -0.00152968], dtype=float32)

In [5]:
step(state0, u)

(DeviceArray([-0.17480217, -0.00152968], dtype=float32),
 DeviceArray(-0.20355515, dtype=float32),
 DeviceArray(0., dtype=float32),
 {})

In [6]:
env_params

{'min_action': -1.0,
 'max_action': 1.0,
 'min_position': -1.2,
 'max_position': 0.6,
 'max_speed': 0.07,
 'goal_position': 0.45,
 'goal_velocity': 0.0,
 'power': 0.0015,
 'H': 50,
 'action_dim': 1}

In [7]:
def world(carry, t):
    keys, env_state, trajectory = carry
    action = trajectory[t]
    env_state_next , reward, terminal, info = \
        env.step(env_state, action)
    reward = reward.astype('float')
    carry = keys, env_state_next, trajectory
    return carry, {
        "observation": env_state,
        "observation_next": env_state_next,
        "reward": reward, "action": action, "terminal": 1 - terminal,
        "env_state": env_state, 'env_state_next': env_state_next
    }

In [8]:
forecast_fn = partial(forecast, step_fn=world, horizon=100, minval=-1., maxval=1.0, action_dim=1)
score_fn    = partial(score, discount=0.99, terminal_reward_fn=None)

In [9]:
%%time
# RS:Model.
score = 0
list_states = []
env_state = env.reset(rng)
for _ in tqdm.notebook.trange(200):
    rng, key = jax.random.split(rng)
    list_states.append(env_state)
    action = plan(key, env_state, forecast_fn, score_fn, population=50_000)[0][0]
    env_state, reward, terminal, info = env.step(env_state, action)
    score += reward

print(f'Random Score: {score}')

  0%|          | 0/200 [00:00<?, ?it/s]

  lax._check_user_dtype_supported(dtype, "astype")


Random Score: -15.95876407623291
CPU times: user 5.54 s, sys: 289 ms, total: 5.83 s
Wall time: 5.48 s


In [10]:
_, kwargs = render(list_states[0], {})

In [11]:
for s in list_states:
    render(s, kwargs)