In [1]:
import jax
import tax
import tqdm
import haiku as hk
import numpy as np
import collections 
import jax.numpy as jnp
import matplotlib as mpl
import matplotlib.pyplot as plt
import mbrl

from jax import jit
from functools import partial
from deluca.envs.classic._pendulum import step, reset, reward, env_params
from deluca.envs.classic._pendulum import render
from mbrl.algs.rs import forecast
from mbrl.algs.rs import plan
from mbrl.algs.rs import score

rng = jax.random.PRNGKey(42)
Environment = collections.namedtuple('Environment', ['step', 'reset'])

In [2]:
step = jit(partial(step, env_params=env_params))
reset = jit(partial(reset, env_params=env_params))
env = Environment(step, reset)

In [3]:
state0, observation0 = reset(rng)
u = jax.random.uniform(rng, (1,), minval=-2, maxval=2.)

In [4]:
step(state0, u)

(DeviceArray([0.3965062, 1.0385346], dtype=float32),
 DeviceArray([0.9224159 , 0.38619795, 1.0385346 ], dtype=float32),
 DeviceArray(-0.18756944, dtype=float32),
 DeviceArray(False, dtype=bool),
 {})

In [5]:
def world(carry, t):
    keys, (env_state, observation), trajectory = carry
    action = trajectory[t]
    env_state_next, observation_next, reward, terminal, info = \
        env.step(env_state, action)
    reward = reward.astype('float')
    carry = keys, (env_state_next, observation_next), trajectory
    return carry, {
        "observation": observation,
        "observation_next": observation_next,
        "reward": reward, "action": action, "terminal": 1 - terminal,
        "env_state": env_state, 'env_state_next': env_state_next
    }

In [6]:
forecast_fn = partial(forecast, step_fn=world, horizon=20, minval=-2., maxval=2.0, action_dim=1)
score_fn    = partial(score, discount=0.99, terminal_reward_fn=None)

In [7]:
%%time
# RS:Model.
score = 0
list_states = []
env_state, observation = env.reset(rng)
for _ in tqdm.notebook.trange(200):
    rng, key = jax.random.split(rng)
    list_states.append(env_state)
    action = plan(key, (env_state, observation), forecast_fn, score_fn)[0][0]
    env_state, observation, reward, terminal, info = env.step(env_state, action)
    score += reward

print(f'Random Score: {score}')

  0%|          | 0/200 [00:00<?, ?it/s]

  lax._check_user_dtype_supported(dtype, "astype")


Random Score: -125.24327087402344
CPU times: user 1.88 s, sys: 259 ms, total: 2.14 s
Wall time: 1.73 s


In [8]:
""" Entire Loop with scan"""

def one_step(carry, t):
    key, (env_state, observation)  = carry
    key, subkey = jax.random.split(key)
    action = plan(subkey, (env_state, observation), forecast_fn, score_fn)[0][0]
    env_state_next, observation_next, reward, terminal, info = \
        env.step(env_state, action)
    carry = key, (env_state_next, observation_next)
    return carry, {
        "observation": observation,
        "observation_next": observation_next,
        "reward": reward, "action": action, "terminal": 1 - terminal,
        "env_state": env_state, 'env_state_next': env_state_next
    }

In [9]:
%%time
env_state, observation = env.reset(rng)
init = (rng, (env_state, observation))
_, out = jax.lax.scan(one_step, init, jnp.arange(200))
jnp.sum(out['reward'])

CPU times: user 1.15 s, sys: 2.81 ms, total: 1.15 s
Wall time: 639 ms


DeviceArray(-243.81517, dtype=float32)

In [10]:
_, kwargs = render(list_states[0], {})

In [11]:
for s in list_states:
    render(s, kwargs)