In [1]:
import equinox as eqx
import gym
import jax
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np
import optax
import timeit
import wandb

from argparse import Namespace
from functools import partial
from jax import grad, jit, vmap
from typing import Sequence

from jax_learning.buffers.ram_buffers import NextStateNumPyBuffer
from jax_learning.buffers.utils import batch_flatten, to_jnp

In [2]:
wandb.init(project="test_jax_rl", group="cartpole-dqn_test")

  return LooseVersion(v) >= LooseVersion(check)
[34m[1mwandb[0m: Currently logged in as: [33mchan[0m. Use [1m`wandb login --relogin`[0m to force relogin
  from IPython.core.display import display, HTML  # type: ignore


In [3]:
cfg_dict = {
    # Environment setup
    "env": "CartPole-v0",
    "seed": 0,
    "render": False,
    
    # Experiment progress
    "load_step": 0,
    "log_interval": 5000,
    
    # Learning hyperparameters
    "max_timesteps": 100000,
    "buffer_size": 100000,
    "buffer_warmup": 10000,
    "num_gradient_steps": 1,
    "batch_size": 128,
    "lr": 3e-4,
    "gamma": 0.99,
    "update_frequency": 1,
    "target_update_frequency": 1000,
    "tau": 1.,
    
    # Epsilon greedy hyperparameters
    "init_eps": 0.5,
    "min_eps": 0.02,
    "eps_decay": 0.9995,
    "eps_warmup": 10000,
    
    # Model architecture
    "hidden_dim": 64,
    "num_hidden": 2,
    
    # Evaluation
    "eval_cfg": {
        "max_episodes": 10,
        "seed": 1,
        "render": True,
    }
}
cfg = Namespace(**cfg_dict)
eval_cfg = Namespace(**cfg.eval_cfg)
wandb.config = cfg_dict

In [4]:
np.random.seed(cfg.seed)

In [5]:
env = gym.make(cfg.env)

  logger.warn(
  logger.warn(


In [6]:
cfg.obs_dim = env.observation_space.shape
cfg.act_dim = (env.action_space.n,)
cfg.action_space = "discrete"

In [7]:
cfg.h_state_dim = (1,)
cfg.rew_dim = (1,)

In [8]:
cfg

Namespace(env='CartPole-v0', seed=0, render=False, load_step=0, log_interval=5000, max_timesteps=100000, buffer_size=100000, buffer_warmup=10000, num_gradient_steps=1, batch_size=128, lr=0.0003, gamma=0.99, update_frequency=1, target_update_frequency=1000, tau=1.0, init_eps=0.5, min_eps=0.02, eps_decay=0.9995, eps_warmup=10000, hidden_dim=64, num_hidden=2, eval_cfg={'max_episodes': 10, 'seed': 1, 'render': True}, obs_dim=(4,), act_dim=(2,), action_space='discrete', h_state_dim=(1,), rew_dim=(1,))

In [9]:
cfg.buffer_rng = np.random.RandomState(cfg.seed)
cfg.env_rng = np.random.RandomState(cfg.seed)
cfg.agent_key, cfg.model_key = jrandom.split(jrandom.PRNGKey(cfg.seed), num=2)
eval_cfg.env_rng = np.random.RandomState(eval_cfg.seed)



In [10]:
def interact(env, agent, buffer, cfg):
    max_timesteps = cfg.max_timesteps
    log_interval = cfg.log_interval
    render = env.render if cfg.render else lambda: None
    env_rng = cfg.env_rng
    
    obs = env.reset(seed=env_rng.randint(0, float("inf")))
    h_state = agent.reset()
    ep_return = 0.
    ep_i = 0
    ep_len = 0
    tic = timeit.default_timer()
    for timestep_i in range(max_timesteps):
        timestep_dict = {
            "timestep": timestep_i
        }
        act, next_h_state = agent.compute_action(obs, h_state)
        next_obs, rew, done, info = env.step(act)
        render()
        buffer.push(obs, h_state, act, rew, done, info, next_obs, next_h_state)
        agent.learner.update(next_obs, next_h_state, timestep_dict)
        obs = next_obs
        ep_return += rew
        ep_len += 1
        
        if done:
            obs = env.reset(seed=env_rng.randint(0, float("inf")))
            h_state = agent.reset()
            timestep_dict["episodic_return"] = ep_return
            timestep_dict["episode"] = ep_i
            timestep_dict["episode_length"] = ep_len
            ep_return = 0.
            ep_len = 0
            ep_i += 1
            
        if (timestep_i + 1) % log_interval == 0:
            toc = timeit.default_timer()
            timestep_dict["time_diff"] = toc - tic
            tic = timeit.default_timer()
            
        wandb.log(timestep_dict)
        
def evaluate(env, agent, cfg):
    max_episodes = cfg.max_episodes
    render = env.render if cfg.render else lambda: None

    for ep_i in range(max_episodes):
        tic = timeit.default_timer()
        obs = env.reset(seed=env_rng.randint(0, float("inf")))
        h_state = agent.reset()
        ep_return = 0.
        ep_len = 0
        done = False
        while not done:
            timestep_dict = {
                "timestep": ep_len
            }
            act, next_h_state = agent.deterministic_action(obs, h_state)
            next_obs, rew, done, info = env.step(act)
            render()
            obs = next_obs
            ep_return += rew
            ep_len += 1

            if done:
                toc = timeit.default_timer()
                timestep_dict["eval_episodic_return"] = ep_return
                timestep_dict["eval_episode"] = ep_i
                timestep_dict["eval_episode_length"] = ep_len
                timestep_dict["eval_per_episode_time"] = toc - tic

            wandb.log(timestep_dict)

In [11]:
class ParameterizedSoftmaxQ(eqx.Module):
    obs_dim: int
    act_dim: int

    def __init__(self, obs_dim, act_dim):
        self.obs_dim = int(np.product(obs_dim))
        self.act_dim = int(np.product(act_dim))

    def greedy_action(self, obs, h_state):
        q_val, h_state = self.q_values(obs, h_state)
        acts = jnp.argmax(q_val, axis=-1)
        return acts, q_val[acts].reshape((1,)), h_state
    
    def random_action(self, obs, h_state, key):
        q_val, h_state = self.q_values(obs, h_state)
        acts = jrandom.categorical(key=key, logits=q_val, axis=-1)
        return acts, q_val[acts].reshape((1,)), h_state
    

class LinearQ(ParameterizedSoftmaxQ):
    linear: eqx.nn.Linear
    bias: jnp.ndarray

    def __init__(self, obs_dim, act_dim, key):
        super().__init__(obs_dim, act_dim)
        self.linear = eqx.nn.Linear(self.obs_dim, self.act_dim, use_bias=False, key=key)
        self.bias = jnp.zeros(self.act_dim)

    def q_values(self, obs, h_state):
        return self.linear(obs) + self.bias, h_state
    

class MLPQ(ParameterizedSoftmaxQ):
    weights: Sequence[eqx.nn.Linear]
    biases: Sequence[jnp.ndarray]

    @property
    def num_hidden(self):
        return len(self.weights) - 1

    def __init__(self, obs_dim, act_dim, hidden_dim, num_hidden, key):
        super().__init__(obs_dim, act_dim)
        self.weights = [eqx.nn.Linear(self.obs_dim, hidden_dim, use_bias=False, key=key)]
        self.biases = [jnp.zeros(hidden_dim)]
        for _ in range(num_hidden - 1):
            key, _ = jrandom.split(key, num=2)
            self.weights.append(eqx.nn.Linear(hidden_dim, hidden_dim, use_bias=False, key=key))
            self.biases.append(jnp.zeros(hidden_dim))
            
        key, _ = jrandom.split(key, num=2)
        self.weights.append(eqx.nn.Linear(hidden_dim, self.act_dim, use_bias=False, key=key))
        self.biases.append(jnp.zeros(self.act_dim))

    def q_values(self, obs, h_state):
        x = obs
        for layer_i in range(self.num_hidden):
            x = jax.nn.relu(self.weights[layer_i](x) + self.biases[layer_i])
        x = self.weights[-1](x) + self.biases[-1]
        return x, h_state

In [12]:
class EpsilonGreedyAgent:
    def __init__(self, model, learner, init_eps, min_eps, eps_decay, eps_warmup, key):
        self._model = model
        self._learner = learner
        self._eps = init_eps
        self._init_eps = init_eps
        self._min_eps = min_eps
        self._eps_decay = eps_decay
        self._eps_warmup = eps_warmup
        self._key = key

    @property
    def model(self):
        return self._model
        
    @property
    def learner(self):
        return self._learner
        
    def deterministic_action(self, obs, h_state):
        return self.greedy_action(obs, h_state)
        
    def greedy_action(self, obs, h_state):
        action, val, next_h_state = self.model.greedy_action(obs, h_state)
        return np.asarray(action), np.asarray(next_h_state)
        
    def compute_action(self, obs, h_state, overwrite_rng_key=True):
        new_key, curr_key = jrandom.split(self._key)
        if jrandom.bernoulli(key=curr_key, p=self._eps):
            val, next_h_state = self.model.q_values(obs, h_state)
            action = jrandom.randint(curr_key, shape=(1,), minval=0, maxval=val.shape[-1]).item()
        else:
            action, val, next_h_state = self.model.greedy_action(obs, h_state)

        if overwrite_rng_key:
            self._key = new_key
            if self._eps_warmup > 0:
                self._eps_warmup -= 1
            else:
                self._eps = max(self._eps * self._eps_decay, self._min_eps)

        return np.asarray(action), np.asarray(next_h_state)
    
    def reset(self):
        if hasattr(self.model, "reset"):
            return self.model.reset()
        return np.array([0.], dtype=np.float32)

In [13]:
@eqx.filter_grad(has_aux=True)
def compute_one_step_td_loss(model, obss, h_states, acts, targets):
    curr_q_values, _ = jax.vmap(model.q_values)(obss, h_states)
    loss = jnp.mean((jnp.sum(curr_q_values * acts, axis=-1, keepdims=True) - targets) ** 2)
    return loss, {
        "loss": loss,
        "max_target": jnp.max(targets),
        "min_target": jnp.min(targets),
        "mean_target": jnp.mean(targets),
        "max_q": jnp.max(curr_q_values),
        "min_q": jnp.min(curr_q_values),
        "mean_q": jnp.mean(curr_q_values),
    }

@eqx.filter_jit
def make_step(model, target_model, opt, opt_state, obss, h_states, acts, rews, dones, next_obss, next_h_states, gamma):
    _, next_q_values, _ = jax.vmap(target_model.greedy_action)(next_obss, next_h_states)
    targets = rews + (1 - dones) * (gamma * next_q_values)
    
    grads, update_info = compute_one_step_td_loss(model,
                                                  obss,
                                                  h_states,
                                                  acts,
                                                  targets)
    updates, opt_state = opt.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return model, opt_state, grads, update_info


class QLearning:
    def __init__(self, model, target_model, opt, buffer, cfg):
        self._model = model
        self._target_model = target_model
        self._opt = opt
        self._opt_state = self._opt.init(model)
        self._buffer = buffer
        self._cfg = cfg
        
        self._step = cfg.load_step
        
        self._batch_size = cfg.batch_size
        self._buffer_warmup = cfg.buffer_warmup
        self._num_gradient_steps = cfg.num_gradient_steps
        self._gamma = cfg.gamma
        self._tau = cfg.tau
        self._update_frequency = cfg.update_frequency
        self._target_update_frequency = cfg.target_update_frequency

    @property
    def model(self):
        return self._model
    
    @property
    def target_model(self):
        return self._target_model
        
    @property
    def buffer(self):
        return self._buffer
    
    @property
    def opt(self):
        return self._opt
    
    @property
    def opt_state(self):
        return self._opt_state
        
    def update(self, next_obs, next_h_state, update_info):
        self._step += 1
        
        if self._step <= self._buffer_warmup or (self._step - 1 - self._buffer_warmup) % self._update_frequency != 0:
            return

        update_info["mean_loss"] = 0.
        update_info["mean_q"] = 0.
        update_info["mean_target"] = 0.
        update_info["max_q"] = -np.inf
        update_info["max_target"] = -np.inf
        update_info["min_q"] = np.inf
        update_info["min_target"] = np.inf
        for update_i in range(self._num_gradient_steps):
            obss, h_states, acts, rews, dones, next_obss, next_h_states, _, _, _ = self.buffer.sample_with_next_obs(batch_size=self._batch_size,
                                                                                                                    next_obs=next_obs,
                                                                                                                    next_h_state=next_h_state)

            acts = np.eye(self.model.act_dim)[acts.astype(np.int64)]
            
            (obss, h_states, acts, rews, dones, next_obss, next_h_states) = to_jnp(*batch_flatten(obss,
                                                                                                  h_states,
                                                                                                  acts,
                                                                                                  rews,
                                                                                                  dones,
                                                                                                  next_obss,
                                                                                                  next_h_states))
            model, opt_state, grads, curr_update_info = make_step(model=self.model,
                                                                  target_model=self.target_model,
                                                                  opt=self.opt,
                                                                  opt_state=self.opt_state,
                                                                  obss=obss,
                                                                  h_states=h_states,
                                                                  acts=acts,
                                                                  rews=rews,
                                                                  dones=dones,
                                                                  next_obss=next_obss,
                                                                  next_h_states=next_h_states,
                                                                  gamma=self._gamma)
            self._model = model
            self._opt_state = opt_state
            
            if self._step % self._target_update_frequency == 0:
                self._target_model = jax.tree_map(lambda p, tp: p * self._tau + tp * (1 - self._tau),
                                                  self.model,
                                                  self.target_model)
            
            update_info["mean_loss"] += curr_update_info["loss"].item() / self._num_gradient_steps
            update_info["mean_q"] += curr_update_info["mean_q"].item() / self._num_gradient_steps
            update_info["mean_target"] += curr_update_info["mean_target"].item() / self._num_gradient_steps
            update_info["max_q"] = max(update_info["max_q"], curr_update_info["max_q"].item())
            update_info["max_target"] = max(update_info["max_target"], curr_update_info["max_target"].item())
            update_info["min_q"] = min(update_info["min_q"], curr_update_info["min_q"].item())
            update_info["min_target"] = min(update_info["min_target"], curr_update_info["min_target"].item())

In [14]:
buffer = NextStateNumPyBuffer(
    buffer_size=cfg.buffer_size,
    obs_dim=cfg.obs_dim,
    h_state_dim=cfg.h_state_dim,
    act_dim=(1,) if cfg.action_space == "discrete" else cfg.act_dim,
    rew_dim=cfg.rew_dim,
    rng=cfg.buffer_rng,
)

model = MLPQ(obs_dim=cfg.obs_dim,
                act_dim=cfg.act_dim,
                hidden_dim=cfg.hidden_dim,
                num_hidden=cfg.num_hidden,
                key=cfg.model_key)

target_model = MLPQ(obs_dim=cfg.obs_dim,
                    act_dim=cfg.act_dim,
                    hidden_dim=cfg.hidden_dim,
                    num_hidden=cfg.num_hidden,
                    key=cfg.model_key)

opt = optax.adam(cfg.lr)

learner = QLearning(model=model,
                    target_model=target_model,
                    opt=opt,
                    buffer=buffer,
                    cfg=cfg)

agent = EpsilonGreedyAgent(model=model,
                           learner=learner,
                           init_eps=cfg.init_eps,
                           min_eps=cfg.min_eps,
                           eps_decay=cfg.eps_decay,
                           eps_warmup=cfg.eps_warmup,
                           key=cfg.agent_key)

In [15]:
interact(env, agent, buffer, cfg)

OverflowError: cannot convert float infinity to integer

In [None]:
evaluate(eval_env, agent, eval_cfg)