In [1]:
import gym
import jax
import tax
import tree
import tqdm
import optax
import brax
import functools
import jax.numpy as jnp
import numpy as np
import haiku as hk

from a2c.a2c import Data
from a2c.a2c import State
from a2c.a2c import Batch
from a2c.a2c import process_data
from a2c.a2c import loss_fn
from a2c.common.utils import evaluation
from a2c.common.nn import mlp_multivariate_normal_diag
from a2c.common.nn import mlp_deterministic
from jax import jit
from jax import vmap
from gym.vector import AsyncVectorEnv
from functools import partial
from brax import envs

tax.set_platform('cpu')


In [2]:
rng = jax.random.PRNGKey(42)

In [3]:
def setup(
    seed=42,
    policy_opt = 'rmsprop',
    policy_opt_kwargs = dict(learning_rate=1e-3),
    value_opt = 'rmsprop',
    value_opt_kwargs = dict(learning_rate=1e-3),
    policy_kwargs: dict = dict(), value_kwargs: dict = dict(),
):
    rng = jax.random.PRNGKey(seed)

    env = Ant()
    action_size = env.action_size
    observation_size = env.observation_size
    
    dummy_action = jnp.zeros((action_size))
    dummy_observation = jnp.zeros((observation_size)) 
    
    policy_def = lambda x: mlp_multivariate_normal_diag(action_size, **policy_kwargs)(x)
    policy_def = hk.transform(policy_def)
    policy_def = hk.without_apply_rng(policy_def)
    policy_opt = getattr(optax, policy_opt)(**policy_opt_kwargs)
    
    value_def = lambda x: mlp_deterministic(1, **value_kwargs)(x).squeeze(-1)
    value_def = hk.transform(value_def)
    value_def = hk.without_apply_rng(value_def)
    value_opt = getattr(optax, value_opt)(**value_opt_kwargs)
    
    rng, rng_policy, rng_value = jax.random.split(rng, 3)
    value_params = value_def.init(rng_policy, dummy_observation)
    value_opt_state = value_opt.init(value_params)
    policy_params = policy_def.init(rng_policy, dummy_observation)
    policy_opt_state = policy_opt.init(policy_params)

    params = {'policy': policy_params, 'value': value_params}
    opt_state = {'policy': policy_opt_state, 'value': value_opt_state}
    
    process_data_to_batch = partial(
        process_data, value_apply=value_def.apply,    
    )
    
    loss = partial(loss_fn, value_apply=value_def.apply, policy_apply=policy_def.apply)

    @jit
    def update_fn(state, inputs):
        """ Generic Update function """
        g, metrics = jax.grad(loss, has_aux=True)(state.params, inputs)

        updates, value_opt_state = value_opt.update(g['value'], state.opt_state['value'])
        value_params = jax.tree_multimap(lambda p, u: p + u, state.params['value'], updates)

        updates, policy_opt_state = policy_opt.update(g['policy'], state.opt_state['policy'])
        policy_params = jax.tree_multimap(lambda p, u: p + u, state.params['policy'], updates)

        params = state.params
        params = dict(policy=policy_params, value=value_params)
        opt_state = state.opt_state
        opt_state = dict(policy=policy_opt_state, value=value_opt_state)
        state = state.replace(params=params, opt_state=opt_state)
        return state, metrics
    
    def interaction(env, horizon: int = 10, seed: int = 42):
        # Problem. Vmap reset
        # Problem vmap step/scan
        # Problem Reset.
        @jit
        def step(carry, xs):
            rng, param, state = carry
            
            # Take the action.
            rng, rng_action = jax.random.split(rng)
            action = policy_def.apply(param, state.obs).sample(
                seed=rng_action
            )
            
            # Interaction with the environment
            new_state = env.step(state, action)    
            carry = [rng, param, new_state]
            info = {
                'reward': new_state.reward,
                'observation': state.obs,
                'observation_next': new_state.obs,
                'terminal': 1.0 - new_state.done,
                'steps': new_state.steps,
            }
            return carry, info

        
        T = jnp.arange(horizon)
        rng = jax.random.PRNGKey(seed)

        @jit
        def _interact(rng, param, state):
            carry = (rng, param, state)
            (rng, _, state_next), info = jax.lax.scan(step, carry, T)
            info['observation_next'] = info['observation_next'][-1]
            return info
        
        b_interact = jit(vmap(_interact, (0, None, None)))

        param = yield
        state = env.reset(rng)
        state = tax.reduce([state for _ in range(8)])      # HARD CODED        
        while True:
            rng, rng_interaction = jax.random.split(rng)
            b_rng = jax.random.split(rng_interaction, 8)   # HARD CODED || environments
            info = b_interact(b_rng, param, state)
            param = yield info

             
    
    def make_policy(params):
        fn = lambda rng, x: policy_def.apply(params, x).sample(seed=rng) 
        return jit(fn)
    
    interaction_step = interaction(env)

    state = State(key=rng, params=params, opt_state=opt_state)
    info = {
        'interaction': interaction_step,
        'process_data': process_data_to_batch,
        'update': jit(lambda state, data: update_fn(state, process_data_to_batch(state.params, data)))
    }
    return state, info 

In [4]:
from jax import jit 
from jax import vmap
from brax.envs.ant import Ant

In [5]:
state, info = setup()
info['interaction'].send(None)

KeyboardInterrupt: 

In [None]:
data = info['interaction'].send(state.params['policy'])

In [None]:
#info['process_data'](state.params, data)

In [None]:
data['reward'].shape