In [None]:
!pip install -U dm-haiku
!pip install optax

!pip install gym==0.24.0
!pip install gym[classic_control]

!pip install coax

In [None]:
import coax
import gym
import haiku as hk
import jax
import jax.numpy as jnp
import optax
from coax.value_losses import mse, huber

import numpy as onp

In [None]:
# the name of this script
name = 'a2c'

# the cart-pole MDP
env = gym.make('CartPole-v0')
env = coax.wrappers.TrainMonitor(env, name=name, tensorboard_dir=f"./data/tensorboard_custom/{name}")

In [None]:
def func_pi(S, is_training):
    logits = hk.Sequential((
        hk.Linear(8), jax.nn.relu,
        hk.Linear(8), jax.nn.relu,
        hk.Linear(8), jax.nn.relu,
        hk.Linear(env.action_space.n, w_init=jnp.zeros)
    ))
    return {'logits': logits(S)}


# def func_v(S, is_training):
#     value = hk.Sequential((
#         hk.Linear(8), jax.nn.relu,
#         hk.Linear(8), jax.nn.relu,
#         hk.Linear(8), jax.nn.relu,
#         hk.Linear(1, w_init=jnp.zeros), jnp.ravel
#     ))
#     return value(S)

In [None]:
# these optimizers collect batches of grads before applying updates
optimizer_v = optax.chain(optax.apply_every(k=32), optax.adam(0.002))
optimizer_pi = optax.chain(optax.apply_every(k=32), optax.adam(0.001))

In [None]:
# experience tracer
tracer = coax.reward_tracing.NStep(n=1, gamma=0.9)

In [None]:
pi = coax.Policy(func_pi, env)

In [None]:
class RandomStateMixin:
    @property
    def random_seed(self):
        return self._random_seed

    @random_seed.setter
    def random_seed(self, new_random_seed):
        if new_random_seed is None:
            new_random_seed = onp.random.randint(2147483647)
        self._random_seed = new_random_seed
        self._random_key = jax.random.PRNGKey(self._random_seed)

    @property
    def rng(self):
        self._random_key, key = jax.random.split(self._random_key)
        return key

In [None]:
class CustomV(RandomStateMixin):
    def __init__(self, env, random_seed=None):
        self.random_seed = random_seed  # also initializes self.rng via RandomStateMixin
        self._jitted_funcs = {}
        self._space = env.observation_space

        def func_v(S):
            value = hk.Sequential((
            hk.Linear(8), jax.nn.relu,
            hk.Linear(8), jax.nn.relu,
            hk.Linear(8), jax.nn.relu,
            hk.Linear(1, w_init=jnp.zeros), jnp.ravel))
            return value(S)
        
        # Haiku-transform the provided func
        transformed = hk.transform_with_state(func_v)
        self._function = jax.jit(transformed.apply)

        # init function params and state
        dummy = self._space.sample() 
        dummy = self.observation_preprocessor(dummy)
        self._params, self._function_state = transformed.init(self.rng, dummy)

        def soft_update_func(old, new, tau):
            return jax.tree_map(lambda a, b: (1 - tau) * a + tau * b, old, new)

        self._soft_update_func = jax.jit(soft_update_func)

    def __call__(self, s):
        S = self.observation_preprocessor(s)
        V, _ = self.function(self.params, self.function_state, self.rng, S)
        return onp.asarray(V[0])

    def observation_preprocessor(self, X):
        X = jnp.asarray(X, dtype=self._space.dtype)   # ensure ndarray
        X = jnp.reshape(X, (-1, *self._space.shape))  # ensure batch axis
        X = jnp.clip(X, self._space.low, self._space.high)  # clip to be safe
        return X

    @property
    def params(self):
        return self._params

    @params.setter
    def params(self, new_params):
        if jax.tree_structure(new_params) != jax.tree_structure(self._params):
            raise TypeError("new params must have the same structure as old params")
        self._params = new_params

    @property
    def function(self):
        """
        This function may be called directly as:
        output, function_state = obj.function(obj.params, obj.function_state, obj.rng, *inputs)
        """
        return self._function

    @property
    def function_state(self):
        return self._function_state

    @function_state.setter
    def function_state(self, new_function_state):
        if jax.tree_structure(new_function_state) != jax.tree_structure(self._function_state):
            raise TypeError("new function_state must have the same structure as old function_state")
        self._function_state = new_function_state

    def soft_update(self, other, tau):
        self.params = self._soft_update_func(self.params, other.params, tau)
        self.function_state = self._soft_update_func(self.function_state, other.function_state, tau)

In [None]:
v = CustomV(env)

In [None]:
class CustomTDLearning:
    def __init__(self, v, loss_function, optimizer):
        self._f = v
        self._f_targ = v
        self.loss_function = huber if loss_function is None else loss_function

        # optimizer
        self._optimizer = optax.adam(1e-3) if optimizer is None else optimizer
        self._optimizer_state = self.optimizer.init(self._f.params)

        def loss_func(params, target_params, state, target_state, rng, transition_batch):
            rngs = hk.PRNGSequence(rng)
            S = self.v.observation_preprocessor(transition_batch.S)
            W = jnp.clip(transition_batch.W, 0.1, 10.)  # clip importance weights to reduce variance

            metrics = {}

            V, state_new = self.v.function(params, state, next(rngs), S)
            G = self.target_func(target_params, target_state, next(rngs), transition_batch)
            loss = self.loss_function(G, V, W)

            # only needed for metrics dict
            V_targ, _ = self.v.function(
                target_params['v_targ'], target_state['v_targ'], next(rngs), S)

            dLoss_dV = jax.grad(self.loss_function, argnums=1)
            td_error = -V.shape[0] * dLoss_dV(G, V)  # e.g. (G - V) if loss function is MSE
            metrics.update({
                f'{self.__class__.__name__}/loss': loss,
                f'{self.__class__.__name__}/td_error': jnp.mean(W * td_error),
                f'{self.__class__.__name__}/td_error_targ': jnp.mean(-dLoss_dV(V, V_targ, W)),
            })
            return loss, (td_error, state_new, metrics)

        def apply_grads_func(opt, opt_state, params, grads):
            updates, new_opt_state = opt.update(grads, opt_state, params)
            new_params = optax.apply_updates(params, updates)
            return new_opt_state, new_params

        self._apply_grads_func = jax.jit(apply_grads_func, static_argnums=0)  

        def grads_and_metrics_func(
                params, target_params, state, target_state, rng, transition_batch):

            rngs = hk.PRNGSequence(rng)
            grads, (td_error, state_new, metrics) = jax.grad(loss_func, has_aux=True)(
                params, target_params, state, target_state, next(rngs), transition_batch)
            
            def _get_leaf_diagnostics(leaf, key_prefix):
                # update this to add more grads diagnostics
                return {
                    f'{key_prefix}max': jnp.max(jnp.abs(leaf)),
                    f'{key_prefix}norm': jnp.linalg.norm(jnp.ravel(leaf)),
                }

            def tree_ravel(pytree):
                return jnp.concatenate([jnp.ravel(leaf) for leaf in jax.tree_leaves(pytree)])

            def get_grads_diagnostics(grads, key_prefix=''):
                return _get_leaf_diagnostics(tree_ravel(grads), key_prefix)

            # add some diagnostics about the gradients
            metrics.update(get_grads_diagnostics(grads, f'{self.__class__.__name__}/grads_'))

            return grads, state_new, metrics, td_error

        self._grads_and_metrics_func = jax.jit(grads_and_metrics_func)

    def target_func(self, target_params, target_state, rng, transition_batch):
        rngs = hk.PRNGSequence(rng)
        params, state = target_params['v_targ'], target_state['v_targ']
        S_next = self.v_targ.observation_preprocessor(transition_batch.S_next)

        V_next, _ = self.v_targ.function(params, state, next(rngs), S_next)
        return transition_batch.Rn + transition_batch.In * V_next

    def update(self, transition_batch, return_td_error=False):
        grads, function_state, metrics, td_error = self.grads_and_metrics(transition_batch)
        if any(jnp.any(jnp.isnan(g)) for g in jax.tree_leaves(grads)):
            raise RuntimeError(f"found nan's in grads: {grads}")
        self.apply_grads(grads, function_state)
        return (metrics, td_error) if return_td_error else metrics   

    def apply_grads(self, grads, function_state):
        self._f.function_state = function_state
        self.optimizer_state, self._f.params = \
            self._apply_grads_func(self.optimizer, self.optimizer_state, self._f.params, grads)   

    def grads_and_metrics(self, transition_batch):
        return self._grads_and_metrics_func(
            self._f.params, self.target_params, self._f.function_state, self.target_function_state,
            self._f.rng, transition_batch)
        
    @property
    def optimizer(self):
        return self._optimizer

    @optimizer.setter
    def optimizer(self, new_optimizer):
        new_optimizer_state_structure = jax.tree_structure(new_optimizer.init(self._f.params))
        if new_optimizer_state_structure != jax.tree_structure(self.optimizer_state):
            raise AttributeError("cannot set optimizer attr: mismatch in optimizer_state structure")
        self._optimizer = new_optimizer

    @property
    def optimizer_state(self):
        return self._optimizer_state

    @optimizer_state.setter
    def optimizer_state(self, new_optimizer_state):
        if jax.tree_structure(new_optimizer_state) != jax.tree_structure(self.optimizer_state):
            raise AttributeError("cannot set optimizer_state attr: mismatch in tree structure")
        self._optimizer_state = new_optimizer_state

    @property
    def v(self):
        return self._f

    @property
    def v_targ(self):
        return self._f_targ 

    @property
    def target_params(self):
        return hk.data_structures.to_immutable_dict({
            'v': self.v.params,
            'v_targ': self.v_targ.params,
            'reg': None,
            'reg_hparams': None})

    @property
    def target_function_state(self):
        return hk.data_structures.to_immutable_dict({
            'v': self.v.function_state,
            'v_targ': self.v_targ.function_state,
            'reg': None})


In [None]:
vanilla_pg = coax.policy_objectives.VanillaPG(pi, optimizer=optimizer_pi)

In [None]:
simple_td = CustomTDLearning(v, loss_function=coax.value_losses.mse, optimizer=optimizer_v)

In [None]:
# train
for ep in range(1000):
    s = env.reset()

    for t in range(env.spec.max_episode_steps):
        a = pi(s)
        s_next, r, done, info = env.step(a)
        if done and (t == env.spec.max_episode_steps - 1):
            r = 1 / (1 - tracer.gamma)

        tracer.add(s, a, r, done)
        while tracer:
            transition_batch = tracer.pop()
            metrics_v, td_error = simple_td.update(transition_batch, return_td_error=True)
            metrics_pi = vanilla_pg.update(transition_batch, td_error)
            env.record_metrics(metrics_v)
            env.record_metrics(metrics_pi)

        if done:
            break

        s = s_next

    # early stopping
    if env.avg_G > env.spec.reward_threshold:
        break