In [None]:
import equinox as eqx
import gym
import jax
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np
import optax
import sys
import timeit
import wandb

from argparse import Namespace
from functools import partial
from jax import grad, jit, vmap
from typing import Sequence, Tuple, Optional, Callable

from jax_learning.buffers.ram_buffers import NextStateNumPyBuffer
from jax_learning.buffers.utils import batch_flatten, to_jnp
from jax_learning.constants import DISCRETE, CONTINUOUS
from jax_learning.rl_utils import interact, evaluate
from jax_learning.learners import LearnerWithTargetNetwork
from jax_learning.models import Policy, ActionValue, MLP, StochasticPolicy, Model
from jax_learning.models.q_functions import MLPQ, MultiQ

In [None]:
key = jrandom.PRNGKey(0)
obs_dim, act_dim, hidden_dim, num_hidden, keys = (6,), (1,), 2, 2, jrandom.split(jrandom.PRNGKey(0))
q_constructor = partial(MLPQ, in_dim=obs_dim, out_dim=act_dim, hidden_dim=hidden_dim, num_hidden=num_hidden)
model = MultiQ(q_constructor, num_qs=7, key=key)

In [None]:
q_vals, h_states = jax.vmap(model.q_values)(np.array([[0., 1., 2.], [1., 1., 2.]]), np.array([[0., 1., 2.], [0., 1., 2.]]), np.array([[0., 1., 2.], [0., 1., 2.]]))

In [None]:
q_vals

In [None]:
class SAC(LearnerWithTargetNetwork):
    def __init__(self,
                 model: Dict[str, eqx.Module],
                 target_model: Dict[str, eqx.Module],
                 opt: optax.GradientTransformation,
                 buffer: ReplayBuffer,
                 cfg: Namespace):
        super().__init__(model, target_model, opt, buffer, cfg)
        
        self._step = cfg.load_step
        self._batch_size = cfg.batch_size
        self._num_gradient_steps = cfg.num_gradient_steps
        self._gamma = cfg.gamma
        
        self._buffer_warmup = cfg.buffer_warmup
        self._update_frequency = cfg.update_frequency
        self._actor_update_frequency = cfg.actor_update_frequency
        self._target_update_frequency = cfg.target_update_frequency
        self._tau = cfg.tau
        self._omega = cfg.omega

        @eqx.filter_grad(has_aux=True)
        def value_loss(models: Tuple[eqx.Module, eqx.Module],
                            obss: np.ndarray,
                            h_states: np.ndarray,
                            acts: np.ndarray,
                            rews: np.ndarray,
                            dones: np.ndarray,
                            next_obss: np.ndarray,
                            next_h_states: np.ndarray,
                            gammas: np.ndarray) -> Tuple[np.ndarray, dict]:
            (model, target_model) = models
            q_curr_preds, _ = jax.vmap(model.q_values)(obss, h_states)
            q_next_preds, _ = jax.vmap(target_model.q_values)(next_obss, next_h_states)
            
            td_errors = jax.vmap(q_learning_td_error)(q_curr_preds, acts, q_next_preds, rews, dones, gammas)
            loss = jnp.mean(td_errors ** 2)
            return loss, {
                LOSS: loss,
                MAX_Q_NEXT: jnp.max(q_next_preds),
                MIN_Q_NEXT: jnp.min(q_next_preds),
                MEAN_Q_NEXT: jnp.mean(q_next_preds),
                MAX_Q_CURR: jnp.max(q_curr_preds),
                MIN_Q_CURR: jnp.min(q_curr_preds),
                MEAN_Q_CURR: jnp.mean(q_curr_preds),
                MAX_TD_ERROR: jnp.max(td_errors),
                MIN_TD_ERROR: jnp.min(td_errors),
            }

        @eqx.filter_grad(has_aux=True)
        def policy_loss(models: Tuple[eqx.Module, eqx.Module],
                            obss: np.ndarray,
                            h_states: np.ndarray,
                            acts: np.ndarray,
                            rews: np.ndarray,
                            dones: np.ndarray,
                            next_obss: np.ndarray,
                            next_h_states: np.ndarray,
                            gammas: np.ndarray) -> Tuple[np.ndarray, dict]:
            (model, target_model) = models
            q_curr_preds, _ = jax.vmap(model.q_values)(obss, h_states)
            q_next_preds, _ = jax.vmap(target_model.q_values)(next_obss, next_h_states)
            
            td_errors = jax.vmap(q_learning_td_error)(q_curr_preds, acts, q_next_preds, rews, dones, gammas)
            loss = jnp.mean(td_errors ** 2)
            return loss, {
                LOSS: loss,
                MAX_Q_NEXT: jnp.max(q_next_preds),
                MIN_Q_NEXT: jnp.min(q_next_preds),
                MEAN_Q_NEXT: jnp.mean(q_next_preds),
                MAX_Q_CURR: jnp.max(q_curr_preds),
                MIN_Q_CURR: jnp.min(q_curr_preds),
                MEAN_Q_CURR: jnp.mean(q_curr_preds),
                MAX_TD_ERROR: jnp.max(td_errors),
                MIN_TD_ERROR: jnp.min(td_errors),
            }

        @eqx.filter_grad(has_aux=True)
        def temperature_loss(models: Tuple[eqx.Module, eqx.Module],
                            obss: np.ndarray,
                            h_states: np.ndarray,
                            acts: np.ndarray,
                            rews: np.ndarray,
                            dones: np.ndarray,
                            next_obss: np.ndarray,
                            next_h_states: np.ndarray,
                            gammas: np.ndarray) -> Tuple[np.ndarray, dict]:
            (model, target_model) = models
            q_curr_preds, _ = jax.vmap(model.q_values)(obss, h_states)
            q_next_preds, _ = jax.vmap(target_model.q_values)(next_obss, next_h_states)
            
            td_errors = jax.vmap(q_learning_td_error)(q_curr_preds, acts, q_next_preds, rews, dones, gammas)
            loss = jnp.mean(td_errors ** 2)
            return loss, {
                LOSS: loss,
                MAX_Q_NEXT: jnp.max(q_next_preds),
                MIN_Q_NEXT: jnp.min(q_next_preds),
                MEAN_Q_NEXT: jnp.mean(q_next_preds),
                MAX_Q_CURR: jnp.max(q_curr_preds),
                MIN_Q_CURR: jnp.min(q_curr_preds),
                MEAN_Q_CURR: jnp.mean(q_curr_preds),
                MAX_TD_ERROR: jnp.max(td_errors),
                MIN_TD_ERROR: jnp.min(td_errors),
            }
        
        def update_q(model: eqx.Module,
                 target_model: eqx.Module,
                 opt: optax.GradientTransformation,
                 opt_state: optax.OptState,
                 obss: np.ndarray,
                 h_states: np.ndarray,
                 acts: np.ndarray,
                 rews: np.ndarray,
                 dones: np.ndarray,
                 next_obss: np.ndarray,
                 next_h_states: np.ndarray,
                 gammas: np.ndarray,
                 omega: float) -> Tuple[eqx.Module, optax.OptState, Tuple[jax.tree_util.PyTreeDef, jax.tree_util.PyTreeDef, jax.tree_util.PyTreeDef], dict]:
            grads, learn_info = q_learning_loss((model, target_model),
                                                obss,
                                                h_states,
                                                acts,
                                                rews,
                                                dones,
                                                next_obss,
                                                next_h_states,
                                                gammas)

            (model_grads, target_model_grads) = grads
            grads = jax.tree_map(lambda g, tg: g * omega + tg * (1 - omega),
                                 model_grads,
                                 target_model_grads)

            updates, opt_state = opt.update(grads, opt_state)
            model = eqx.apply_updates(model, updates)
            return model, opt_state, (grads, model_grads, target_model_grads), learn_info
        
        def update_policy(model: eqx.Module,
                 target_model: eqx.Module,
                 opt: optax.GradientTransformation,
                 opt_state: optax.OptState,
                 obss: np.ndarray,
                 h_states: np.ndarray,
                 acts: np.ndarray,
                 rews: np.ndarray,
                 dones: np.ndarray,
                 next_obss: np.ndarray,
                 next_h_states: np.ndarray,
                 gammas: np.ndarray,
                 omega: float) -> Tuple[eqx.Module, optax.OptState, Tuple[jax.tree_util.PyTreeDef, jax.tree_util.PyTreeDef, jax.tree_util.PyTreeDef], dict]:
            grads, learn_info = q_learning_loss((model, target_model),
                                                obss,
                                                h_states,
                                                acts,
                                                rews,
                                                dones,
                                                next_obss,
                                                next_h_states,
                                                gammas)

            (model_grads, target_model_grads) = grads
            grads = jax.tree_map(lambda g, tg: g * omega + tg * (1 - omega),
                                 model_grads,
                                 target_model_grads)

            updates, opt_state = opt.update(grads, opt_state)
            model = eqx.apply_updates(model, updates)
            return model, opt_state, (grads, model_grads, target_model_grads), learn_info
        
        def update_temperature(model: eqx.Module,
                 target_model: eqx.Module,
                 opt: optax.GradientTransformation,
                 opt_state: optax.OptState,
                 obss: np.ndarray,
                 h_states: np.ndarray,
                 acts: np.ndarray,
                 rews: np.ndarray,
                 dones: np.ndarray,
                 next_obss: np.ndarray,
                 next_h_states: np.ndarray,
                 gammas: np.ndarray,
                 omega: float) -> Tuple[eqx.Module, optax.OptState, Tuple[jax.tree_util.PyTreeDef, jax.tree_util.PyTreeDef, jax.tree_util.PyTreeDef], dict]:
            (model_grads, target_model_grads) = grads
            grads = jax.tree_map(lambda g, tg: g * omega + tg * (1 - omega),
                                 model_grads,
                                 target_model_grads)

            updates, opt_state = opt.update(grads, opt_state)
            model = eqx.apply_updates(model, updates)
            return model, opt_state, (grads, model_grads, target_model_grads), learn_info
        
        self.update_q = eqx.filter_jit(update_q)
        self.update_policy = eqx.filter_jit(update_policy)
        self.update_temperature = eqx.filter_jit(update_temperature)
        
    def learn(self,
              next_obs: np.ndarray,
              next_h_state: np.ndarray,
              learn_info: dict):
        self._step += 1
        
        if self._step <= self._buffer_warmup or (self._step - 1 - self._buffer_warmup) % self._update_frequency != 0:
            return

        learn_info[MEAN_LOSS] = 0.
        learn_info[MEAN_Q_CURR] = 0.
        learn_info[MEAN_Q_NEXT] = 0.
        learn_info[MAX_Q_CURR] = -np.inf
        learn_info[MAX_Q_NEXT] = -np.inf
        learn_info[MIN_Q_CURR] = np.inf
        learn_info[MIN_Q_NEXT] = np.inf
        for update_i in range(self._num_gradient_steps):
            obss, h_states, acts, rews, dones, next_obss, next_h_states, _, _, _ = self.buffer.sample_with_next_obs(batch_size=self._batch_size,
                                                                                                                    next_obs=next_obs,
                                                                                                                    next_h_state=next_h_state)

            if self.obs_rms:
                obss = self.obs_rms.normalize(obss)
            acts = acts.astype(np.int64)
            gammas = np.ones(self._batch_size) * self._gamma
            
            (obss, h_states, acts, rews, dones, next_obss, next_h_states, gammas) = to_jnp(*batch_flatten(obss,
                                                                                                          h_states,
                                                                                                          acts,
                                                                                                          rews,
                                                                                                          dones,
                                                                                                          next_obss,
                                                                                                          next_h_states,
                                                                                                          gammas))
            model, opt_state, grads, curr_learn_info = self.step(model=self.model[Q],
                                                                 target_model=self.target_model[Q],
                                                                 opt=self.opt[Q],
                                                                 opt_state=self.opt_state[Q],
                                                                 obss=obss,
                                                                 h_states=h_states,
                                                                 acts=acts,
                                                                 rews=rews,
                                                                 dones=dones,
                                                                 next_obss=next_obss,
                                                                 next_h_states=next_h_states,
                                                                 gammas=gammas,
                                                                 omega=self._omega)

            self._model[Q] = model
            self._opt_state[Q] = opt_state
            
            if self._step % self._target_update_frequency == 0:
                self.polyak_average(model_key=Q)
            
            learn_info[MEAN_LOSS] += curr_learn_info[LOSS].item() / self._num_gradient_steps
            learn_info[MEAN_Q_CURR] += curr_learn_info[MEAN_Q_CURR].item() / self._num_gradient_steps
            learn_info[MEAN_Q_NEXT] += curr_learn_info[MEAN_Q_NEXT].item() / self._num_gradient_steps
            learn_info[MAX_Q_CURR] = max(learn_info[MAX_Q_CURR], curr_learn_info[MAX_Q_CURR].item())
            learn_info[MAX_Q_NEXT] = max(learn_info[MAX_Q_NEXT], curr_learn_info[MAX_Q_NEXT].item())
            learn_info[MIN_Q_CURR] = min(learn_info[MIN_Q_CURR], curr_learn_info[MIN_Q_CURR].item())
            learn_info[MIN_Q_NEXT] = min(learn_info[MIN_Q_NEXT], curr_learn_info[MIN_Q_NEXT].item())


In [None]:
params, static = eqx.partition(model, filter_spec=lambda x: isinstance(x, MLP))
print(type(params))
print(params)

In [None]:
assert 0

In [None]:
wandb.init(project="test_jax_rl", group="reacher-sac_test")
wandb.define_metric("episodic_return", summary="max")

In [None]:
cfg_dict = {
    # Environment setup
    "env": "Reacher-v2",
    "seed": 0,
    "render": False,
    
    # Experiment progress
    "load_step": 0,
    "log_interval": 5000,
    
    # Learning hyperparameters
    "max_timesteps": 1000000,
    "buffer_size": 1000000,
    "buffer_warmup": 1000,
    "num_gradient_steps": 1,
    "batch_size": 64,
    "max_grad_norm": 10.,
    "gamma": 0.99,
    "update_frequency": 4,
    
    # Actor
    "actor_lr": 3e-4,
    "actor_update_frequency": 1,
    
    # Critic
    "critic_lr": 3e-4,
    "target_update_frequency": 1,
    "tau": 0.005, # This is for polyak averaging of target network
    
    # Temperature
    "alpha_lr": 3e-4,
    "init_alpha": 1.0,
    "target_entropy": "auto",
    
    # Model architecture
    "hidden_dim": 256,
    "num_hidden": 2,
    
    # Evaluation
    "eval_cfg": {
        "max_episodes": 100,
        "seed": 1,
        "render": True,
    }
}
cfg = Namespace(**cfg_dict)
eval_cfg = Namespace(**cfg.eval_cfg)
wandb.config = cfg_dict

In [None]:
np.random.seed(cfg.seed)

In [None]:
env = gym.make(cfg.env)

In [None]:
cfg.obs_dim = env.observation_space.shape
cfg.act_dim = env.action_space.shape
cfg.action_space = CONTINUOUS

In [None]:
cfg.h_state_dim = (1,)
cfg.rew_dim = (1,)

In [None]:
cfg

In [None]:
cfg.buffer_rng = np.random.RandomState(cfg.seed)
cfg.env_rng = np.random.RandomState(cfg.seed)
cfg.agent_key, cfg.model_key = jrandom.split(jrandom.PRNGKey(cfg.seed), num=2)
eval_cfg.env_rng = np.random.RandomState(eval_cfg.seed)

In [None]:
buffer = NextStateNumPyBuffer(
    buffer_size=cfg.buffer_size,
    obs_dim=cfg.obs_dim,
    h_state_dim=cfg.h_state_dim,
    act_dim=(1,) if cfg.action_space == DISCRETE else cfg.act_dim,
    rew_dim=cfg.rew_dim,
    rng=cfg.buffer_rng,
)

model = MLPSoftmaxQ(obs_dim=cfg.obs_dim,
                    act_dim=cfg.act_dim,
                    hidden_dim=cfg.hidden_dim,
                    num_hidden=cfg.num_hidden,
                    key=cfg.model_key)

target_model = MLPSoftmaxQ(obs_dim=cfg.obs_dim,
                           act_dim=cfg.act_dim,
                           hidden_dim=cfg.hidden_dim,
                           num_hidden=cfg.num_hidden,
                           key=cfg.model_key)

opt = optax.chain(
    optax.clip_by_global_norm(cfg.max_grad_norm),  # Clip by the gradient by the global norm
    optax.scale_by_adam(),  # Use the updates from adam
    optax.scale(-1.0) # Gradient descent
)

learner = QLearning(model=model,
                    target_model=target_model,
                    opt=opt,
                    buffer=buffer,
                    cfg=cfg)

agent = EpsilonGreedyAgent(model=model,
                           buffer=buffer,
                           learner=learner,
                           init_eps=cfg.init_eps,
                           min_eps=cfg.min_eps,
                           eps_decay=cfg.eps_decay,
                           eps_warmup=cfg.eps_warmup,
                           key=cfg.agent_key)

In [None]:
%wandb

In [None]:
interact(env, agent, cfg)

In [None]:
evaluate(env, agent, eval_cfg)

In [None]:
assert 0

In [None]:
buffer

In [None]:
buffer.observations

In [None]:
np.roll(buffer.next_observations, 1, axis=0)

In [None]:
np.concatenate((buffer.observations - np.roll(buffer.next_observations, 1, axis=0), buffer.dones), axis=1)

In [None]:
buffer.sample_with_next_obs(3, buffer.next_observations[19], buffer.hidden_states[0])