In [1]:
import equinox as eqx
import gym
import jax
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np
import optax
import sys
import timeit
import wandb

from argparse import Namespace
from functools import partial
from jax import grad, jit, vmap
from typing import Sequence, Tuple

from jax_learning.agents.rl_agents import RLAgent
from jax_learning.buffers.ram_buffers import NextStateNumPyBuffer
from jax_learning.buffers.utils import batch_flatten, to_jnp
from jax_learning.constants import DISCRETE, CONTINUOUS
from jax_learning.rl_utils import interact, evaluate

from jax_learning.models import Policy, ActionValue, MLP, StochasticPolicy

In [2]:
wandb.init(project="test_jax_rl", group="reacher-reinforce_test")
wandb.define_metric("episodic_return", summary="max")

  return LooseVersion(v) >= LooseVersion(check)
[34m[1mwandb[0m: Currently logged in as: [33mchan[0m. Use [1m`wandb login --relogin`[0m to force relogin
  from IPython.core.display import display, HTML  # type: ignore


<wandb.sdk.wandb_metric.Metric at 0x10f16d1e0>

In [3]:
cfg_dict = {
    # Environment setup
    "env": "Reacher-v2",
    "seed": 0,
    "render": False,
    
    # Experiment progress
    "load_step": 0,
    "log_interval": 5000,
    
    # Learning hyperparameters
    "gamma": 0.99,
    "max_timesteps": 1000000,
    "update_frequency": 1024,
    "lr": 3e-4,
    "max_grad_norm": 10.,
    
    # Model architecture
    "hidden_dim": 256,
    "num_hidden": 2,
    
    # Evaluation
    "eval_cfg": {
        "max_episodes": 100,
        "seed": 1,
        "render": True,
    }
}
cfg = Namespace(**cfg_dict)
eval_cfg = Namespace(**cfg.eval_cfg)
wandb.config = cfg_dict

In [4]:
np.random.seed(cfg.seed)

In [5]:
env = gym.make(cfg.env)

  logger.warn(
  logger.warn(
objc[24789]: Class GLFWWindowDelegate is implemented in both /usr/local/Cellar/glfw/3.3.7/lib/libglfw.3.3.dylib (0x11f84b7b0) and /Users/chanb/.mujoco/mujoco210/bin/libglfw.3.dylib (0x11dee2700). One of the two will be used. Which one is undefined.
objc[24789]: Class GLFWApplicationDelegate is implemented in both /usr/local/Cellar/glfw/3.3.7/lib/libglfw.3.3.dylib (0x11f84b788) and /Users/chanb/.mujoco/mujoco210/bin/libglfw.3.dylib (0x11dee2778). One of the two will be used. Which one is undefined.
objc[24789]: Class GLFWContentView is implemented in both /usr/local/Cellar/glfw/3.3.7/lib/libglfw.3.3.dylib (0x11f84b800) and /Users/chanb/.mujoco/mujoco210/bin/libglfw.3.dylib (0x11dee27a0). One of the two will be used. Which one is undefined.
objc[24789]: Class GLFWWindow is implemented in both /usr/local/Cellar/glfw/3.3.7/lib/libglfw.3.3.dylib (0x11f84b878) and /Users/chanb/.mujoco/mujoco210/bin/libglfw.3.dylib (0x11dee2818). One of the two will be used. Whic

In [6]:
cfg.obs_dim = env.observation_space.shape
cfg.act_dim = env.action_space.shape
cfg.action_space = CONTINUOUS

In [7]:
cfg.h_state_dim = (1,)
cfg.rew_dim = (1,)

In [8]:
cfg

Namespace(env='Reacher-v2', seed=0, render=False, load_step=0, log_interval=5000, gamma=0.99, max_timesteps=1000000, update_frequency=1024, lr=0.0003, max_grad_norm=10.0, hidden_dim=256, num_hidden=2, eval_cfg={'max_episodes': 100, 'seed': 1, 'render': True}, obs_dim=(11,), act_dim=(2,), action_space='continuous', h_state_dim=(1,), rew_dim=(1,))

In [9]:
cfg.buffer_rng = np.random.RandomState(cfg.seed)
cfg.env_rng = np.random.RandomState(cfg.seed)
cfg.agent_key, cfg.model_key = jrandom.split(jrandom.PRNGKey(cfg.seed), num=2)
eval_cfg.env_rng = np.random.RandomState(eval_cfg.seed)



In [15]:
from jax_learning.distributions import Distribution, Normal


class MLPGaussianPolicy(StochasticPolicy):
    obs_dim: int
    act_dim: int
    eps: float
    policy: eqx.Module

    def __init__(self,
                 obs_dim: Sequence[int],
                 act_dim: Sequence[int],
                 hidden_dim: int,
                 num_hidden: int,
                 key: jrandom.PRNGKey,
                 eps: float=1e-7):
        self.obs_dim = int(np.product(obs_dim))
        self.act_dim = int(np.product(act_dim))
        self.eps = eps
        self.policy = MLP(self.obs_dim, self.act_dim * 2, hidden_dim, num_hidden, key)

    def deterministic_action(self,
                             obs: np.ndarray,
                             h_state: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        act_mean, _ = jnp.split(self.policy(obs), 2, axis=-1)
        return act_mean, h_state
    
    def random_action(self,
                      obs: np.ndarray,
                      h_state: np.ndarray,
                      key: jrandom.PRNGKey) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        dist = self.dist(obs, h_state)
        act = dist.sample(key)
        return act, h_state
    
    def act_lprob(self,
                  obs: np.ndarray,
                  h_state: np.ndarray,
                  key: jrandom.PRNGKey) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        dist = self.dist(obs, h_state)
        act = dist.sample(key)
        lprob = dist.lprob(act)
        return act, lprob, h_state
    
    def dist(self,
             obs: np.ndarray,
             h_state: np.ndarray) -> Distribution:
        act_mean, act_raw_std = jnp.split(self.policy(obs), 2, axis=-1)
        act_std = jax.nn.softplus(act_raw_std) + self.eps
        return Normal(act_mean, act_std)


In [46]:
from jax_learning.buffers import ReplayBuffer
from jax_learning.learners import Learner

POLICY = "policy"
LOSS = "loss"
MAX_RETURN = "max_return"
MIN_RETURN = "min_return"
class REINFORCE(Learner):
    def __init__(self,
                 model: eqx.Module,
                 opt: optax.GradientTransformation,
                 buffer: ReplayBuffer,
                 cfg: Namespace):
        super().__init__(model, opt, buffer, cfg)
        
        self._step = cfg.load_step
        self._update_frequency = cfg.update_frequency
        self._sample_idxes = np.arange(cfg.update_frequency)
        self._gamma = cfg.gamma
        
        def get_lprob(dist, act):
            return jnp.sum(dist.lprob(act))
        
        def score_function(lprob, ret):
            return lprob * ret
        
        @eqx.filter_grad(has_aux=True)
        def reinforce_loss(model: eqx.Module,
                           obss: np.ndarray,
                           h_states: np.ndarray,
                           acts: np.ndarray,
                           rets: np.ndarray) -> Tuple[np.ndarray, dict]:
            dists = jax.vmap(model.dist)(obss, h_states)
            lprobs = jax.vmap(get_lprob)(dists, acts)
            
            score = jax.vmap(score_function)(lprobs, rets)
            loss = -jnp.mean(score)
            return loss, {
                LOSS: loss,
                MAX_RETURN: jnp.max(rets),
                MIN_RETURN: jnp.min(rets),
            }
        
        def step(model: eqx.Module,
                 opt: optax.GradientTransformation,
                 opt_state: optax.OptState,
                 obss: np.ndarray,
                 h_states: np.ndarray,
                 acts: np.ndarray,
                 rets: np.ndarray) -> Tuple[eqx.Module, optax.OptState, jax.tree_util.PyTreeDef, dict]:
            grads, learn_info = reinforce_loss(model,
                                               obss,
                                               h_states,
                                               acts,
                                               rets)

            updates, opt_state = opt.update(grads, opt_state)
            model = eqx.apply_updates(model, updates)
            return model, opt_state, grads, learn_info
        self.step = eqx.filter_jit(step)
        
    def compute_returns(self, rews, dones):
        rets = np.zeros(rews.shape[0] + 1)
        for step in reversed(range(len(rews))):
            rets[step] = rets[step + 1] * self._gamma * (1 - dones[step]) + rews[step]
        return rets[:-1]

    def learn(self,
              next_obs: np.ndarray,
              next_h_state: np.ndarray,
              learn_info: dict):
        self._step += 1
        
        if self._step % self._update_frequency != 0:
            return

        obss, h_states, acts, rews, dones, _, _, _ = self.buffer.sample(batch_size=self._update_frequency,
                                                                        idxes=self._sample_idxes)

        rets = self.compute_returns(rews, dones)
        (obss, h_states, acts, rets) = to_jnp(*batch_flatten(obss,
                                                             h_states,
                                                             acts,
                                                             rets))
        model, opt_state, grads, curr_learn_info = self.step(model=self.model[POLICY],
                                                             opt=self.opt[POLICY],
                                                             opt_state=self.opt_state[POLICY],
                                                             obss=obss,
                                                             h_states=h_states,
                                                             acts=acts,
                                                             rets=rets)

        self._model[POLICY] = model
        self._opt_state[POLICY] = opt_state

        learn_info[MEAN_LOSS] = curr_learn_info[LOSS].item()
        self.buffer.clear()

In [47]:
buffer = NextStateNumPyBuffer(
    buffer_size=cfg.update_frequency,
    obs_dim=cfg.obs_dim,
    h_state_dim=cfg.h_state_dim,
    act_dim=(1,) if cfg.action_space == DISCRETE else cfg.act_dim,
    rew_dim=cfg.rew_dim,
    rng=cfg.buffer_rng,
)

model = {
    POLICY: MLPGaussianPolicy(obs_dim=cfg.obs_dim,
                              act_dim=cfg.act_dim,
                              hidden_dim=cfg.hidden_dim,
                              num_hidden=cfg.num_hidden,
                              key=cfg.model_key)
}

opt = {
    POLICY: optax.chain(optax.clip_by_global_norm(cfg.max_grad_norm),  # Clip by the gradient by the global norm
                        optax.scale_by_adam(),  # Use the updates from adam
                        optax.scale(-1.0)) # Gradient descent
}


learner = REINFORCE(model=model,
                    opt=opt,
                    buffer=buffer,
                    cfg=cfg)

agent = RLAgent(model=model[POLICY],
                buffer=buffer,
                learner=learner,
                key=cfg.agent_key)

In [25]:
%wandb

In [48]:
interact(env, agent, cfg)

NameError: name 'MEAN_LOSS' is not defined

In [None]:
evaluate(env, agent, eval_cfg)

In [None]:
assert 0

In [None]:
buffer

In [None]:
buffer.observations

In [None]:
np.roll(buffer.next_observations, 1, axis=0)

In [None]:
np.concatenate((buffer.observations - np.roll(buffer.next_observations, 1, axis=0), buffer.dones), axis=1)

In [None]:
buffer.sample_with_next_obs(3, buffer.next_observations[19], buffer.hidden_states[0])