In [1]:
import equinox as eqx
import gym
import jax
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np
import optax
import sys
import timeit
import wandb

from argparse import Namespace
from functools import partial
from jax import grad, jit, vmap
from typing import Sequence

from jax_learning.agents.rl_agents import EpsilonGreedyAgent
from jax_learning.buffers.ram_buffers import NextStateNumPyBuffer
from jax_learning.buffers.utils import batch_flatten, to_jnp
from jax_learning.constants import DISCRETE
from jax_learning.learners.q_learning import QLearning
from jax_learning.models.q_functions import MLPSoftmaxQ
from jax_learning.rl_utils import interact, evaluate

In [2]:
wandb.init(project="test_jax_rl", group="cartpole-dqn_test")
wandb.define_metric("episodic_return", summary="max")

  return LooseVersion(v) >= LooseVersion(check)
[34m[1mwandb[0m: Currently logged in as: [33mchan[0m. Use [1m`wandb login --relogin`[0m to force relogin
  from IPython.core.display import display, HTML  # type: ignore


<wandb.sdk.wandb_metric.Metric at 0x1069f5600>

In [3]:
cfg_dict = {
    # Environment setup
    "env": "CartPole-v1",
    "seed": 0,
    "render": False,
    
    # Experiment progress
    "load_step": 0,
    "log_interval": 5000,
    
    # Learning hyperparameters
    "max_timesteps": 1000000,
    "buffer_size": 1000000,
    "buffer_warmup": 1000,
    "num_gradient_steps": 1,
    "batch_size": 64,
    "lr": 3e-4,
    "max_grad_norm": 10.,
    "gamma": 0.99,
    "update_frequency": 4,
    "target_update_frequency": 1,
    "tau": 0.005, # This is for polyak averaging of target network
    "omega": 1.0, # This is for residual gradient: 1 for semi-gradient
    
    # Epsilon greedy hyperparameters
    "init_eps": 1.0,
    "min_eps": 0.02,
    "eps_decay": 0.9999,
    "eps_warmup": 1000,
    
    # Model architecture
    "hidden_dim": 64,
    "num_hidden": 1,
    
    # Evaluation
    "eval_cfg": {
        "max_episodes": 100,
        "seed": 1,
        "render": True,
    }
}
cfg = Namespace(**cfg_dict)
eval_cfg = Namespace(**cfg.eval_cfg)
wandb.config = cfg_dict

In [4]:
np.random.seed(cfg.seed)

In [5]:
env = gym.make(cfg.env)

  logger.warn(


In [6]:
cfg.obs_dim = env.observation_space.shape
cfg.act_dim = (env.action_space.n,)
cfg.action_space = DISCRETE

In [7]:
cfg.h_state_dim = (1,)
cfg.rew_dim = (1,)

In [8]:
cfg

Namespace(env='CartPole-v1', seed=0, render=False, load_step=0, log_interval=5000, max_timesteps=1000000, buffer_size=1000000, buffer_warmup=1000, num_gradient_steps=1, batch_size=64, lr=0.0003, max_grad_norm=10.0, gamma=0.99, update_frequency=4, target_update_frequency=1, tau=0.005, omega=1.0, init_eps=1.0, min_eps=0.02, eps_decay=0.9999, eps_warmup=1000, hidden_dim=64, num_hidden=1, eval_cfg={'max_episodes': 100, 'seed': 1, 'render': True}, obs_dim=(4,), act_dim=(2,), action_space='discrete', h_state_dim=(1,), rew_dim=(1,))

In [9]:
cfg.buffer_rng = np.random.RandomState(cfg.seed)
cfg.env_rng = np.random.RandomState(cfg.seed)
cfg.agent_key, cfg.model_key = jrandom.split(jrandom.PRNGKey(cfg.seed), num=2)
eval_cfg.env_rng = np.random.RandomState(eval_cfg.seed)



In [10]:
Q = "q"

buffer = NextStateNumPyBuffer(
    buffer_size=cfg.buffer_size,
    obs_dim=cfg.obs_dim,
    h_state_dim=cfg.h_state_dim,
    act_dim=(1,) if cfg.action_space == DISCRETE else cfg.act_dim,
    rew_dim=cfg.rew_dim,
    rng=cfg.buffer_rng,
)

model = {
    Q: MLPSoftmaxQ(obs_dim=cfg.obs_dim,
                   act_dim=cfg.act_dim,
                   hidden_dim=cfg.hidden_dim,
                   num_hidden=cfg.num_hidden,
                   key=cfg.model_key),
}

target_model = {
    Q: MLPSoftmaxQ(obs_dim=cfg.obs_dim,
                   act_dim=cfg.act_dim,
                   hidden_dim=cfg.hidden_dim,
                   num_hidden=cfg.num_hidden,
                   key=cfg.model_key),
}

opt = {
    Q: optax.chain(optax.clip_by_global_norm(cfg.max_grad_norm), # Clip by the gradient by the global norm
                   optax.scale_by_adam(), # Use the updates from adam
                   optax.scale(-1.0)), # Gradient descent
}

learner = QLearning(model=model,
                    target_model=target_model,
                    opt=opt,
                    buffer=buffer,
                    cfg=cfg)

agent = EpsilonGreedyAgent(model=model[Q],
                           buffer=buffer,
                           learner=learner,
                           init_eps=cfg.init_eps,
                           min_eps=cfg.min_eps,
                           eps_decay=cfg.eps_decay,
                           eps_warmup=cfg.eps_warmup,
                           key=cfg.agent_key)

In [11]:
%wandb

In [None]:
interact(env, agent, cfg)

In [None]:
evaluate(env, agent, eval_cfg)

In [None]:
assert 0

In [None]:
buffer

In [None]:
buffer.observations

In [None]:
np.roll(buffer.next_observations, 1, axis=0)

In [None]:
np.concatenate((buffer.observations - np.roll(buffer.next_observations, 1, axis=0), buffer.dones), axis=1)

In [None]:
buffer.sample_with_next_obs(3, buffer.next_observations[19], buffer.hidden_states[0])