<a href="https://colab.research.google.com/github/azzeddineCH/Imapal-Cartpole-Agent-with-Jax/blob/main/CartPole_with_impala_in_Jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import functools
from typing import NamedTuple, Any, Tuple, List
import dm_env
import haiku as hk
import jax.nn
from jax import numpy as jnp
import rlax
import optax
import distrax
from typing import Callable
from bsuite.environments import cartpole
import threading
import queue
import numpy as np
from rl.common import Transition
from typing import NamedTuple, Any, Tuple
import dm_env
import haiku as hk
import jax.nn
from jax import numpy as jnp

In [None]:
class Transition(NamedTuple):
    timestep: dm_env.TimeStep  # step_type [ First, Mid, Last ], reward, discount, observation
    action: int
    agent_out: Any


class Network(hk.Module):

    def __init__(self, num_actions: int):
        super().__init__()
        self._num_actions = num_actions

    def __call__(self, timestep: dm_env.TimeStep) -> Tuple[jnp.ndarray, jnp.ndarray]:
        torso = hk.Sequential([
            hk.Flatten(),
            hk.Linear(128), jax.nn.relu,
            hk.Linear(64), jax.nn.relu
        ])

        hidden = torso(timestep.observation)  # batch_size, 64

        policy_logit = hk.Linear(self._num_actions)(hidden)  # batch_size, num_actions
        baseline = hk.Linear(1)(hidden)  # batch_size, 1
        baseline = jnp.squeeze(baseline)  # batch_size

        return policy_logit, baseline


def preprocess_step(ts: dm_env.TimeStep) -> dm_env.TimeStep:
    # reward: None -> 0, discount: None -> 1, scalar -> np.array(), StepType -> int.
    if ts.reward is None:
        ts = ts._replace(reward=0.)
    if ts.discount is None:
        ts = ts._replace(discount=1.)
    return jax.tree_util.tree_map(jnp.asarray, ts)

In [None]:
class Agent:

    def __init__(self, net_apply_fn):
        self._net = net_apply_fn
        self._discount = 0.99

    @functools.partial(jax.jit, static_argnums=0)
    def step(
            self,
            params: hk.Params,
            rng: jnp.ndarray,
            timestep: dm_env.TimeStep
    ):
        timestep = jax.tree_util.tree_map(lambda t: jnp.expand_dims(t, 0), timestep)  # 1, ...observation_shape
        logits, _ = self._net(params, timestep)  # 1, num_actions
        logits = jnp.squeeze(logits)  # num_actions
        action = hk.multinomial(rng, logits, num_samples=1)  # 1, num_samples
        action = jnp.squeeze(action, axis=-1)  # num_samples,

        return action, logits

    def loss(self, params: hk.Params, trajs: Transition):
        # 1- generate actions logits following the learner policy Pi
        net_curried = hk.BatchApply(functools.partial(self._net, params))
        learner_logits, baseline_with_boostrap = net_curried(
            trajs.timestep)  # num_transitions, batch_size, num_actions, batch_size, 1

        # 2- gather the learner Vt and Vt+1 for TD-error calculation
        baseline = baseline_with_boostrap[:-1]  # V_t
        learner_logits = learner_logits[:-1]  # Pi(a_t/s_t)
        baseline_tp1 = baseline_with_boostrap[1:]  # V_t+1

        # 3 - Remove bootstrapping timesteps t+1
        _, behavior_actions, behavior_logits = jax.tree_util.tree_map(lambda t: t[:-1], trajs)

        # 4 - Shift the behavior_timestep so that each element of behavior_actions matches the resulted timestep
        behavior_timestep = jax.tree_util.tree_map(lambda t: t[1:], trajs.timestep)
        discount = behavior_timestep.discount * self._discount

        # 6 - Ignore the transition between Last to First timestep as it's a behavior of deepmind env api
        mask = jnp.not_equal(behavior_timestep.step_type, int(dm_env.StepType.FIRST))
        mask = mask.astype(jnp.float32)

        # 7 - make the importance sampling ration Ci: Pi(a_t/s_t) / Mu(a_t/s_t)
        rhos = distrax.importance_sampling_ratios(
            target_dist=distrax.Categorical(learner_logits),
            sampling_dist=distrax.Categorical(behavior_logits),
            event=behavior_actions
        )

        # 8 - get the v_trace error and policy gradient advantage of each timestep
        vtrace_returns = jax.vmap(rlax.vtrace_td_error_and_advantage, in_axes=1, out_axes=1)(
            baseline,  # v_tm1
            baseline_tp1,  # v_t
            behavior_timestep.reward,  # r_t
            discount,  # discount_t
            rhos  # rho_tm1
        )  # num_transition, batch_size

        # 9 - Calculate the policy gradient loss
        pg_advantage = jax.lax.stop_gradient(vtrace_returns.pg_advantage)
        pg_loss = jnp.mean(
            jax.vmap(rlax.policy_gradient_loss, in_axes=1, out_axes=0)(
                learner_logits,  # logits_t
                behavior_actions,  # a_t
                pg_advantage,  # adv_t
                mask,  # w_t
            )
        )

        # 10 - Calculate the baseline loss
        bl_loss = 0.5 * jnp.mean(jnp.square(vtrace_returns.errors) * mask)

        # 11 - Calculate the entropy loss
        ent_loss = jnp.mean(jax.vmap(rlax.entropy_loss, in_axes=1, out_axes=0)(learner_logits, mask))

        # 11 - Calculate the total weighted loss
        total_loss = pg_loss + 0.5 * bl_loss + 0.01 * ent_loss
        return total_loss

In [None]:
from rl.common import preprocess_step
from typing import List


def run_actor(
        agent: Agent,
        rng_key: jnp.ndarray,
        get_params: Callable[[], hk.Params],
        enqueue_traj: Callable[[Tuple[Transition, List[int]]], None],
        horizon: int,
        num_trajectories: int,
):
    env = cartpole.Cartpole()
    state = env.reset()
    traj = []
    eps_return = 0
    for i in range(num_trajectories):
        params = get_params()  # get the latest params from the learner

        eps_returns = []
        for t in range(horizon + int(i == 0)):  # first rollout is one step longer, check L28
            rng_key, step_key = jax.random.split(rng_key)
            state = preprocess_step(state)
            action, logits = agent.step(params, step_key, state)

            transition = Transition(state, action, logits)
            traj.append(transition)

            state = env.step(action)

            eps_return += state.reward if state.reward else 0
            if state.step_type == dm_env.StepType.LAST:
                eps_returns.append(eps_return)
                eps_return = 0

        trajectory_avg_return = jnp.mean(np.asarray(eps_returns))
        stacked_traj = jax.tree_util.tree_map(lambda *ts: jnp.stack(ts), *traj)  # list of trees -> tree of list
        enqueue_traj((stacked_traj, trajectory_avg_return))  # push the trajectory to the learner queue

        traj = traj[-1:]  # resume the trajectory from the last transition

In [None]:
class Learner:

    def __init__(self, agent: Agent, opt_update):
        self._agent = agent
        self._opt_update = opt_update

    @functools.partial(jax.jit, static_argnums=0)
    def update(
            self,
            params: hk.Params,
            opt_state: optax.OptState,
            trajs: Transition
    ):
        loss_value, gradient = jax.value_and_grad(self._agent.loss)(params, trajs)
        updates, new_opt_state = self._opt_update(gradient, opt_state)
        return optax.apply_updates(params, updates), new_opt_state, loss_value

In [None]:
def run(traj_per_actor, num_actors, horizon):
    env = cartpole.Cartpole()
    num_actions = env.action_spec().num_values

    # create: Network, Impala Agent, Optimizer
    net = hk.without_apply_rng(hk.transform(lambda ts: Network(num_actions)(ts)))
    agent = Agent(net.apply)
    opt = optax.rmsprop(learning_rate=5e-3, decay=0.99, eps=1e-7)

    learner = Learner(agent, opt.update)

    # Init the learner parameters
    sample_timestep = env.reset()
    sample_timestep = preprocess_step(sample_timestep)
    ts_with_batch = jax.tree_util.tree_map(lambda t: jnp.expand_dims(t, 0), sample_timestep)
    params = jax.jit(net.init)(jax.random.PRNGKey(69), ts_with_batch)

    # Init Optimizer state
    opt_state = opt.init(params)

    # a utility callback to pull params from the learner
    current_params = lambda: params

    # Create the learner queue and dqueue utility method
    # the dqueue method would wait until the element to remove
    # is available in the queue to add it to the batch
    batch_size = 2
    q = queue.Queue(maxsize=batch_size)

    def dequeue():
        batch = []
        batch_episode_returns = []
        for _ in range(batch_size):
            actor_trajectory, actor_trajectory_avg_return = q.get()  # { key: num_trasition, ... value_shape }
            batch.append(actor_trajectory)
            batch_episode_returns.append(actor_trajectory_avg_return)

        batch = jax.tree_util.tree_map(lambda *ts: jnp.stack(ts, axis=1),
                                       *batch)  # { key: num_trasition, batch_size, ... value_shape }

        batch_avg_return = jnp.mean(np.asarray(batch_episode_returns))
        return jax.device_put(batch), batch_avg_return

    # Start the actors
    for i in range(num_actors):
        key = jax.random.PRNGKey(i)
        args = (agent, key, current_params, q.put, horizon, traj_per_actor)
        threading.Thread(target=run_actor, args=args).start()

    # Start the learner
    num_steps = num_actors * traj_per_actor // batch_size

    for i in range(num_steps):
        traj, avg_return = dequeue()
        params, opt_state, loss_value = learner.update(params, opt_state, traj)
        print(f"step {i + 1}, loss: {loss_value}, avg return: {avg_return}")


In [None]:
run(traj_per_actor=500, num_actors=2, horizon=20)