In [1]:
import gym
import jax
import tax
import tree
import optax
import jax.numpy as jnp
import numpy as np
import haiku as hk

from a2c.a2c import Data
from a2c.a2c import State
from a2c.a2c import Batch
from a2c.a2c import process_data
from a2c.a2c import loss_fn
from a2c.common.utils import evaluation
from a2c.common.nn import mlp_categorical
from a2c.common.nn import mlp_deterministic
from jax import jit
from jax import vmap
from gym.vector import AsyncVectorEnv
from functools import partial

tax.set_platform('cpu')
rng = jax.random.PRNGKey(42)


# TODO Add discount, lambda reward_scaling -> `process_data_to_batch`
def init_nn(
    observation_size, action_size, seed=42,
    policy_opt = 'adabelief',
    policy_opt_kwargs = dict(learning_rate=5e-3),
    value_opt = 'adabelief',
    value_opt_kwargs = dict(learning_rate=5e-3),
    policy_kwargs: dict = dict(), value_kwargs: dict = dict(),
):
    rng = jax.random.PRNGKey(seed)
    dummy_action = jnp.zeros((action_size))
    dummy_observation = jnp.zeros((observation_size))
    
    policy_def = lambda x: mlp_categorical(action_size, **policy_kwargs)(x)
    policy_def = hk.transform(policy_def)
    policy_def = hk.without_apply_rng(policy_def)
    policy_opt = getattr(optax, policy_opt)(**policy_opt_kwargs)
    
    value_def = lambda x: mlp_deterministic(1, **value_kwargs)(x).squeeze(-1)
    value_def = hk.transform(value_def)
    value_def = hk.without_apply_rng(value_def)
    value_opt = getattr(optax, value_opt)(**value_opt_kwargs)
    
    rng, rng_policy, rng_value = jax.random.split(rng, 3)
    value_params = value_def.init(rng_policy, dummy_observation)
    value_opt_state = value_opt.init(value_params)
    policy_params = policy_def.init(rng_policy, dummy_observation)
    policy_opt_state = policy_opt.init(policy_params)

    params = {'policy': policy_params, 'value': value_params}
    opt_state = {'policy': policy_opt_state, 'value': value_opt_state}
    
    process_data_to_batch = partial(
        process_data, value_apply=value_def.apply,    
    )
    
    loss = partial(loss_fn, value_apply=value_def.apply, policy_apply=policy_def.apply)

    
    @partial(jit, static_argnums=(2, 3,))
    def update_fn(state, inputs, policy_opt, value_opt):
        """ Generic Update function """
        g, metrics = jax.grad(loss, has_aux=True)(state.params, inputs)

        updates, value_opt_state = value_opt(g['value'], state.opt_state['value'])
        value_params = jax.tree_multimap(lambda p, u: p + u, state.params['value'], updates)

        updates, policy_opt_state = policy_opt(g['policy'], state.opt_state['policy'])
        policy_params = jax.tree_multimap(lambda p, u: p + u, state.params['policy'], updates)

        params = state.params
        params = dict(policy=policy_params, value=value_params)
        opt_state = state.opt_state
        opt_state = dict(policy=policy_opt_state, value=value_opt_state)
        state = state.replace(params=params, opt_state=opt_state)
        return state, metrics

    update = jit(partial(
        update_fn, 
        policy_opt=policy_opt.update, 
        value_opt=value_opt.update))

    
    def make_policy(state):
        fn = lambda rng, x: policy_def.apply(state.params['policy'], x).sample(seed=rng) 
        return jit(fn)
    
    return State(key=rng, params=params, opt_state=opt_state), {
        'process_data': jit(process_data_to_batch),
        'make_policy': make_policy,
        'loss': jit(loss), 'update': jit(update),
        'full_update': jit(lambda state, data: update(state, process_data_to_batch(state.params, data)))
    }


In [2]:
NEnvs = 8
make_env = lambda: gym.make('CartPole-v0')
env = AsyncVectorEnv([make_env for _ in range(NEnvs)])
env_test = gym.make('CartPole-v0')
action_size = env_test.action_space.n
observation_size = env_test.observation_space.shape[0]
state, a2c = init_nn(observation_size, action_size, 42, policy_kwargs={}, value_kwargs={})

In [3]:
def interaction(env, horizon: int = 10, seed: int = 42):
    # -- Initialization
    rng = jax.random.PRNGKey(seed)
    observation, buf = env.reset(), []
    policy = yield
    
    # -- Interaction Loop.
    while True:
        for _ in range(horizon):
            rng, rng_action = jax.random.split(rng)
            action = np.array(policy(rng_action, observation))
            observation_next, reward, done, info = env.step(action)
            buf.append({
                'observation': observation,
                'reward': reward,
                'terminal': 1. - done,
                'action': action
            })
            observation = observation_next.copy()
            
        data = jit(tax.reduce)(buf)
        data['last_observation'] = observation
        policy = yield data
        buf = []

In [4]:
interaction_step = interaction(env, 10)
interaction_step.send(None)
policy = a2c['make_policy'](state) 

In [5]:
for e in range(1000):
    data = interaction_step.send(policy)
    state, info_fit = a2c['full_update'](state, data)
    policy = a2c['make_policy'](state) 
    interaction_step.send(policy)
    if e % 100 == 0:
        info_eval = evaluation(rng, env_test, policy)
        info = {}
        info.update(**info_fit, **info_eval)
        info = tree.map_structure(lambda v: float(v), info)
        print(info)

{'H_loss': -6.931200186954811e-05, 'a2c_loss': 9.696889877319336, 'policy_loss': 3.2070348262786865, 'value_loss': 12.979849815368652, 'eval/score': 19.8, 'eval/score_std': 8.034923770640267}
{'H_loss': -5.35979779670015e-05, 'a2c_loss': 3.5517690181732178, 'policy_loss': 1.2522631883621216, 'value_loss': 4.599118709564209, 'eval/score': 141.2, 'eval/score_std': 49.272304593960286}
{'H_loss': -3.089511665166356e-05, 'a2c_loss': 3.459258556365967, 'policy_loss': 0.6166175007820129, 'value_loss': 5.685344219207764, 'eval/score': 160.2, 'eval/score_std': 6.013318551349163}
{'H_loss': -2.937376848421991e-05, 'a2c_loss': 4.029115676879883, 'policy_loss': -0.27538248896598816, 'value_loss': 8.609055519104004, 'eval/score': 54.8, 'eval/score_std': 19.823218709382186}
{'H_loss': -2.5307897885795683e-05, 'a2c_loss': 4.361700534820557, 'policy_loss': 1.0481122732162476, 'value_loss': 6.627227306365967, 'eval/score': 110.4, 'eval/score_std': 5.1613951602255765}


KeyboardInterrupt: 

In [None]:
evaluation(rng, env_test, policy)

In [None]:
policy(rng, env.reset())

In [None]:
state.params['value']