In [1]:
import equinox as eqx
import gym
import jax
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np
import optax
import sys
import timeit
import wandb

from argparse import Namespace
from functools import partial
from jax import grad, jit, vmap
from typing import Sequence, Tuple, Optional, Callable, Dict

from jax_learning.agents.rl_agents import RLAgent
from jax_learning.buffers import ReplayBuffer
from jax_learning.buffers.ram_buffers import NextStateNumPyBuffer
from jax_learning.buffers.utils import batch_flatten, to_jnp
from jax_learning.common import polyak_average_generator
from jax_learning.constants import DISCRETE, CONTINUOUS
from jax_learning.rl_utils import interact, evaluate
from jax_learning.learners import LearnerWithTargetNetwork
from jax_learning.models import StochasticPolicy, ActionValue, Temperature
from jax_learning.models.policies import MLPSquashedGaussianPolicy
from jax_learning.models.q_functions import MLPQ, MultiQ

In [2]:
wandb.init(project="test_jax_rl", group="reacher-sac_test")
wandb.define_metric("episodic_return", summary="max")

  return LooseVersion(v) >= LooseVersion(check)
[34m[1mwandb[0m: Currently logged in as: [33mchan[0m. Use [1m`wandb login --relogin`[0m to force relogin
  from IPython.core.display import display, HTML  # type: ignore


<wandb.sdk.wandb_metric.Metric at 0x1126af340>

In [3]:
cfg_dict = {
    # Environment setup
    "env": "Reacher-v2",
    "seed": 0,
    "render": False,
    
    # Experiment progress
    "load_step": 0,
    "log_interval": 5000,
    
    # Learning hyperparameters
    "max_timesteps": 1000000,
    "buffer_size": 1000000,
    "buffer_warmup": 1000,
    "num_gradient_steps": 1,
    "batch_size": 128,
    "max_grad_norm": 10.,
    "gamma": 0.99,
    "update_frequency": 1,
    
    # Actor
    "actor_lr": 3e-4,
    "actor_update_frequency": 1,
    
    # Critic
    "critic_lr": 3e-4,
    "target_update_frequency": 1,
    "tau": 0.01, # This is for polyak averaging of target network
    
    # Normalization
    "normalize_obs": False,
    "normalize_value": False,
    
    # Temperature
    "alpha_lr": 3e-4,
    "init_alpha": 1.0,
    "target_entropy": "auto",
    
    # Model architecture
    "hidden_dim": 256,
    "num_hidden": 2,
    
    # Evaluation
    "eval_cfg": {
        "max_episodes": 100,
        "seed": 1,
        "render": True,
    }
}
cfg = Namespace(**cfg_dict)
eval_cfg = Namespace(**cfg.eval_cfg)
wandb.config = cfg_dict

In [4]:
np.random.seed(cfg.seed)

In [5]:
env = gym.make(cfg.env)

  logger.warn(
  logger.warn(
objc[31153]: Class GLFWWindowDelegate is implemented in both /usr/local/Cellar/glfw/3.3.7/lib/libglfw.3.3.dylib (0x1222177b0) and /Users/chanb/.mujoco/mujoco210/bin/libglfw.3.dylib (0x1222aa700). One of the two will be used. Which one is undefined.
objc[31153]: Class GLFWApplicationDelegate is implemented in both /usr/local/Cellar/glfw/3.3.7/lib/libglfw.3.3.dylib (0x122217788) and /Users/chanb/.mujoco/mujoco210/bin/libglfw.3.dylib (0x1222aa778). One of the two will be used. Which one is undefined.
objc[31153]: Class GLFWContentView is implemented in both /usr/local/Cellar/glfw/3.3.7/lib/libglfw.3.3.dylib (0x122217800) and /Users/chanb/.mujoco/mujoco210/bin/libglfw.3.dylib (0x1222aa7a0). One of the two will be used. Which one is undefined.
objc[31153]: Class GLFWWindow is implemented in both /usr/local/Cellar/glfw/3.3.7/lib/libglfw.3.3.dylib (0x122217878) and /Users/chanb/.mujoco/mujoco210/bin/libglfw.3.dylib (0x1222aa818). One of the two will be used. Whic

In [6]:
cfg.obs_dim = env.observation_space.shape
cfg.act_dim = env.action_space.shape
if cfg.target_entropy == "auto":
    cfg.target_entropy = -float(np.product(env.action_space.shape))
cfg.action_space = CONTINUOUS

In [7]:
cfg.h_state_dim = (1,)
cfg.rew_dim = (1,)

In [8]:
cfg

Namespace(env='Reacher-v2', seed=0, render=False, load_step=0, log_interval=5000, max_timesteps=1000000, buffer_size=1000000, buffer_warmup=1000, num_gradient_steps=1, batch_size=128, max_grad_norm=10.0, gamma=0.99, update_frequency=1, actor_lr=0.0003, actor_update_frequency=1, critic_lr=0.0003, target_update_frequency=1, tau=0.01, normalize_obs=False, normalize_value=False, alpha_lr=0.0003, init_alpha=1.0, target_entropy=-2.0, hidden_dim=256, num_hidden=2, eval_cfg={'max_episodes': 100, 'seed': 1, 'render': True}, obs_dim=(11,), act_dim=(2,), action_space='continuous', h_state_dim=(1,), rew_dim=(1,))

In [9]:
cfg.buffer_rng = np.random.RandomState(cfg.seed)
cfg.env_rng = np.random.RandomState(cfg.seed)
cfg.agent_key, cfg.model_key = jrandom.split(jrandom.PRNGKey(cfg.seed), num=2)
eval_cfg.env_rng = np.random.RandomState(eval_cfg.seed)



In [45]:
Q_LOSS = "q_loss"
POLICY_LOSS = "policy_loss"
TEMPERATURE_LOSS = "temperature_loss"
MEAN_Q_LOSS = "mean_q_loss"
MEAN_POLICY_LOSS = "mean_policy_loss"
MEAN_TEMPERATURE_LOSS = "mean_temperature_loss"
MEAN_CURR_Q = "mean_curr_q"
MEAN_NEXT_Q = "mean_next_q"
MAX_CURR_Q = "max_curr_q"
MAX_NEXT_Q = "max_next_q"
MIN_CURR_Q = "min_curr_q"
MIN_NEXT_Q = "min_next_q"
MAX_TD_ERROR = "max_td_error"
MIN_TD_ERROR = "min_td_error"
POLICY = "policy"
Q = "q"
TEMPERATURE = "temperature"
TARGET_ENTROPY = "target_entropy"

def clipped_min_q_td_error(curr_q_pred: np.ndarray,
                           next_q_pred_min: np.ndarray,
                           next_lprob: np.ndarray,
                           rew: np.ndarray,
                           done: np.ndarray,
                           temp: float,
                           gamma: float) -> np.ndarray:
    v_next = (next_q_pred_min - temp * next_lprob)
    curr_q_target = rew + gamma * (1 - done) * v_next
    return curr_q_pred - curr_q_target

def sac_policy_loss(curr_q_pred_min: np.ndarray,
                    lprob: np.ndarray,
                    temp: float) -> np.ndarray:
    return -(curr_q_pred_min - temp * lprob)

def sac_temperature_loss(temp: float,
                         lprob: np.ndarray,
                         target_entropy: float) -> np.ndarray:
    return temp * -(lprob + target_entropy)


class SAC(LearnerWithTargetNetwork):
    def __init__(self,
                 model: Dict[str, eqx.Module],
                 target_model: Dict[str, eqx.Module],
                 opt: Dict[str, optax.GradientTransformation],
                 buffer: ReplayBuffer,
                 cfg: Namespace):
        super().__init__(model, target_model, opt, buffer, cfg)
        
        self._batch_size = cfg.batch_size
        self._num_gradient_steps = cfg.num_gradient_steps
        
        self._buffer_warmup = cfg.buffer_warmup
        self._actor_update_frequency = cfg.actor_update_frequency
        self._target_update_frequency = cfg.target_update_frequency
        
        self._target_entropy = getattr(cfg, TARGET_ENTROPY, None)
        self._sample_key = jrandom.PRNGKey(cfg.seed)
        
        _clipped_min_q_td_error = jax.vmap(clipped_min_q_td_error, in_axes=[0, 0, 0, 0, 0, None, None])

        @eqx.filter_grad(has_aux=True)
        def q_loss(models: Tuple[ActionValue, ActionValue],
                   policy: StochasticPolicy,
                   temperature: Temperature,
                   obss: np.ndarray,
                   h_states: np.ndarray,
                   acts: np.ndarray,
                   rews: np.ndarray,
                   dones: np.ndarray,
                   next_obss: np.ndarray,
                   next_h_states: np.ndarray,
                   keys: Sequence[jrandom.PRNGKey]) -> Tuple[np.ndarray, dict]:
            (q, target_q) = models
            
            curr_xs = jnp.concatenate((obss, acts), axis=-1)
            curr_q_preds, _ = jax.vmap(q.q_values)(curr_xs, h_states)
            
            next_acts, next_lprobs, _ = jax.vmap(policy.act_lprob)(next_obss, next_h_states, keys)
            next_lprobs = jnp.sum(next_lprobs, axis=-1, keepdims=True)
            
            next_xs = jnp.concatenate((next_obss, next_acts), axis=-1)
            next_q_preds, _ = jax.vmap(target_q.q_values)(next_xs, next_h_states)
            next_q_preds_min = jnp.min(next_q_preds, axis=1)
            
            temp = temperature()
            
            def batch_td_errors(curr_q_pred):
                return _clipped_min_q_td_error(curr_q_pred,
                                               next_q_preds_min,
                                               next_lprobs,
                                               rews,
                                               dones,
                                               temp,
                                               self._gamma)
            td_errors = jax.vmap(batch_td_errors, in_axes=[1])(curr_q_preds)
            loss = jnp.mean(td_errors ** 2)
            return loss, {
                Q_LOSS: loss,
                MAX_NEXT_Q: jnp.max(next_q_preds),
                MIN_NEXT_Q: jnp.min(next_q_preds),
                MEAN_NEXT_Q: jnp.mean(next_q_preds),
                MAX_CURR_Q: jnp.max(curr_q_preds),
                MIN_CURR_Q: jnp.min(curr_q_preds),
                MEAN_CURR_Q: jnp.mean(curr_q_preds),
                MAX_TD_ERROR: jnp.max(td_errors),
                MIN_TD_ERROR: jnp.min(td_errors),
            }
        
        apply_residual_gradient = polyak_average_generator(getattr(cfg, "omega", 1.0))
        
        def update_q(q: ActionValue,
                     target_q: ActionValue,
                     policy: StochasticPolicy,
                     temperature: Temperature,
                     opt: optax.GradientTransformation,
                     opt_state: optax.OptState,
                     obss: np.ndarray,
                     h_states: np.ndarray,
                     acts: np.ndarray,
                     rews: np.ndarray,
                     dones: np.ndarray,
                     next_obss: np.ndarray,
                     next_h_states: np.ndarray) -> Tuple[ActionValue,
                                                         optax.OptState,
                                                         Tuple[jax.tree_util.PyTreeDef,
                                                               jax.tree_util.PyTreeDef,
                                                               jax.tree_util.PyTreeDef],
                                                         dict,
                                                         jrandom.PRNGKey]:
            sample_key = jrandom.split(self._sample_key, num=1)[0]
            keys = jrandom.split(self._sample_key, num=self._batch_size)
            grads, learn_info = q_loss((q, target_q),
                                       policy,
                                       temperature,
                                       obss,
                                       h_states,
                                       acts,
                                       rews,
                                       dones,
                                       next_obss,
                                       next_h_states,
                                       keys)

            (q_grads, target_q_grads) = grads
            grads = jax.tree_map(apply_residual_gradient,
                                 q_grads,
                                 target_q_grads)

            updates, opt_state = opt.update(grads, opt_state)
            q = eqx.apply_updates(q, updates)
            return q, opt_state, (grads, q_grads, target_q_grads), learn_info, sample_key

        _sac_policy_loss = jax.vmap(sac_policy_loss, in_axes=[0, 0, None])
        
        @eqx.filter_grad(has_aux=True)
        def policy_loss(policy: StochasticPolicy,
                        q: ActionValue,
                        temperature: Temperature,
                        obss: np.ndarray,
                        h_states: np.ndarray,
                        keys: Sequence[jrandom.PRNGKey]) -> Tuple[np.ndarray, dict]:
            acts, lprobs, _ = jax.vmap(policy.act_lprob)(obss, h_states, keys)
            lprobs = jnp.sum(lprobs, axis=-1, keepdims=True)
            curr_xs = jnp.concatenate((obss, acts), axis=-1)
            curr_q_preds, _ = jax.vmap(q.q_values)(curr_xs, h_states)
            curr_q_preds_min = jnp.min(curr_q_preds, axis=1)
            temp = temperature()
            
            loss = jnp.mean(_sac_policy_loss(curr_q_preds_min, lprobs, temp))
            return loss, {
                POLICY_LOSS: loss,
            }
        
        def update_policy(policy: StochasticPolicy,
                          q: ActionValue,
                          temperature: Temperature,
                          opt: optax.GradientTransformation,
                          opt_state: optax.OptState,
                          obss: np.ndarray,
                          h_states: np.ndarray,
                          acts: np.ndarray) -> Tuple[ActionValue,
                                                     optax.OptState,
                                                     jax.tree_util.PyTreeDef,
                                                     dict,
                                                     jrandom.PRNGKey]:
            sample_key = jrandom.split(self._sample_key, num=1)[0]
            keys = jrandom.split(self._sample_key, num=self._batch_size)
            
            grads, learn_info = policy_loss(policy,
                                            q,
                                            temperature,
                                            obss,
                                            h_states,
                                            keys)

            updates, opt_state = opt.update(grads, opt_state)
            policy = eqx.apply_updates(policy, updates)
            return policy, opt_state, grads, learn_info, sample_key

        _sac_temperature_loss = jax.vmap(sac_temperature_loss, in_axes=[None, 0, None])
        
        @eqx.filter_grad(has_aux=True)
        def temperature_loss(temperature: Temperature,
                             policy: StochasticPolicy,
                             obss: np.ndarray,
                             h_states: np.ndarray,
                             keys: Sequence[jrandom.PRNGKey]) -> Tuple[np.ndarray, dict]:
            temp = temperature()
            _, lprobs, _ = jax.vmap(policy.act_lprob)(obss, h_states, keys)
            lprobs = jnp.sum(lprobs, axis=-1, keepdims=True)
            loss = jnp.mean(_sac_temperature_loss(temp, lprobs, self._target_entropy))
            return loss, {
                TEMPERATURE_LOSS: loss,
            }
        
        def update_temperature(policy: StochasticPolicy,
                               temperature: Temperature,
                               opt: optax.GradientTransformation,
                               opt_state: optax.OptState,
                               obss: np.ndarray,
                               h_states: np.ndarray) -> Tuple[ActionValue,
                                                              optax.OptState,
                                                              jax.tree_util.PyTreeDef,
                                                              dict,
                                                              jrandom.PRNGKey]:
            sample_key = jrandom.split(self._sample_key, num=1)[0]
            keys = jrandom.split(self._sample_key, num=self._batch_size)
            grads, learn_info = temperature_loss(temperature,
                                                 policy,
                                                 obss,
                                                 h_states,
                                                 keys)

            updates, opt_state = opt.update(grads, opt_state)
            temperature = eqx.apply_updates(temperature, updates)
            return temperature, opt_state, grads, learn_info, sample_key
        
        self.update_q = eqx.filter_jit(update_q)
        self.update_policy = eqx.filter_jit(update_policy)
        self.update_temperature = eqx.filter_jit(update_temperature)
        
    def learn(self, next_obs: np.ndarray, next_h_state: np.ndarray, learn_info: dict):
        self._step += 1

        if (
            self._step <= self._buffer_warmup
            or (self._step - 1 - self._buffer_warmup) % self._update_frequency != 0
        ):
            return

        learn_info[MEAN_Q_LOSS] = 0.0
        learn_info[MEAN_POLICY_LOSS] = 0.0
        learn_info[MEAN_TEMPERATURE_LOSS] = 0.0
        learn_info[MEAN_CURR_Q] = 0.0
        learn_info[MEAN_NEXT_Q] = 0.0
        learn_info[MAX_CURR_Q] = -np.inf
        learn_info[MAX_NEXT_Q] = -np.inf
        learn_info[MIN_CURR_Q] = np.inf
        learn_info[MIN_NEXT_Q] = np.inf
        for update_i in range(self._num_gradient_steps):
            (
                obss,
                h_states,
                acts,
                rews,
                dones,
                next_obss,
                next_h_states,
                _,
                _,
                _,
            ) = self.buffer.sample_with_next_obs(
                batch_size=self._batch_size,
                next_obs=next_obs,
                next_h_state=next_h_state,
            )

            if self.obs_rms:
                obss = self.obs_rms.normalize(obss)

            (obss, h_states, acts, rews, dones, next_obss, next_h_states) = to_jnp(
                *batch_flatten(
                    obss, h_states, acts, rews, dones, next_obss, next_h_states
                )
            )
            q, opt_state, grads, q_learn_info, self._sample_key = self.update_q(
                q=self.model[Q],
                target_q=self.target_model[Q],
                policy=self.model[POLICY],
                temperature=self.model[TEMPERATURE],
                opt=self.opt[Q],
                opt_state=self.opt_state[Q],
                obss=obss,
                h_states=h_states,
                acts=acts,
                rews=rews,
                dones=dones,
                next_obss=next_obss,
                next_h_states=next_h_states,
            )

            self._model[Q] = q
            self._opt_state[Q] = opt_state
            
            if self._step % self._actor_update_frequency == 0:
                policy, opt_state, grads, policy_learn_info, self._sample_key = self.update_policy(
                    policy=self.model[POLICY],
                    q=self.model[Q],
                    temperature=self.model[TEMPERATURE],
                    opt=self.opt[POLICY],
                    opt_state=self.opt_state[POLICY],
                    obss=obss,
                    h_states=h_states,
                    acts=acts
                )
                self._model[POLICY] = policy
                self._opt_state[POLICY] = opt_state

                if self._target_entropy is not None:
                    temperature, opt_state, grads, temperature_learn_info, self._sample_key = self.update_temperature(
                        policy=self.model[POLICY],
                        temperature=self.model[TEMPERATURE],
                        opt=self.opt[TEMPERATURE],
                        opt_state=self.opt_state[TEMPERATURE],
                        obss=obss,
                        h_states=h_states,
                    )
                    self._model[TEMPERATURE] = temperature
                    self._opt_state[TEMPERATURE] = opt_state

            if self._step % self._target_update_frequency == 0:
                self.update_target_model(model_key=Q)

            learn_info[MEAN_Q_LOSS] += (
                q_learn_info[Q_LOSS].item() / self._num_gradient_steps
            )
            learn_info[MEAN_POLICY_LOSS] += (
                policy_learn_info[POLICY_LOSS].item() / self._num_gradient_steps
            )
            learn_info[MEAN_TEMPERATURE_LOSS] += (
                temperature_learn_info[TEMPERATURE_LOSS].item() / self._num_gradient_steps
            )
            learn_info[MEAN_CURR_Q] += (
                q_learn_info[MEAN_CURR_Q].item() / self._num_gradient_steps
            )
            learn_info[MEAN_NEXT_Q] += (
                q_learn_info[MEAN_NEXT_Q].item() / self._num_gradient_steps
            )
            learn_info[MAX_CURR_Q] = max(
                q_learn_info[MAX_CURR_Q], q_learn_info[MAX_CURR_Q].item()
            )
            learn_info[MAX_NEXT_Q] = max(
                q_learn_info[MAX_NEXT_Q], q_learn_info[MAX_NEXT_Q].item()
            )
            learn_info[MIN_CURR_Q] = min(
                q_learn_info[MIN_CURR_Q], q_learn_info[MIN_CURR_Q].item()
            )
            learn_info[MIN_NEXT_Q] = min(
                q_learn_info[MIN_NEXT_Q], q_learn_info[MIN_NEXT_Q].item()
            )


In [46]:
buffer = NextStateNumPyBuffer(
    buffer_size=cfg.buffer_size,
    obs_dim=cfg.obs_dim,
    h_state_dim=cfg.h_state_dim,
    act_dim=(1,) if cfg.action_space == DISCRETE else cfg.act_dim,
    rew_dim=cfg.rew_dim,
    rng=cfg.buffer_rng,
)

policy_key, q_key = jrandom.split(cfg.model_key)
policy = MLPSquashedGaussianPolicy(
        obs_dim=cfg.obs_dim,
        act_dim=cfg.act_dim,
        hidden_dim=cfg.hidden_dim,
        num_hidden=cfg.num_hidden,
        key=policy_key,
)

temperature = Temperature(init_alpha=cfg.init_alpha)

q_constructor = partial(MLPQ,
                        in_dim=(cfg.obs_dim[0] + cfg.act_dim[0],),
                        out_dim=(1,),
                        hidden_dim=cfg.hidden_dim,
                        num_hidden=cfg.num_hidden)

q = MultiQ(q_constructor,
           num_qs=2,
           key=q_key)

target_q = MultiQ(q_constructor,
                  num_qs=2,
                  key=q_key)

model = {
    POLICY: policy,
    TEMPERATURE: temperature,
    Q: q,
}

target_model = {
    Q: target_q
}

q_opt_transforms = [
    optax.scale_by_adam(),
    optax.scale(-cfg.critic_lr)
]

policy_opt_transforms = [
    optax.scale_by_adam(),
    optax.scale(-cfg.actor_lr)
]

temperature_opt_transforms = [
    optax.scale_by_adam(),
    optax.scale(-cfg.alpha_lr)
]

if cfg.max_grad_norm:
    q_opt_transforms.insert(0, optax.clip_by_global_norm(cfg.max_grad_norm))
    policy_opt_transforms.insert(0, optax.clip_by_global_norm(cfg.max_grad_norm))
    temperature_opt_transforms.insert(0, optax.clip_by_global_norm(cfg.max_grad_norm))
opt = {
    Q: optax.chain(*q_opt_transforms),
    POLICY: optax.chain(*policy_opt_transforms),
    TEMPERATURE: optax.chain(*temperature_opt_transforms)
}

learner = SAC(model=model,
              target_model=target_model,
              opt=opt,
              buffer=buffer,
              cfg=cfg)

agent = RLAgent(model=model,
                model_key=POLICY,
                buffer=buffer,
                learner=learner,
                key=cfg.agent_key)

In [47]:
%wandb

In [None]:
interact(env, agent, cfg)

In [None]:
evaluate(env, agent, eval_cfg)

In [None]:
wandb.finish()