In [1]:
import gym
import jax
import tax
import tree
import optax
import numpy as np
import haiku as hk
import jax.numpy as jnp

from ppo.ppo import Data
from ppo.ppo import State
from ppo.ppo import Batch
from ppo.ppo import process_data
from ppo.ppo import loss_ppo_def
from ppo.common.utils import evaluation
from jax import jit
from jax import vmap
from functools import partial
from gym.vector import AsyncVectorEnv

tax.set_platform('cpu')
rng = jax.random.PRNGKey(42)

NENVS = 8

def evaluation(rng, env, policy, niters: int = 5):
    action_type = env.action_space.__class__.__name__
    all_scores = []
    for _ in range(niters):
        observation, score = env.reset(), 0
        for _ in range(env.spec.max_episode_steps):
            rng, rng_action = jax.random.split(rng)
            action = policy(rng, observation)
            if action_type == 'Discrete':
                action = int(action)
            else:
                action = np.asarray(action)
            observation, reward, done, info = env.step(action)
            score += reward
            if done:
                break
        all_scores.append(score)
    info = {}
    info['eval/score'] = np.mean(all_scores)
    info['eval/score_std'] = np.std(all_scores)
    return info


def interaction(env, horizon: int = 10, seed: int = 42):
    rng = jax.random.PRNGKey(seed)
    observation, buf = env.reset(), []
    policy = yield
    
    # -- Interaction Loop.
    
    while True:
        for _ in range(horizon):
            rng, rng_action = jax.random.split(rng)
            action = np.array(policy(rng_action, observation))
            observation_next, reward, done, info = env.step(action)
            buf.append({
                'observation': observation,
                'reward': reward,
                'terminal': 1. - done,
                'action': action
            })
            observation = observation_next.copy()
            
        data = jit(tax.reduce)(buf)
        data['last_observation'] = observation
        policy = yield data
        buf = []

In [2]:
name = 'Pendulum-v0'
make_env = lambda: gym.make(name)

env              = AsyncVectorEnv([make_env for _ in range(NENVS)])
env_test         = gym.make(name)
action_size      = env_test.action_space.shape[0]
observation_size = env_test.observation_space.shape[0]

# `Neural Network and Optimizers`

In [3]:
rng = jax.random.PRNGKey(42)
dummy_action = jnp.zeros((action_size,))
dummy_observation = jnp.zeros((observation_size,))

policy_def = lambda x: tax.mlp_multivariate_normal_diag(
    action_size, logstd_min=-10.0, logstd_max=3.0)(x)
policy_def = hk.transform(policy_def)
policy_def = hk.without_apply_rng(policy_def)
policy_opt = getattr(optax, 'adabelief')(learning_rate=5e-4)
value_def  = lambda x: tax.mlp_deterministic(1)(x).squeeze(-1)
value_def  = hk.transform(value_def)
value_def  = hk.without_apply_rng(value_def)
value_opt  = getattr(optax, 'adabelief')(learning_rate=5e-4)

rng, rng_policy, rng_value = jax.random.split(rng, 3)
value_params               = value_def.init(rng_policy, dummy_observation)
value_opt_state            = value_opt.init(value_params)
policy_params              = policy_def.init(rng_policy, dummy_observation)
policy_opt_state           = policy_opt.init(policy_params)

params    = {'policy': policy_params, 'value': value_params}
opt_state = {'policy': policy_opt_state, 'value': value_opt_state}
state     = State(params=params, opt_state=opt_state, key=rng)

process_data2batch = partial(process_data, 
                             value_apply=value_def.apply,
                             policy_apply=policy_def.apply)

def _make_policy(p):
    fn = lambda rng, x: policy_def.apply(p, x).sample(seed=rng) 
    return jit(fn)

In [4]:
loss = partial(loss_ppo_def, 
               epsilon_ppo=0.2,
               entropy_cost=0.1,
               value_cost=1,
               value_apply=value_def.apply, 
               policy_apply=policy_def.apply)


@jit
def update_fn(state, inputs):
    """ Generic Update function """
    g, metrics = jax.grad(loss, has_aux=True)(state.params, inputs)

    updates, value_opt_state = value_opt.update(g['value'], state.opt_state['value'])
    value_params = jax.tree_multimap(lambda p, u: p + u, state.params['value'], updates)

    updates, policy_opt_state = policy_opt.update(g['policy'], state.opt_state['policy'])
    policy_params = jax.tree_multimap(lambda p, u: p + u, state.params['policy'], updates)

    params = state.params
    params = dict(policy=policy_params, value=value_params)
    opt_state = state.opt_state
    opt_state = dict(policy=policy_opt_state, value=value_opt_state)
    state = state.replace(params=params, opt_state=opt_state)
    return state, metrics

# `Initialization`

In [5]:
interaction_step = interaction(env, horizon=100)
interaction_step.send(None)

# `Training Loop`

In [None]:
EPOCH     = 10
MINIBATCH = 32

@jit
def _step(carry, xs):
    state, batch = carry
    minibatch = tree.map_structure(lambda v: v[xs], batch)
    state, info = update_fn(state, minibatch)
    carry = (state, batch)
    return carry, info


@jit
def _fit(carry, xs):
    state, batch, key = carry
    n                 = batch['observation'].shape[0]
    key, subkey       = jax.random.split(rng)
    index             = jnp.arange(n)
    indexes           = jax.random.permutation(subkey, index)
    indexes           = jnp.stack(
        jnp.array_split(indexes, n / MINIBATCH)
    )

    _carry = (state, batch)
    _xs    = indexes
    (state, batch), info = jax.lax.scan(_step, _carry, _xs)    
    carry = (state, batch, key)
    return carry, info

import tqdm

for _ in tqdm.notebook.trange(100):
    for _ in range(10):
        data  = interaction_step.send(_make_policy(state.params['policy']))
        batch = jit(process_data2batch)(state.params, data)

        # Update state.
        rng, subrng               = jax.random.split(rng)
        carry                     = (state, batch, subrng)
        xs                        = jnp.arange(EPOCH)
        (state, batch, key), info = _fit(carry, xs)
        info = tree.map_structure(lambda v: jnp.mean(v), info)

    policy_fn = _make_policy(state.params['policy'])
    info_eval =  evaluation(rng, env_test, policy_fn)
    print({**info_eval, **info})