In [1]:
import gym
import jax.random as jrandom
import numpy as np
import optax
import timeit
import wandb

from argparse import Namespace
from functools import partial

from jax_learning.agents.rl_agents import RLAgent
from jax_learning.buffers.ram_buffers import TrajectoryNumPyBuffer
from jax_learning.common import init_wandb
from jax_learning.constants import DISCRETE, CONTINUOUS
from jax_learning.rl_utils import interact, evaluate, random_exploration_generator
from jax_learning.learners.path_consistency import PCL
from jax_learning.models import Temperature
from jax_learning.models.policies import MLPSquashedGaussianPolicy
from jax_learning.models.value_functions import MLPValue

  import imp


In [2]:
# Can be "online", "offline" or "disabled".
init_wandb(project="test_jax_rl", group="hopper-pcl_test", mode="disabled")

  return LooseVersion(v) >= LooseVersion(check)
[34m[1mwandb[0m: Currently logged in as: [33mchanb[0m. Use [1m`wandb login --relogin`[0m to force relogin
  from IPython.core.display import display, HTML  # type: ignore


In [3]:
cfg_dict = {
    # Environment setup
    "env": "Hopper-v2",
    "seed": 0,
    "render": False,
    
    # Experiment progress
    "load_step": 0,
    "log_interval": 5000,
    
    # Learning hyperparameters
    "max_timesteps": 1000000,
    "buffer_size": 1000000,
    "buffer_warmup": 1000,
    "num_gradient_steps": 1,
    "batch_size": 256,
    "max_grad_norm": False,
    "gamma": 0.99,
    "update_frequency": 1,
    
    "exploration_steps": 1000,
    "exploration_strategy": "standard_gaussian",
    
    # Actor
    "actor_lr": 3e-4,
    
    # Critic
    "critic_lr": 3e-4,
    
    # Normalization
    "normalize_obs": False,
    "normalize_value": False,
    
    # Temperature
    "alpha_lr": 3e-4,
    "init_alpha": 0.2,
    "target_entropy": None,
    
    # Model architecture
    "hidden_dim": 256,
    "num_hidden": 2,
    
    # Evaluation
    "evaluation_frequency": 5000,
    "eval_cfg": {
        "num_episodes": 10,
        "seed": 1,
        "render": True,
        "clip_action": True,
        "max_action": 1.,
        "min_action": -1.,
    }
}
cfg = Namespace(**cfg_dict)
eval_cfg = Namespace(**cfg.eval_cfg)
wandb.config = cfg_dict

In [4]:
np.random.seed(cfg.seed)

In [5]:
env = gym.make(cfg.env)

  logger.warn(
  logger.warn(
  deprecation(
  deprecation(


In [6]:
cfg.obs_dim = env.observation_space.shape
cfg.act_dim = env.action_space.shape
if cfg.target_entropy == "auto":
    cfg.target_entropy = -float(np.product(env.action_space.shape))
cfg.action_space = CONTINUOUS

cfg.random_exploration = None
if getattr(cfg ,"exploration_steps", False):
    cfg.random_exploration = random_exploration_generator(cfg.exploration_strategy,
                                                          cfg.act_dim,
                                                          getattr(cfg, "min_action", -1.),
                                                          getattr(cfg, "max_action", 1.))

In [7]:
cfg.h_state_dim = (1,)
cfg.rew_dim = (1,)

In [8]:
cfg.buffer_rng = np.random.RandomState(cfg.seed)
cfg.env_rng = np.random.RandomState(cfg.seed)
cfg.agent_key, cfg.model_key = jrandom.split(jrandom.PRNGKey(cfg.seed), num=2)
eval_cfg.env_rng = np.random.RandomState(eval_cfg.seed)
cfg.evaluation_cfg = eval_cfg

In [9]:
cfg

Namespace(env='Hopper-v2', seed=0, render=False, load_step=0, log_interval=5000, max_timesteps=1000000, buffer_size=1000000, buffer_warmup=1000, num_gradient_steps=1, batch_size=256, max_grad_norm=False, gamma=0.99, update_frequency=1, exploration_steps=1000, exploration_strategy='standard_gaussian', actor_lr=0.0003, actor_update_frequency=1, critic_lr=0.0003, target_update_frequency=1, tau=0.005, normalize_obs=False, normalize_value=False, alpha_lr=0.0003, init_alpha=0.2, target_entropy=None, hidden_dim=256, num_hidden=2, evaluation_frequency=5000, eval_cfg={'num_episodes': 10, 'seed': 1, 'render': True, 'clip_action': True, 'max_action': 1.0, 'min_action': -1.0}, obs_dim=(11,), act_dim=(3,), action_space='continuous', random_exploration=<function random_exploration_generator.<locals>.sample_standard_gaussian at 0x299c1fac0>, h_state_dim=(1,), rew_dim=(1,), buffer_rng=RandomState(MT19937) at 0x290A78E40, env_rng=RandomState(MT19937) at 0x290A79140, agent_key=DeviceArray([4146024105,  

In [10]:
POLICY = "policy"
Q = "q"
V = "v"
TEMPERATURE = "temperature"

buffer = NextStateNumPyBuffer(
    buffer_size=cfg.buffer_size,
    obs_dim=cfg.obs_dim,
    h_state_dim=cfg.h_state_dim,
    act_dim=(1,) if cfg.action_space == DISCRETE else cfg.act_dim,
    rew_dim=cfg.rew_dim,
    rng=cfg.buffer_rng,
)

policy_key, q_key = jrandom.split(cfg.model_key)
policy = MLPSquashedGaussianPolicy(
        obs_dim=cfg.obs_dim,
        act_dim=cfg.act_dim,
        hidden_dim=cfg.hidden_dim,
        num_hidden=cfg.num_hidden,
        key=policy_key,
)

temperature = Temperature(init_alpha=cfg.init_alpha)

q_constructor = partial(MLPQ,
                        in_dim=(cfg.obs_dim[0] + cfg.act_dim[0],),
                        out_dim=(1,),
                        hidden_dim=cfg.hidden_dim,
                        num_hidden=cfg.num_hidden)

v = MLPV(in_dim=cfg.obs_dim,
         out_dim=(1,),
         hidden_dim=cfg.hidden_dim,
         num_hidden=cfg.num_hidden)

model = {
    POLICY: policy,
    TEMPERATURE: temperature,
    V: v,
}

v_opt_transforms = [
    optax.scale_by_adam(),
    optax.scale(-cfg.critic_lr)
]

policy_opt_transforms = [
    optax.scale_by_adam(),
    optax.scale(-cfg.actor_lr)
]

temperature_opt_transforms = [
    optax.scale_by_adam(),
    optax.scale(-cfg.alpha_lr)
]

if cfg.max_grad_norm:
    v_opt_transforms.insert(0, optax.clip_by_global_norm(cfg.max_grad_norm))
    policy_opt_transforms.insert(0, optax.clip_by_global_norm(cfg.max_grad_norm))
    temperature_opt_transforms.insert(0, optax.clip_by_global_norm(cfg.max_grad_norm))
opt = {
    V: optax.chain(*v_opt_transforms),
    POLICY: optax.chain(*policy_opt_transforms),
    TEMPERATURE: optax.chain(*temperature_opt_transforms)
}

learner = PCL(model=model,
              opt=opt,
              buffer=buffer,
              cfg=cfg)

agent = RLAgent(model=model,
                model_key=POLICY,
                buffer=buffer,
                learner=learner,
                key=cfg.agent_key)

In [11]:
%wandb

In [None]:
interact(env, agent, cfg)

If you want to render in human mode, initialize the environment in this way: gym.make('EnvName', render_mode='human') and don't call the render method.
See here for more information: https://www.gymlibrary.ml/content/api/[0m
  deprecation(


Creating window glfw


In [None]:
wandb.finish()