In [1]:
import gym
import jax.random as jrandom
import numpy as np
import optax
import timeit
import wandb

from argparse import Namespace
from functools import partial

from jax_learning.agents.rl_agents import RLAgent
from jax_learning.buffers.ram_buffers import TrajectoryNumPyBuffer
from jax_learning.common import init_wandb
from jax_learning.constants import DISCRETE, CONTINUOUS
from jax_learning.rl_utils import interact, evaluate, random_exploration_generator
from jax_learning.learners.path_consistency import PCL
from jax_learning.models import Temperature
from jax_learning.models.policies import MLPSquashedGaussianPolicy, MLPGaussianPolicy
from jax_learning.models.value_functions import MLPValue

  import imp


In [2]:
# Can be "online", "offline" or "disabled".
init_wandb(project="test_jax_rl", group="hopper-pcl_test", mode="disabled")

  return LooseVersion(v) >= LooseVersion(check)


In [3]:
cfg_dict = {
    # Environment setup
    "env": "Hopper-v2",
    "seed": 0,
    "render": False,
    
    # Experiment progress
    "load_step": 0,
    "log_interval": 5000,
    
    # Learning hyperparameters
    "max_timesteps": 1000000,
    "buffer_size": 1000000,
    "buffer_warmup": 1000,
    "num_gradient_steps": 1,
    "batch_size": 50,
    "max_grad_norm": False,
    "gamma": 0.99,
    "update_frequency": 1,
    
    "exploration_steps": 1000,
    "exploration_strategy": "standard_gaussian",
    
    # Actor
    "actor_lr": 3e-4,
    
    # Critic
    "critic_lr": 3e-4,
    
    "horizon_length": 20,
    
    # Normalization
    "normalize_obs": False,
    "normalize_value": False,
    
    # Temperature
    "alpha_lr": 3e-4,
    "init_alpha": 0.2,
    "target_entropy": None,
    
    # Model architecture
    "hidden_dim": 256,
    "num_hidden": 2,
    
    # Evaluation
    "evaluation_frequency": 5000,
    "eval_cfg": {
        "num_episodes": 10,
        "seed": 1,
        "render": False,
        "clip_action": True,
        "max_action": 1.,
        "min_action": -1.,
    }
}
cfg = Namespace(**cfg_dict)
eval_cfg = Namespace(**cfg.eval_cfg)
wandb.config = cfg_dict

In [4]:
np.random.seed(cfg.seed)

In [5]:
env = gym.make(cfg.env)

  logger.warn(
  logger.warn(
  deprecation(
  deprecation(


In [6]:
cfg.obs_dim = env.observation_space.shape
cfg.act_dim = env.action_space.shape
if cfg.target_entropy == "auto":
    cfg.target_entropy = -float(np.product(env.action_space.shape))
cfg.action_space = CONTINUOUS

cfg.random_exploration = None
if getattr(cfg ,"exploration_steps", False):
    cfg.random_exploration = random_exploration_generator(cfg.exploration_strategy,
                                                          cfg.act_dim,
                                                          getattr(cfg, "min_action", -1.),
                                                          getattr(cfg, "max_action", 1.))

In [7]:
cfg.h_state_dim = (1,)
cfg.rew_dim = (1,)

In [8]:
cfg.buffer_rng = np.random.RandomState(cfg.seed)
cfg.env_rng = np.random.RandomState(cfg.seed)
cfg.agent_key, cfg.model_key = jrandom.split(jrandom.PRNGKey(cfg.seed), num=2)
eval_cfg.env_rng = np.random.RandomState(eval_cfg.seed)
cfg.evaluation_cfg = eval_cfg

In [9]:
cfg

Namespace(env='Hopper-v2', seed=0, render=False, load_step=0, log_interval=5000, max_timesteps=1000000, buffer_size=1000000, buffer_warmup=1000, num_gradient_steps=1, batch_size=50, max_grad_norm=False, gamma=0.99, update_frequency=1, exploration_steps=1000, exploration_strategy='standard_gaussian', actor_lr=0.0003, critic_lr=0.0003, horizon_length=20, normalize_obs=False, normalize_value=False, alpha_lr=0.0003, init_alpha=0.2, target_entropy=None, hidden_dim=256, num_hidden=2, evaluation_frequency=5000, eval_cfg={'num_episodes': 10, 'seed': 1, 'render': False, 'clip_action': True, 'max_action': 1.0, 'min_action': -1.0}, obs_dim=(11,), act_dim=(3,), action_space='continuous', random_exploration=<function random_exploration_generator.<locals>.sample_standard_gaussian at 0x175559ea0>, h_state_dim=(1,), rew_dim=(1,), buffer_rng=RandomState(MT19937) at 0x14FD0CA40, env_rng=RandomState(MT19937) at 0x14FD0CD40, agent_key=DeviceArray([4146024105,  967050713], dtype=uint32), model_key=DeviceAr

In [10]:
POLICY = "policy"
Q = "q"
V = "v"
TEMPERATURE = "temperature"

buffer = TrajectoryNumPyBuffer(
    buffer_size=cfg.buffer_size,
    obs_dim=cfg.obs_dim,
    h_state_dim=cfg.h_state_dim,
    act_dim=(1,) if cfg.action_space == DISCRETE else cfg.act_dim,
    rew_dim=cfg.rew_dim,
    rng=cfg.buffer_rng,
)

policy_key, v_key = jrandom.split(cfg.model_key)
policy = MLPGaussianPolicy(
        obs_dim=cfg.obs_dim,
        act_dim=cfg.act_dim,
        hidden_dim=cfg.hidden_dim,
        num_hidden=cfg.num_hidden,
        key=policy_key,
)
# policy = MLPSquashedGaussianPolicy(
#         obs_dim=cfg.obs_dim,
#         act_dim=cfg.act_dim,
#         hidden_dim=cfg.hidden_dim,
#         num_hidden=cfg.num_hidden,
#         key=policy_key,
# )

temperature = Temperature(init_alpha=cfg.init_alpha)

v = MLPValue(in_dim=cfg.obs_dim,
             out_dim=(1,),
             hidden_dim=cfg.hidden_dim,
             num_hidden=cfg.num_hidden,
             key=v_key,)

model = {
    POLICY: policy,
    TEMPERATURE: temperature,
    V: v,
}

v_opt_transforms = [
    optax.scale_by_adam(),
    optax.scale(-cfg.critic_lr)
]

policy_opt_transforms = [
    optax.scale_by_adam(),
    optax.scale(-cfg.actor_lr)
]

temperature_opt_transforms = [
    optax.scale_by_adam(),
    optax.scale(-cfg.alpha_lr)
]

if cfg.max_grad_norm:
    v_opt_transforms.insert(0, optax.clip_by_global_norm(cfg.max_grad_norm))
    policy_opt_transforms.insert(0, optax.clip_by_global_norm(cfg.max_grad_norm))
    temperature_opt_transforms.insert(0, optax.clip_by_global_norm(cfg.max_grad_norm))
opt = {
    V: optax.chain(*v_opt_transforms),
    POLICY: optax.chain(*policy_opt_transforms),
    TEMPERATURE: optax.chain(*temperature_opt_transforms)
}

learner = PCL(model=model,
              opt=opt,
              buffer=buffer,
              cfg=cfg)

agent = RLAgent(model=model,
                model_key=POLICY,
                buffer=buffer,
                learner=learner,
                key=cfg.agent_key)

In [11]:
%wandb

In [12]:
interact(env, agent, cfg)

Current return (episode: 74, is finished: False) with length 98: 117.12971820943532
(DeviceArray([ 0.42114586,  0.9355575 , -0.04016141], dtype=float32), DeviceArray([0.42974144, 0.24413195, 0.3877744 ], dtype=float32))
Evaluation:
Mean return: 103.91280773815738, Mean length: 80.9
Current return (episode: 97, is finished: False) with length 196: 307.73655747663776
(DeviceArray([-0.11082284, -0.10101343, -0.11712176], dtype=float32), DeviceArray([0.20232834, 0.4027467 , 0.19349763], dtype=float32))
Evaluation:
Mean return: 378.8388659294139, Mean length: 192.9
Current return (episode: 122, is finished: False) with length 78: 115.30010039124826
(DeviceArray([-0.5212943 , -0.08940071,  0.23862752], dtype=float32), DeviceArray([0.27584276, 0.30227402, 0.42868024], dtype=float32))
Evaluation:
Mean return: 371.5436614242149, Mean length: 189.9
Current return (episode: 147, is finished: False) with length 130: 215.9397846213762
(DeviceArray([-0.00227779,  0.03066964, -0.22755392], dtype=floa

KeyboardInterrupt: 

In [None]:
wandb.finish()