In [None]:
import equinox as eqx
import gym
import jax
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np
import optax
import sys
import timeit
import wandb

from argparse import Namespace
from functools import partial
from jax import grad, jit, vmap
from typing import Sequence, Tuple, Dict

from jax_learning.agents.rl_agents import OfflineOnlineRLAgent
from jax_learning.buffers.ram_buffers import NextStateNumPyBuffer
from jax_learning.buffers.utils import batch_flatten, to_jnp
from jax_learning.common import init_wandb
from jax_learning.constants import DISCRETE, CONTINUOUS, OFFLINE, ONLINE
from jax_learning.learners.behavioural_cloning import BC
from jax_learning.learners.reinforce import REINFORCE
from jax_learning.models.policies import MLPGaussianPolicy
from jax_learning.rl_utils import interact, evaluate

In [None]:
init_wandb(
    project="test_jax_rl", group="hopper-bc_test", mode="disabled"
)

In [None]:
cfg_dict = {
    # Environment setup
    "env": "Hopper-v2",
    "seed": 0,
    "render": False,
    "clip_action": True,
    "max_action": 1.0,
    "min_action": -1.0,
    # Experiment progress
    "log_interval": 50000,
    "checkpoint_frequency": 5000,
    "save_path": None,
    "load_path": None,
    # Learning hyperparameters
    "max_timesteps": 1000000,
    "pretrain": {
        "num_updates": 1000,
        "checkpoint_frequency": 1000,
        "log_interval": 1000,
        "evaluation_frequency": 1000,
    },
    "bc": {
        "lr": 3e-4,
        "batch_size": 512,
        "max_grad_norm": 10.0,
        "expert_buffer_path": "../data/hopper_medium_expert-v2.pkl",
    },
    "reinforce": {
        "update_frequency": 10000,
        "lr": 1e-5,
        "max_grad_norm": 10.0,
        "gamma": 0.99,
    },
    # Normalization
    "normalize_obs": False,
    "normalize_value": False,
    # Model architecture
    "hidden_dim": 128,
    "num_hidden": 2,
    # Evaluation
    "evaluation_frequency": 50000,
    "eval_cfg": {
        "num_episodes": 50,
        "seed": 1,
        "render": True,
        "clip_action": True,
        "max_action": 1.0,
        "min_action": -1.0,
    },
}
cfg = Namespace(**cfg_dict)
eval_cfg = Namespace(**cfg.eval_cfg)
wandb.config = cfg_dict

In [None]:
np.random.seed(cfg.seed)

In [None]:
env = gym.make(cfg.env)

In [None]:
cfg.obs_dim = env.observation_space.shape
cfg.act_dim = env.action_space.shape
cfg.action_space = CONTINUOUS

In [None]:
cfg.h_state_dim = (1,)
cfg.rew_dim = (1,)

In [None]:
cfg.buffer_rng = np.random.RandomState(cfg.seed)
cfg.env_rng = np.random.RandomState(cfg.seed)
cfg.agent_key, cfg.model_key = jrandom.split(jrandom.PRNGKey(cfg.seed), num=2)
eval_cfg.env_rng = np.random.RandomState(eval_cfg.seed)
cfg.evaluation_cfg = eval_cfg
cfg.bc = Namespace(**cfg.bc)
cfg.reinforce = Namespace(**cfg.reinforce)
cfg.pretrain = Namespace(**cfg.pretrain)

In [None]:
cfg

In [None]:
POLICY = "policy"

bc_buffer = NextStateNumPyBuffer(
    buffer_size=0,
    obs_dim=cfg.obs_dim,
    h_state_dim=cfg.h_state_dim,
    act_dim=cfg.act_dim,
    rew_dim=cfg.rew_dim,
    load_buffer=cfg.bc.expert_buffer_path,
)

reinforce_buffer = NextStateNumPyBuffer(
    buffer_size=cfg.reinforce.update_frequency,
    obs_dim=cfg.obs_dim,
    h_state_dim=cfg.h_state_dim,
    act_dim=cfg.act_dim,
    rew_dim=cfg.rew_dim,
)

model = {
    POLICY: MLPGaussianPolicy(
        obs_dim=cfg.obs_dim,
        act_dim=cfg.act_dim,
        hidden_dim=cfg.hidden_dim,
        num_hidden=cfg.num_hidden,
        key=cfg.model_key,
        min_std=1e-7,
    )
}

opt_transforms = [optax.scale_by_rms(), optax.scale(-cfg.bc.lr)]
if cfg.bc.max_grad_norm:
    opt_transforms.insert(0, optax.clip_by_global_norm(cfg.bc.max_grad_norm))
bc_opt = {POLICY: optax.chain(*opt_transforms)}

opt_transforms = [optax.scale_by_rms(), optax.scale(-cfg.reinforce.lr)]
if cfg.reinforce.max_grad_norm:
    opt_transforms.insert(0, optax.clip_by_global_norm(cfg.reinforce.max_grad_norm))
reinforce_opt = {POLICY: optax.chain(*opt_transforms)}

learners = {
    OFFLINE: BC(model=model, opt=bc_opt, buffer=bc_buffer, cfg=cfg.bc),
    ONLINE: REINFORCE(model=model, opt=reinforce_opt, buffer=reinforce_buffer, cfg=cfg.reinforce)
}

agent = OfflineOnlineRLAgent(
    model=model, model_key=POLICY, buffer=reinforce_buffer, learners=learners, key=cfg.agent_key
)

In [None]:
if cfg.load_path:
    agent.load(cfg.load_path)

In [None]:
%wandb

In [None]:
interact(env, agent, cfg)

In [None]:
wandb.finish()