In [1]:
import jax
import tax
import tqdm
import haiku as hk
import numpy as np
import collections 
import jax.numpy as jnp
import matplotlib as mpl
import matplotlib.pyplot as plt
import mbrl

from jax import jit
from functools import partial
from deluca.envs.classic._cartpole import reset, env_params, dynamics, step, render
from deluca.envs.classic._cartpole import reward_fn
from mbrl.algs.rs import forecast
from mbrl.algs.rs import plan
from mbrl.algs.rs import score

rng = jax.random.PRNGKey(42)
Environment = collections.namedtuple('Environment', ['step', 'reset'])

In [2]:
step     = jit(partial(step, env_params=env_params))
reset    = jit(partial(reset, env_params=env_params))
dynamics = jit(partial(dynamics, env_params=env_params))
env      = Environment(step=step, reset=reset)

In [3]:
action_size = 1
observation_size = env_params['state_size']

In [4]:
rng, rng_reset = jax.random.split(rng, 2)
state  = env.reset(rng_reset)
state_next, reward, done, info = env.step(state, 1)

In [5]:
@jit
def world(carry, t):
    keys, state, trajectory = carry
    action = trajectory[t]
    state_next, reward, done, info = env.step(state, action)
    reward = reward.astype(jnp.float32)
    carry = keys, state_next, trajectory
    return carry, {
        "observation": state,
        "observation_next": state_next,
        "reward": reward, "action": action, "terminal": 1. - done,
    }

In [6]:
score_    = jit(score)
forecast_ = partial(forecast, 
                    step_fn=world, 
                    horizon=20, 
                    action_dim=None, 
                    minval=None, 
                    maxval=2,     # Number of discrete actions possible
                    action_type='discrete')

In [7]:
rng, rng_reset = jax.random.split(rng, 2)
state_0 = env.reset(rng_reset)
action, _ = plan(rng, state_0, jit(forecast_), jit(score))
action[0]

DeviceArray(1, dtype=int32)

In [8]:
%%time
# RS:Model.
score = 0
rng, rng_reset = jax.random.split(rng, 2)
state = env.reset(rng_reset)
list_states = []
for _ in tqdm.notebook.trange(200):
    rng, rng_plan = jax.random.split(rng, 2)
    list_states.append(state)
    action = plan(rng_plan, state, forecast_, score_)[0][0]
    state, reward, terminal, info = env.step(state, action)
    score += reward
print(score)

  0%|          | 0/200 [00:00<?, ?it/s]

179.0
CPU times: user 1.56 s, sys: 99.5 ms, total: 1.66 s
Wall time: 863 ms


In [9]:
_, info = render(list_states[0])

In [10]:
for s in list_states:
    render(s, **info)