In [1]:
import gym
import jax.random as jrandom
import numpy as np
import optax
import timeit
import wandb

from argparse import Namespace
from functools import partial

from jax_learning.agents.rl_agents import RLAgent
from jax_learning.buffers.ram_buffers import NextStateNumPyBuffer
from jax_learning.common import init_wandb
from jax_learning.constants import DISCRETE, CONTINUOUS
from jax_learning.rl_utils import interact, evaluate, random_exploration_generator
from jax_learning.learners.soft_actor_critic import SAC
from jax_learning.models import Temperature
from jax_learning.models.policies import MLPSquashedGaussianPolicy
from jax_learning.models.q_functions import MLPQ, MultiQ

  import imp


In [2]:
init_wandb(project="test_jax_rl", group="hopper-sac_test", mode="online")

  return LooseVersion(v) >= LooseVersion(check)
[34m[1mwandb[0m: Currently logged in as: [33mchan[0m. Use [1m`wandb login --relogin`[0m to force relogin
  from IPython.core.display import display, HTML  # type: ignore


In [3]:
cfg_dict = {
    # Environment setup
    "env": "Hopper-v2",
    "seed": 0,
    "render": False,
    
    # Experiment progress
    "load_step": 0,
    "log_interval": 5000,
    
    # Learning hyperparameters
    "max_timesteps": 1000000,
    "buffer_size": 1000000,
    "buffer_warmup": 1000,
    "num_gradient_steps": 1,
    "batch_size": 256,
    "max_grad_norm": False,
    "gamma": 0.99,
    "update_frequency": 1,
    
    "exploration_steps": 1000,
    "exploration_strategy": "standard_gaussian",
    
    # Actor
    "actor_lr": 3e-4,
    "actor_update_frequency": 1,
    
    # Critic
    "critic_lr": 3e-4,
    "target_update_frequency": 1,
    "tau": 0.005, # This is for polyak averaging of target network
    
    # Normalization
    "normalize_obs": False,
    "normalize_value": False,
    
    # Temperature
    "alpha_lr": 3e-4,
    "init_alpha": 0.2,
    "target_entropy": None,
    
    # Model architecture
    "hidden_dim": 256,
    "num_hidden": 2,
    
    # Evaluation
    "evaluation_frequency": 5000,
    "eval_cfg": {
        "num_episodes": 10,
        "seed": 1,
        "render": True,
        "clip_action": True,
        "max_action": 1.,
        "min_action": -1.,
    }
}
cfg = Namespace(**cfg_dict)
eval_cfg = Namespace(**cfg.eval_cfg)
wandb.config = cfg_dict

In [4]:
np.random.seed(cfg.seed)

In [5]:
env = gym.make(cfg.env)

  logger.warn(
  logger.warn(
objc[15346]: Class GLFWWindowDelegate is implemented in both /usr/local/Cellar/glfw/3.3.7/lib/libglfw.3.3.dylib (0x1141ec7b0) and /Users/chanb/.mujoco/mujoco210/bin/libglfw.3.dylib (0x118f7d700). One of the two will be used. Which one is undefined.
objc[15346]: Class GLFWApplicationDelegate is implemented in both /usr/local/Cellar/glfw/3.3.7/lib/libglfw.3.3.dylib (0x1141ec788) and /Users/chanb/.mujoco/mujoco210/bin/libglfw.3.dylib (0x118f7d778). One of the two will be used. Which one is undefined.
objc[15346]: Class GLFWContentView is implemented in both /usr/local/Cellar/glfw/3.3.7/lib/libglfw.3.3.dylib (0x1141ec800) and /Users/chanb/.mujoco/mujoco210/bin/libglfw.3.dylib (0x118f7d7a0). One of the two will be used. Which one is undefined.
objc[15346]: Class GLFWWindow is implemented in both /usr/local/Cellar/glfw/3.3.7/lib/libglfw.3.3.dylib (0x1141ec878) and /Users/chanb/.mujoco/mujoco210/bin/libglfw.3.dylib (0x118f7d818). One of the two will be used. Whic

In [6]:
cfg.obs_dim = env.observation_space.shape
cfg.act_dim = env.action_space.shape
if cfg.target_entropy == "auto":
    cfg.target_entropy = -float(np.product(env.action_space.shape))
cfg.action_space = CONTINUOUS

cfg.random_exploration = None
if getattr(cfg ,"exploration_steps", False):
    cfg.random_exploration = random_exploration_generator(cfg.exploration_strategy,
                                                          cfg.act_dim,
                                                          getattr(cfg, "min_action", -1.),
                                                          getattr(cfg, "max_action", 1.))

In [7]:
cfg.h_state_dim = (1,)
cfg.rew_dim = (1,)

In [None]:
cfg.buffer_rng = np.random.RandomState(cfg.seed)
cfg.env_rng = np.random.RandomState(cfg.seed)
cfg.agent_key, cfg.model_key = jrandom.split(jrandom.PRNGKey(cfg.seed), num=2)
eval_cfg.env_rng = np.random.RandomState(eval_cfg.seed)
cfg.evaluation_cfg = eval_cfg

In [8]:
cfg

Namespace(env='Hopper-v2', seed=0, render=False, load_step=0, log_interval=5000, max_timesteps=1000000, buffer_size=1000000, buffer_warmup=1000, num_gradient_steps=1, batch_size=256, max_grad_norm=False, gamma=0.99, update_frequency=1, exploration_steps=1000, exploration_strategy='standard_gaussian', actor_lr=0.0003, actor_update_frequency=1, critic_lr=0.0003, target_update_frequency=1, tau=0.005, normalize_obs=False, normalize_value=False, alpha_lr=0.0003, init_alpha=0.2, target_entropy=None, hidden_dim=256, num_hidden=2, eval_cfg={'max_episodes': 100, 'seed': 1, 'render': True}, obs_dim=(11,), act_dim=(3,), action_space='continuous', random_exploration=<function random_exploration_generator.<locals>.sample_standard_gaussian at 0x11a919750>, h_state_dim=(1,), rew_dim=(1,))

In [10]:
POLICY = "policy"
Q = "q"
TEMPERATURE = "temperature"

buffer = NextStateNumPyBuffer(
    buffer_size=cfg.buffer_size,
    obs_dim=cfg.obs_dim,
    h_state_dim=cfg.h_state_dim,
    act_dim=(1,) if cfg.action_space == DISCRETE else cfg.act_dim,
    rew_dim=cfg.rew_dim,
    rng=cfg.buffer_rng,
)

policy_key, q_key = jrandom.split(cfg.model_key)
policy = MLPSquashedGaussianPolicy(
        obs_dim=cfg.obs_dim,
        act_dim=cfg.act_dim,
        hidden_dim=cfg.hidden_dim,
        num_hidden=cfg.num_hidden,
        key=policy_key,
)

temperature = Temperature(init_alpha=cfg.init_alpha)

q_constructor = partial(MLPQ,
                        in_dim=(cfg.obs_dim[0] + cfg.act_dim[0],),
                        out_dim=(1,),
                        hidden_dim=cfg.hidden_dim,
                        num_hidden=cfg.num_hidden)

q = MultiQ(q_constructor,
           num_qs=2,
           key=q_key)

target_q = MultiQ(q_constructor,
                  num_qs=2,
                  key=q_key)

model = {
    POLICY: policy,
    TEMPERATURE: temperature,
    Q: q,
}

target_model = {
    Q: target_q
}

q_opt_transforms = [
    optax.scale_by_adam(),
    optax.scale(-cfg.critic_lr)
]

policy_opt_transforms = [
    optax.scale_by_adam(),
    optax.scale(-cfg.actor_lr)
]

temperature_opt_transforms = [
    optax.scale_by_adam(),
    optax.scale(-cfg.alpha_lr)
]

if cfg.max_grad_norm:
    q_opt_transforms.insert(0, optax.clip_by_global_norm(cfg.max_grad_norm))
    policy_opt_transforms.insert(0, optax.clip_by_global_norm(cfg.max_grad_norm))
    temperature_opt_transforms.insert(0, optax.clip_by_global_norm(cfg.max_grad_norm))
opt = {
    Q: optax.chain(*q_opt_transforms),
    POLICY: optax.chain(*policy_opt_transforms),
    TEMPERATURE: optax.chain(*temperature_opt_transforms)
}

learner = SAC(model=model,
              target_model=target_model,
              opt=opt,
              buffer=buffer,
              cfg=cfg)

agent = RLAgent(model=model,
                model_key=POLICY,
                buffer=buffer,
                learner=learner,
                key=cfg.agent_key)

In [11]:
%wandb

In [12]:
interact(env, agent, cfg)

KeyboardInterrupt: 

In [None]:
wandb.finish()