In [1]:
import jax
import tax
import tqdm
import tree
import haiku as hk
import numpy as np
import collections 
import jax.numpy as jnp
import matplotlib as mpl
import matplotlib.pyplot as plt
import mbrl
import deluca 

from jax import jit
from functools import partial
from deluca.envs import CartPole
from mbrl.algs.rs import trajectory_search, forecast, score, plan

key = jax.random.PRNGKey(42)


def filter_env_params(x):
    x = x.copy()
    if 'random' in x:
        x.pop('random')
    if 'state' in x:
        x.pop('state')
    x['kinematics_integrator'] = 'euler'
    T = collections.namedtuple('T', x.keys())
    return T(**x)
    


@partial(jit, static_argnums=(3,))
def cartpole_rw(state, action, state_next, env_params):
    x, x_dot, theta, theta_dot = state_next

    done = jax.lax.cond(
        (jnp.abs(x) > jnp.abs(env_params.x_threshold))
        + (jnp.abs(theta) > jnp.abs(env_params.theta_threshold_radians)),
        lambda done: True,
        lambda done: False,
        None,
    )

    reward = 1 - done
    return reward
    

@partial(jit, static_argnums=(2,))
def cartpole_dy(state, action, env_params):
    return CartPole.dynamics(env_params, state, action)




In [2]:
env = CartPole()


action_size = 1
observation_size = env.observation_space.shape[0]
env_params = filter_env_params(env.__dict__['attrs_'])

cartpole_reward_fn = partial(cartpole_rw, env_params=env_params)
cartpole_dynamics  = partial(cartpole_dy, env_params=env_params) 

In [3]:
state  = env.reset()
state_next, reward, done, info = env.step(1)

In [4]:
# Sanity Check
action = 1 

# -- Reward Function.
rw = cartpole_reward_fn(state, action, state_next)
print(f'Reward: {reward}:{rw}')
# -- Dynamics Function
state_next_model = cartpole_dy(state, action, env_params)
print(f'State: {state_next}:{state_next_model}')

state_next_model = cartpole_dynamics(state, action)
print(f'State: {state_next}:{state_next_model}')


Reward: 1:1
State: [ 0.04598246  0.16744456  0.01289576 -0.30881795]:[ 0.04598246  0.16744456  0.01289576 -0.30881795]
State: [ 0.04598246  0.16744456  0.01289576 -0.30881795]:[ 0.04598246  0.16744456  0.01289576 -0.30881795]


In [11]:
@jit
def world(carry, t):
    keys, state, trajectory = carry
    action = trajectory[t]
    state_next = cartpole_dynamics(state, action)
    reward = cartpole_reward_fn(state, action, state_next).astype('float')
    carry = keys, state_next, trajectory
    return carry, {
        "observation": state,
        "observation_next": state_next,
        "reward": reward, "action": action, "terminal": 1 - int(False),
    }

In [12]:
score_    = jit(score)
forecast_ = partial(forecast, 
                    step_fn=world, 
                    horizon=20, 
                    action_dim=None, 
                    minval=None, 
                    maxval=2,
                    action_type='discrete')

In [15]:
state_0 = env.reset()
action, _ = plan(key, state_0, forecast_, score_)
action[0]

DeviceArray(0, dtype=int32)

In [21]:
%%time
# RS:Model.
score = 0
state = env.reset()
for _ in tqdm.notebook.trange(200):
    key, subkey = jax.random.split(key)
    action = plan(key, state, forecast_, score_)[0][0]
    state, reward, terminal, info = env.step(action)
    score += reward
print(score)

  0%|          | 0/200 [00:00<?, ?it/s]

200
CPU times: user 4.37 s, sys: 345 ms, total: 4.71 s
Wall time: 3.71 s
