In [None]:
import gym
import jax.random as jrandom
import numpy as np
import optax
import timeit
import wandb

from argparse import Namespace
from functools import partial

from jax_learning.agents.rl_agents import RLAgent
from jax_learning.buffers.ram_buffers import NextStateNumPyBuffer
from jax_learning.constants import DISCRETE, CONTINUOUS
from jax_learning.rl_utils import interact, evaluate
from jax_learning.learners.soft_actor_critic import SAC
from jax_learning.models import Temperature
from jax_learning.models.policies import MLPSquashedGaussianPolicy
from jax_learning.models.q_functions import MLPQ, MultiQ

In [None]:
wandb.init(project="test_jax_rl", group="hopper-sac_test")
wandb.define_metric("episodic_return", summary="max")

In [None]:
cfg_dict = {
    # Environment setup
    "env": "Hopper-v2",
    "seed": 0,
    "render": False,
    
    # Experiment progress
    "load_step": 0,
    "log_interval": 5000,
    
    # Learning hyperparameters
    "max_timesteps": 1000000,
    "buffer_size": 1000000,
    "buffer_warmup": 1000,
    "num_gradient_steps": 1,
    "batch_size": 128,
    "max_grad_norm": 10.,
    "gamma": 0.99,
    "update_frequency": 1,
    
    # Actor
    "actor_lr": 3e-4,
    "actor_update_frequency": 1,
    
    # Critic
    "critic_lr": 3e-4,
    "target_update_frequency": 1,
    "tau": 0.01, # This is for polyak averaging of target network
    
    # Normalization
    "normalize_obs": False,
    "normalize_value": False,
    
    # Temperature
    "alpha_lr": 3e-4,
    "init_alpha": 1.0,
    "target_entropy": "auto",
    
    # Model architecture
    "hidden_dim": 256,
    "num_hidden": 2,
    
    # Evaluation
    "eval_cfg": {
        "max_episodes": 100,
        "seed": 1,
        "render": True,
    }
}
cfg = Namespace(**cfg_dict)
eval_cfg = Namespace(**cfg.eval_cfg)
wandb.config = cfg_dict

In [None]:
np.random.seed(cfg.seed)

In [None]:
env = gym.make(cfg.env)

In [None]:
cfg.obs_dim = env.observation_space.shape
cfg.act_dim = env.action_space.shape
if cfg.target_entropy == "auto":
    cfg.target_entropy = -float(np.product(env.action_space.shape))
cfg.action_space = CONTINUOUS

In [None]:
cfg.h_state_dim = (1,)
cfg.rew_dim = (1,)

In [None]:
cfg

In [None]:
cfg.buffer_rng = np.random.RandomState(cfg.seed)
cfg.env_rng = np.random.RandomState(cfg.seed)
cfg.agent_key, cfg.model_key = jrandom.split(jrandom.PRNGKey(cfg.seed), num=2)
eval_cfg.env_rng = np.random.RandomState(eval_cfg.seed)

In [None]:
POLICY = "policy"
Q = "q"
TEMPERATURE = "temperature"

buffer = NextStateNumPyBuffer(
    buffer_size=cfg.buffer_size,
    obs_dim=cfg.obs_dim,
    h_state_dim=cfg.h_state_dim,
    act_dim=(1,) if cfg.action_space == DISCRETE else cfg.act_dim,
    rew_dim=cfg.rew_dim,
    rng=cfg.buffer_rng,
)

policy_key, q_key = jrandom.split(cfg.model_key)
policy = MLPSquashedGaussianPolicy(
        obs_dim=cfg.obs_dim,
        act_dim=cfg.act_dim,
        hidden_dim=cfg.hidden_dim,
        num_hidden=cfg.num_hidden,
        key=policy_key,
)

temperature = Temperature(init_alpha=cfg.init_alpha)

q_constructor = partial(MLPQ,
                        in_dim=(cfg.obs_dim[0] + cfg.act_dim[0],),
                        out_dim=(1,),
                        hidden_dim=cfg.hidden_dim,
                        num_hidden=cfg.num_hidden)

q = MultiQ(q_constructor,
           num_qs=2,
           key=q_key)

target_q = MultiQ(q_constructor,
                  num_qs=2,
                  key=q_key)

model = {
    POLICY: policy,
    TEMPERATURE: temperature,
    Q: q,
}

target_model = {
    Q: target_q
}

q_opt_transforms = [
    optax.scale_by_adam(),
    optax.scale(-cfg.critic_lr)
]

policy_opt_transforms = [
    optax.scale_by_adam(),
    optax.scale(-cfg.actor_lr)
]

temperature_opt_transforms = [
    optax.scale_by_adam(),
    optax.scale(-cfg.alpha_lr)
]

if cfg.max_grad_norm:
    q_opt_transforms.insert(0, optax.clip_by_global_norm(cfg.max_grad_norm))
    policy_opt_transforms.insert(0, optax.clip_by_global_norm(cfg.max_grad_norm))
    temperature_opt_transforms.insert(0, optax.clip_by_global_norm(cfg.max_grad_norm))
opt = {
    Q: optax.chain(*q_opt_transforms),
    POLICY: optax.chain(*policy_opt_transforms),
    TEMPERATURE: optax.chain(*temperature_opt_transforms)
}

learner = SAC(model=model,
              target_model=target_model,
              opt=opt,
              buffer=buffer,
              cfg=cfg)

agent = RLAgent(model=model,
                model_key=POLICY,
                buffer=buffer,
                learner=learner,
                key=cfg.agent_key)

In [None]:
%wandb

In [None]:
interact(env, agent, cfg)

In [None]:
evaluate(env, agent, eval_cfg)

In [None]:
wandb.finish()