In [1]:
import gym
import jax
import tax
import tree
import tqdm
import optax
import collections
import jax.numpy as jnp
import numpy as np
import haiku as hk

from ppo.ppo import Data
from ppo.ppo import State
from ppo.ppo import Batch
from ppo.ppo import update_ppo
from jax import jit
from jax import vmap
from gym.vector import AsyncVectorEnv
from functools import partial
from common import gym_evaluation
from common import gym_interaction

tax.set_platform('cpu')


from mbrl.envs.oracle._cartpole import env_params
from mbrl.envs.oracle._cartpole import step_fn
from mbrl.envs.oracle._cartpole import reset_fn
from mbrl.envs.oracle._cartpole import reward_fn
from mbrl.envs.oracle._cartpole import dynamics_fn
from mbrl.envs.oracle._cartpole import render

rng = jax.random.PRNGKey(42)
Environment = collections.namedtuple('Environment', ['step', 'reset'])

step             = jit(partial(step_fn, env_params=env_params))
reset            = jit(partial(reset_fn, env_params=env_params))
dynamics         = jit(partial(dynamics_fn, env_params=env_params))
env              = Environment(step=step, reset=reset)
action_size      = 2
observation_size = env_params['state_size']

In [2]:
rng, rng_reset = jax.random.split(rng, 2)
env_state, observation  = env.reset(rng_reset)
env_state_next, observation_next, reward, done, info = env.step(env_state, 1)

In [3]:

dummy_action      = jnp.zeros((action_size,))
dummy_observation = jnp.zeros((observation_size,))

policy_def = lambda x: tax.mlp_categorical(action_size)(x)
policy_def = hk.transform(policy_def)
policy_def = hk.without_apply_rng(policy_def)
policy_opt = getattr(optax, 'adabelief')(learning_rate=1e-4)
value_def  = lambda x: tax.mlp_deterministic(1)(x).squeeze(-1)
value_def  = hk.transform(value_def)
value_def  = hk.without_apply_rng(value_def)
value_opt  = getattr(optax, 'adabelief')(learning_rate=1e-4)

rng, rng_policy, rng_value = jax.random.split(rng, 3)
value_params               = value_def.init(rng_policy, dummy_observation)
value_opt_state            = value_opt.init(value_params)
policy_params              = policy_def.init(rng_policy, dummy_observation)
policy_opt_state           = policy_opt.init(policy_params)

params    = {'policy': policy_params, 'value': value_params}
opt_state = {'policy': policy_opt_state, 'value': value_opt_state}
state     = State(params=params, opt_state=opt_state, key=rng)

policy_apply = jit(policy_def.apply)
value_apply  = jit(value_def.apply) 

loss_kwargs = {    
    'policy_apply': policy_apply,
    'value_apply': value_apply,
}

process_kwargs = {
    'policy_apply': policy_apply,
    'value_apply': value_apply,
}

loss_kwargs = hk.data_structures.to_immutable_dict(loss_kwargs)
process_kwargs = hk.data_structures.to_immutable_dict(process_kwargs)
update = partial(update_ppo,     
    policy_opt=policy_opt.update, 
    value_opt=value_opt.update, 
    loss_kwargs=loss_kwargs, 
    process_data_kwargs=process_kwargs, 
    max_grad_norm=-1.0)

update = jit(update)



In [4]:

@jit
def _onestep(carry, xs):
    p, rng, (env_state, observation) = carry
    rng, rng_action, rng_reset = jax.random.split(rng, 3)
    action = policy_apply(p, observation).sample(seed=rng_action)

    state_next, observation_next, reward, done, info = \
        env.step(env_state, action)

    state_next, observation_next  = jax.lax.cond(
        done,
        lambda rng: env.reset(rng),
        lambda rng: (state_next, observation_next),
        operand=rng)

    info = {
        'observation': observation,
        'observation_next': observation_next,
        'env_state': state,
        'env_state_next': state_next,
        'rng': rng,
        'reward': reward,
        'action': action,
        'terminal':  1. - done,
    }
    
    rng, subrng = jax.random.split(rng)
    carry = p, rng, (state_next, observation_next)
    return carry, info


carry = (state.params['policy'], rng, (env_state, observation))
xs = jnp.arange(10)
_, data = jax.lax.scan(_onestep, carry, xs)


@jit
def evaluation(rng, p):
    rng, rng_reset = jax.random.split(rng, 2)
    env_state, observation = env.reset(rng_reset)
    carry = (p, rng, (env_state, observation))
    xs = jnp.arange(200) # HARDCODE
    carry, data = jax.lax.scan(_onestep, carry, xs)
    return data

@jit
def interaction_step(rng, p, env_state, observation):
    carry = (p, rng, (env_state, observation))
    xs = jnp.arange(10) # HARDCODE
    carry, data = jax.lax.scan(_onestep, carry, xs)
    (_, _, (state_next, observation_next)) = carry
    data['last_observation'] = observation_next
    return (state_next, observation_next), data

# Compilation
_ = interaction_step(rng, state.params['policy'], env_state, observation)
data = evaluation(rng, state.params['policy'])

In [5]:
vinteraction = vmap(interaction_step, (0, None, 0, 0))

vrng = jax.random.split(rng, 16)
venv_state, vobservation = vmap(env.reset)(vrng)
_, data = vinteraction(vrng, state.params['policy'], venv_state, vobservation)

In [6]:
@jit
def epoch(state, venv_state, vobservation):
    (venv_state, vobservation), data = vinteraction(
        vrng, state.params['policy'], venv_state, vobservation)
    
    d = {k: v for k, v in data.items() \
         if k in ['last_observation', 'observation', 
                  'observation_next', 'reward', 
                  'terminal', 'action']}
    
    dd = {}
    dd['last_observation'] = d['last_observation']
    dd['observation']      = d['observation'].transpose(1, 0, 2)
    dd['observation_next'] = d['observation_next'].transpose(1, 0, 2)
    dd['reward']           = d['reward'].transpose(1, 0)
    dd['terminal']         = d['terminal'].transpose(1, 0)
    dd['action']           = d['action'].transpose(1, 0)
    
    state, update_info = update(state, dd)
    
    return state, (venv_state, vobservation), update_info

In [7]:
for _ in tqdm.notebook.trange(100):
    for _ in tqdm.notebook.trange(100):
        state, (venv_state, vobservation), info = epoch(state, venv_state, vobservation)

    data = evaluation(rng, state.params['policy'])
    t = jnp.where(data['terminal'] == 0.)
    if len(t[0]) == 0:
        score = sum(data['reward'])
    else:
        score = sum(data['reward'][:t[0][0]])
    print(score)

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

65.0


  0%|          | 0/100 [00:00<?, ?it/s]

133.0


  0%|          | 0/100 [00:00<?, ?it/s]

134.0


  0%|          | 0/100 [00:00<?, ?it/s]

200.0


  0%|          | 0/100 [00:00<?, ?it/s]

200.0


  0%|          | 0/100 [00:00<?, ?it/s]

200.0


  0%|          | 0/100 [00:00<?, ?it/s]

200.0


  0%|          | 0/100 [00:00<?, ?it/s]

200.0


  0%|          | 0/100 [00:00<?, ?it/s]

200.0


  0%|          | 0/100 [00:00<?, ?it/s]

200.0


  0%|          | 0/100 [00:00<?, ?it/s]

13.0


  0%|          | 0/100 [00:00<?, ?it/s]

7.0


  0%|          | 0/100 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
data['reward']

In [None]:
# TODO: Random Key should be change iterativeley
state