In [None]:
import gym
import jax.random as jrandom
import numpy as np
import optax
import timeit
import wandb

from argparse import Namespace
from functools import partial

from jax_learning.agents.rl_agents import EpsilonGreedyAgent, RLAgent
from jax_learning.buffers.ram_buffers import TrajectoryNumPyBuffer
from jax_learning.common import init_wandb
from jax_learning.constants import DISCRETE, CONTINUOUS
from jax_learning.rl_utils import interact, evaluate, random_exploration_generator
from jax_learning.learners.path_consistency import PCL
from jax_learning.models import Temperature
from jax_learning.models.policies import MLPSquashedGaussianPolicy, MLPGaussianPolicy
from jax_learning.models.value_functions import MLPValue

In [None]:
# Can be "online", "offline" or "disabled".
init_wandb(project="test_jax_rl", group="hopper-pcl_test", mode="disabled")

In [None]:
cfg_dict = {
    # Environment setup
    "env": "Hopper-v2",
    "seed": 0,
    "render": False,
    # Experiment progress
    "load_step": 0,
    "log_interval": 5000,
    # Learning hyperparameters
    "max_timesteps": 1000000,
    "buffer_size": 1000000,
    "buffer_warmup": 1000,
    "num_gradient_steps": 1,
    "batch_size": 32,
    "max_grad_norm": False,
    "gamma": 0.99,
    "update_frequency": 1,
    "exploration_steps": 0,
    "exploration_strategy": "standard_gaussian",
    # Actor
    "actor_lr": 3e-4,
    # Critic
    "critic_lr": 3e-4,
    # Rollout
    "horizon_length": 20,
    # Epsilon greedy hyperparameters
    "init_eps": 1.0,
    "min_eps": 0.02,
    "eps_decay": 0.9,
    "eps_warmup": 1000,
    # Normalization
    "normalize_obs": False,
    "normalize_value": False,
    # Temperature
    "alpha_lr": 3e-4,
    "init_alpha": 1.0,
    "target_entropy": -3,
    # Model architecture
    "hidden_dim": 256,
    "num_hidden": 2,
    # Evaluation
    "evaluation_frequency": 5000,
    "eval_cfg": {
        "num_episodes": 10,
        "seed": 1,
        "render": False,
        "clip_action": True,
        "max_action": 1.0,
        "min_action": -1.0,
    },
}
cfg = Namespace(**cfg_dict)
eval_cfg = Namespace(**cfg.eval_cfg)
wandb.config = cfg_dict

In [None]:
np.random.seed(cfg.seed)

In [None]:
env = gym.make(cfg.env)

In [None]:
cfg.obs_dim = env.observation_space.shape
cfg.act_dim = env.action_space.shape
if cfg.target_entropy == "auto":
    cfg.target_entropy = -float(np.product(env.action_space.shape))
cfg.action_space = CONTINUOUS

cfg.random_exploration = None
if getattr(cfg, "exploration_steps", False):
    cfg.random_exploration = random_exploration_generator(
        cfg.exploration_strategy,
        cfg.act_dim,
        getattr(cfg, "min_action", -1.0),
        getattr(cfg, "max_action", 1.0),
    )

In [None]:
cfg.h_state_dim = (1,)
cfg.rew_dim = (1,)

In [None]:
cfg.buffer_rng = np.random.RandomState(cfg.seed)
cfg.env_rng = np.random.RandomState(cfg.seed)
cfg.agent_key, cfg.model_key = jrandom.split(jrandom.PRNGKey(cfg.seed), num=2)
eval_cfg.env_rng = np.random.RandomState(eval_cfg.seed)
cfg.evaluation_cfg = eval_cfg

In [None]:
cfg

In [None]:
POLICY = "policy"
Q = "q"
V = "v"
TEMPERATURE = "temperature"

buffer = TrajectoryNumPyBuffer(
    buffer_size=cfg.buffer_size,
    obs_dim=cfg.obs_dim,
    h_state_dim=cfg.h_state_dim,
    act_dim=(1,) if cfg.action_space == DISCRETE else cfg.act_dim,
    rew_dim=cfg.rew_dim,
    rng=cfg.buffer_rng,
)

policy_key, v_key = jrandom.split(cfg.model_key)
policy = MLPGaussianPolicy(
    obs_dim=cfg.obs_dim,
    act_dim=cfg.act_dim,
    hidden_dim=cfg.hidden_dim,
    num_hidden=cfg.num_hidden,
    key=policy_key,
)
# policy = MLPSquashedGaussianPolicy(
#         obs_dim=cfg.obs_dim,
#         act_dim=cfg.act_dim,
#         hidden_dim=cfg.hidden_dim,
#         num_hidden=cfg.num_hidden,
#         key=policy_key,
# )

temperature = Temperature(init_alpha=cfg.init_alpha)

v = MLPValue(
    in_dim=cfg.obs_dim,
    out_dim=(1,),
    hidden_dim=cfg.hidden_dim,
    num_hidden=cfg.num_hidden,
    key=v_key,
)

model = {
    POLICY: policy,
    TEMPERATURE: temperature,
    V: v,
}

v_opt_transforms = [optax.scale_by_adam(), optax.scale(-cfg.critic_lr)]

policy_opt_transforms = [optax.scale_by_adam(), optax.scale(-cfg.actor_lr)]

temperature_opt_transforms = [optax.scale_by_adam(), optax.scale(-cfg.alpha_lr)]

if cfg.max_grad_norm:
    v_opt_transforms.insert(0, optax.clip_by_global_norm(cfg.max_grad_norm))
    policy_opt_transforms.insert(0, optax.clip_by_global_norm(cfg.max_grad_norm))
    temperature_opt_transforms.insert(0, optax.clip_by_global_norm(cfg.max_grad_norm))
opt = {
    V: optax.chain(*v_opt_transforms),
    POLICY: optax.chain(*policy_opt_transforms),
    TEMPERATURE: optax.chain(*temperature_opt_transforms),
}

learner = PCL(model=model, opt=opt, buffer=buffer, cfg=cfg)

# agent = RLAgent(
#     model=model, model_key=POLICY, buffer=buffer, learner=learner, key=cfg.agent_key
# )

agent = EpsilonGreedyAgent(
    model=model,
    model_key=POLICY,
    buffer=buffer,
    learner=learner,
    init_eps=cfg.init_eps,
    min_eps=cfg.min_eps,
    eps_decay=cfg.eps_decay,
    eps_warmup=cfg.eps_warmup,
    action_space=CONTINUOUS,
    action_dim=cfg.act_dim[0],
    key=cfg.agent_key,
)

In [None]:
%wandb

In [None]:
interact(env, agent, cfg)

In [None]:
wandb.finish()