In [1]:
import gym
import jax
import tax
import tree
import tqdm
import optax
import jax.numpy as jnp
import numpy as np
import haiku as hk

from a2c.a2c import Data
from a2c.a2c import State
from a2c.a2c import Batch
from a2c.a2c import update_a2c
from jax import jit
from jax import vmap
from gym.vector import AsyncVectorEnv
from functools import partial
from common import gym_evaluation
from common import gym_interaction

tax.set_platform('cpu')
rng = jax.random.PRNGKey(42)

In [3]:
NENVS    = 8
name     = 'CartPole-v0'
make_env = lambda: gym.make(name)

env              = AsyncVectorEnv([make_env for _ in range(NENVS)])
env_test         = gym.make(name)
action_size      = env_test.action_space.n
observation_size = env_test.observation_space.shape[0]

In [7]:
rng = jax.random.PRNGKey(42)

dummy_action      = jnp.zeros((action_size,))
dummy_observation = jnp.zeros((observation_size,))

policy_def = lambda x: tax.mlp_categorical(action_size)(x)
policy_def = hk.transform(policy_def)
policy_def = hk.without_apply_rng(policy_def)
policy_opt = getattr(optax, 'adabelief')(learning_rate=5e-4)
value_def  = lambda x: tax.mlp_deterministic(1)(x).squeeze(-1)
value_def  = hk.transform(value_def)
value_def  = hk.without_apply_rng(value_def)
value_opt  = getattr(optax, 'adabelief')(learning_rate=5e-4)

rng, rng_policy, rng_value = jax.random.split(rng, 3)
value_params               = value_def.init(rng_policy, dummy_observation)
value_opt_state            = value_opt.init(value_params)
policy_params              = policy_def.init(rng_policy, dummy_observation)
policy_opt_state           = policy_opt.init(policy_params)

params    = {'policy': policy_params, 'value': value_params}
opt_state = {'policy': policy_opt_state, 'value': value_opt_state}
state     = State(params=params, opt_state=opt_state, key=rng)

policy_apply = jit(policy_def.apply)
value_apply  = jit(value_def.apply) 

loss_kwargs = {    
    'policy_apply': policy_apply,
    'value_apply': value_apply,
}

process_kwargs = {
    'value_apply': value_apply,
}

loss_kwargs = hk.data_structures.to_immutable_dict(loss_kwargs)
process_kwargs = hk.data_structures.to_immutable_dict(process_kwargs)
update = partial(update_a2c,     
    policy_opt=policy_opt.update, 
    value_opt=value_opt.update, 
    loss_kwargs=loss_kwargs, 
    process_data_kwargs=process_kwargs, 
    max_grad_norm=-1.0)
update = jit(update)


In [6]:
interaction = gym_interaction(env, jit(policy_def.apply))
interaction.send(None)

In [8]:
S = tax.Store()
for _ in tqdm.notebook.trange(2000):
    for _ in range(100):
        data  = interaction.send(state.params['policy'])
        state, update_info = update(state, data)
        S.add(**update_info)
    
    eval_info = gym_evaluation(rng, env_test, state.params['policy'], policy_apply)
    S.add(**eval_info)
    info = S.get()
    print(info)

  0%|          | 0/2000 [00:00<?, ?it/s]

{'H_loss': -9.9999976e-05, 'a2c_loss': 7.823637, 'policy_loss': 2.272276, 'value_loss': 11.102856, 'eval/score': 30.2, 'eval/score_std': 12.3515}
{'H_loss': -9.9999976e-05, 'a2c_loss': 12.699322, 'policy_loss': 0.86379796, 'value_loss': 23.671152, 'eval/score': 159.4, 'eval/score_std': 30.6046}
{'H_loss': -9.899998e-05, 'a2c_loss': 21.738735, 'policy_loss': 0.07244401, 'value_loss': 43.3327, 'eval/score': 200.0, 'eval/score_std': 0.0}
{'H_loss': -9.799998e-05, 'a2c_loss': 16.992573, 'policy_loss': 0.024600983, 'value_loss': 33.93605, 'eval/score': 154.2, 'eval/score_std': 38.9173}
{'H_loss': -9.299999e-05, 'a2c_loss': 7.946374, 'policy_loss': 0.176344, 'value_loss': 15.540167, 'eval/score': 200.0, 'eval/score_std': 0.0}


KeyboardInterrupt: 