In [1]:
import equinox as eqx
import gym
import jax
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np
import optax
import timeit
import wandb

from argparse import Namespace
from functools import partial
from jax import grad, jit, vmap
from typing import Sequence

from jax_learning.buffers.ram_buffers import NextStateNumPyBuffer
from jax_learning.buffers.utils import batch_flatten, to_jnp

In [2]:
wandb.init(project="test_jax_rl")

  return LooseVersion(v) >= LooseVersion(check)
[34m[1mwandb[0m: Currently logged in as: [33mchanb[0m. Use [1m`wandb login --relogin`[0m to force relogin
  from IPython.core.display import display, HTML  # type: ignore


In [3]:
cfg_dict = {
    # Environment setup
    "env": "MountainCar-v0",
    "seed": 0,
    "render": False,
    
    # Experiment progress
    "load_step": 0,
    "log_interval": 5000,
    
    # Learning hyperparameters
    "max_timesteps": 1000000,
    "buffer_size": 1000000,
    "buffer_warmup": 1000,
    "num_gradient_steps": 1,
    "batch_size": 256,
    "lr": 3e-2,
    "gamma": 0.99,
    "tau": 0.005,
    
    # Epsilon greedy hyperparameters
    "init_eps": 0.9999,
    "min_eps": 0.0,
    "eps_decay": 0.9999,
    "eps_warmup": 1000,
    
    # Model architecture
    "hidden_dim": 256,
    "num_hidden": 2,
}
cfg = Namespace(**cfg_dict)
wandb.config = cfg_dict

In [4]:
env = gym.make(cfg.env)
env.seed(cfg.seed)

  logger.warn(
  deprecation(


[0]

In [5]:
cfg.obs_dim = env.observation_space.shape
cfg.act_dim = (env.action_space.n,)
cfg.action_space = "discrete"

In [6]:
cfg.h_state_dim = (1,)
cfg.rew_dim = (1,)

In [7]:
cfg

Namespace(env='MountainCar-v0', seed=0, render=False, load_step=0, log_interval=5000, max_timesteps=1000000, buffer_size=1000000, buffer_warmup=1000, num_gradient_steps=1, batch_size=256, lr=0.03, gamma=0.99, tau=0.005, init_eps=0.9999, min_eps=0.0, eps_decay=0.9999, eps_warmup=1000, hidden_dim=256, num_hidden=2, obs_dim=(2,), act_dim=(3,), action_space='discrete', h_state_dim=(1,), rew_dim=(1,))

In [8]:
cfg.buffer_rng = np.random.RandomState(cfg.seed)
cfg.agent_key, cfg.model_key = jrandom.split(jrandom.PRNGKey(cfg.seed), num=2)



In [9]:
def interact(agent, buffer, cfg):
    max_timesteps = cfg.max_timesteps
    log_interval = cfg.log_interval
    render = env.render if cfg.render else lambda: None
    
    obs = env.reset()
    h_state = agent.reset()
    ep_return = 0.
    ep_i = 0
    tic = timeit.default_timer()
    for timestep_i in range(max_timesteps):
        timestep_dict = {
            "timestep": timestep_i
        }
        act, next_h_state = agent.compute_action(obs, h_state)
        # TODO: Action repeat for simpler environment for now
        curr_rew = 0.
        for _ in range(1):
            next_obs, rew, done, info = env.step(act)
            curr_rew += rew
            if done:
                break
        rew = curr_rew
        render()
        buffer.push(obs, h_state, act, rew, done, info, next_obs, next_h_state)
        agent.learner.update(next_obs, next_h_state, timestep_dict)
        obs = next_obs
        ep_return += rew
        
        if done:
            obs = env.reset()
            h_state = agent.reset()
            timestep_dict["episodic_return"] = ep_return
            timestep_dict["episode"] = ep_i
            ep_return = 0.
            ep_i += 1
            
        if (timestep_i + 1) % log_interval == 0:
            toc = timeit.default_timer()
            timestep_dict["time_diff"] = toc - tic
            tic = timeit.default_timer()
            
        wandb.log(timestep_dict)

In [10]:
class ParameterizedSoftmaxQ(eqx.Module):
    obs_dim: int
    act_dim: int

    def __init__(self, obs_dim, act_dim):
        self.obs_dim = int(np.product(obs_dim))
        self.act_dim = int(np.product(act_dim))

    def greedy_action(self, obs, h_state):
        q_val, h_state = self.q_values(obs, h_state)
        acts = jnp.argmax(q_val, axis=-1)
        return acts, q_val[acts].reshape((1,)), h_state
    
    def random_action(self, obs, h_state, key):
        q_val, h_state = self.q_values(obs, h_state)
        acts = jrandom.categorical(key=key, logits=q_val, axis=-1)
        return acts, q_val[acts].reshape((1,)), h_state
    

class LinearQ(ParameterizedSoftmaxQ):
    linear: eqx.nn.Linear
    bias: jnp.ndarray

    def __init__(self, obs_dim, act_dim, key):
        super().__init__(obs_dim, act_dim)
        self.linear = eqx.nn.Linear(self.obs_dim, self.act_dim, use_bias=False, key=key)
        self.bias = jnp.zeros(self.act_dim)

    def q_values(self, obs, h_state):
        return self.linear(obs) + self.bias, h_state
    

class MLPQ(ParameterizedSoftmaxQ):
    weights: Sequence[eqx.nn.Linear]
    biases: Sequence[jnp.ndarray]

    @property
    def num_hidden(self):
        return len(self.weights) - 1

    def __init__(self, obs_dim, act_dim, hidden_dim, num_hidden, key):
        super().__init__(obs_dim, act_dim)
        self.weights = [eqx.nn.Linear(self.obs_dim, hidden_dim, use_bias=False, key=key)]
        self.biases = [jnp.zeros(hidden_dim)]
        for _ in range(num_hidden - 1):
            key, _ = jrandom.split(key, num=2)
            self.weights.append(eqx.nn.Linear(hidden_dim, hidden_dim, use_bias=False, key=key))
            self.biases.append(jnp.zeros(hidden_dim))
            
        key, _ = jrandom.split(key, num=2)
        self.weights.append(eqx.nn.Linear(hidden_dim, self.act_dim, use_bias=False, key=key))
        self.biases.append(jnp.zeros(self.act_dim))

    def q_values(self, obs, h_state):
        x = obs
        for layer_i in range(self.num_hidden):
            x = jax.nn.relu(self.weights[layer_i](x) + self.biases[layer_i])
        x = self.weights[-1](x) + self.biases[-1]
        return x, h_state

In [11]:
class EpsilonGreedyAgent:
    def __init__(self, model, learner, init_eps, min_eps, eps_decay, eps_warmup, key):
        self._model = model
        self._learner = learner
        self._eps = init_eps
        self._init_eps = init_eps
        self._min_eps = min_eps
        self._eps_decay = eps_decay
        self._eps_warmup = eps_warmup
        self._key = key

    @property
    def model(self):
        return self._model
        
    @property
    def learner(self):
        return self._learner
        
    def greedy_action(self, obs, h_state):
        action, val, next_h_state = self.model.greedy_action(obs, h_state)
        return np.asarray(action), np.asarray(next_h_state)
        
    def compute_action(self, obs, h_state, overwrite_rng_key=True):
        new_key, curr_key = jrandom.split(self._key)
        if jrandom.bernoulli(key=curr_key, p=self._eps):
            val, next_h_state = self.model.q_values(obs, h_state)
            action = jrandom.randint(curr_key, shape=(1,), minval=0, maxval=val.shape[-1]).item()
        else:
            action, val, next_h_state = self.model.greedy_action(obs, h_state)

        if overwrite_rng_key:
            self._key = new_key
            if self._eps_warmup > 0:
                self._eps_warmup -= 1
            else:
                self._eps = max(self._eps * self._eps_decay, self._min_eps)

        return np.asarray(action), np.asarray(next_h_state)
    
    def reset(self):
        if hasattr(self.model, "reset"):
            return self.model.reset()
        return np.array([0.], dtype=np.float32)

In [12]:
@eqx.filter_grad(has_aux=True)
def compute_one_step_td_loss(model, obss, h_states, acts, targets):
    curr_q_values, _ = jax.vmap(model.q_values)(obss, h_states)
    loss = jnp.mean((jnp.sum(curr_q_values * acts, axis=-1, keepdims=True) - targets) ** 2)
    return loss, {
        "loss": loss,
        "max_target": jnp.max(targets),
        "min_target": jnp.min(targets),
        "mean_target": jnp.mean(targets),
        "max_q": jnp.max(curr_q_values),
        "min_q": jnp.min(curr_q_values),
        "mean_q": jnp.mean(curr_q_values),
    }

@eqx.filter_jit
def make_step(model, target_model, opt, opt_state, obss, h_states, acts, rews, dones, next_obss, next_h_states, gamma):
    _, next_q_values, _ = jax.vmap(target_model.greedy_action)(next_obss, next_h_states)
    targets = rews + (1 - dones) * (gamma * next_q_values)
    
    grads, update_info = compute_one_step_td_loss(model,
                                                  obss,
                                                  h_states,
                                                  acts,
                                                  targets)
    updates, opt_state = opt.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return model, opt_state, grads, update_info


class QLearning:
    def __init__(self, model, target_model, opt, buffer, cfg):
        self._model = model
        self._target_model = target_model
        self._opt = opt
        self._opt_state = self._opt.init(model)
        self._buffer = buffer
        self._cfg = cfg
        
        self._step = cfg.load_step
        
        self._batch_size = cfg.batch_size
        self._buffer_warmup = cfg.buffer_warmup
        self._num_gradient_steps = cfg.num_gradient_steps
        self._gamma = cfg.gamma
        self._tau = cfg.tau

    @property
    def model(self):
        return self._model
    
    @property
    def target_model(self):
        return self._target_model
        
    @property
    def buffer(self):
        return self._buffer
    
    @property
    def opt(self):
        return self._opt
    
    @property
    def opt_state(self):
        return self._opt_state
        
    def update(self, next_obs, next_h_state, update_info):
        self._step += 1
        
        if self._step <= self._buffer_warmup:
            return

        update_info["mean_loss"] = 0.
        update_info["mean_q"] = 0.
        update_info["mean_target"] = 0.
        update_info["max_q"] = -np.inf
        update_info["max_target"] = -np.inf
        update_info["min_q"] = np.inf
        update_info["min_target"] = np.inf
        for update_i in range(self._num_gradient_steps):
            obss, h_states, acts, rews, dones, next_obss, next_h_states, _, _, _ = self.buffer.sample_with_next_obs(batch_size=self._batch_size,
                                                                                                                    next_obs=next_obs,
                                                                                                                    next_h_state=next_h_state)

            acts = np.eye(self.model.act_dim)[acts.astype(np.int64)]
            
            (obss, h_states, acts, rews, dones, next_obss, next_h_states) = to_jnp(*batch_flatten(obss,
                                                                                                  h_states,
                                                                                                  acts,
                                                                                                  rews,
                                                                                                  dones,
                                                                                                  next_obss,
                                                                                                  next_h_states))
            model, opt_state, grads, curr_update_info = make_step(model=self.model,
                                                                  target_model=self.target_model,
                                                                  opt=self.opt,
                                                                  opt_state=self.opt_state,
                                                                  obss=obss,
                                                                  h_states=h_states,
                                                                  acts=acts,
                                                                  rews=rews,
                                                                  dones=dones,
                                                                  next_obss=next_obss,
                                                                  next_h_states=next_h_states,
                                                                  gamma=self._gamma)
            self._model = model
            self._opt_state = opt_state
            
            self._target_model = jax.tree_map(lambda p, tp: p * self._tau + tp * (1 - self._tau),
                                              self.model,
                                              self.target_model,)
            
            update_info["mean_loss"] += curr_update_info["loss"].item() / self._num_gradient_steps
            update_info["mean_q"] += curr_update_info["mean_q"].item() / self._num_gradient_steps
            update_info["mean_target"] += curr_update_info["mean_target"].item() / self._num_gradient_steps
            update_info["max_q"] = max(update_info["max_q"], curr_update_info["max_q"].item())
            update_info["max_target"] = max(update_info["max_target"], curr_update_info["max_target"].item())
            update_info["min_q"] = min(update_info["min_q"], curr_update_info["min_q"].item())
            update_info["min_target"] = min(update_info["min_target"], curr_update_info["min_target"].item())

In [13]:
buffer = NextStateNumPyBuffer(
    buffer_size=cfg.buffer_size,
    obs_dim=cfg.obs_dim,
    h_state_dim=cfg.h_state_dim,
    act_dim=(1,) if cfg.action_space == "discrete" else cfg.act_dim,
    rew_dim=cfg.rew_dim,
    rng=cfg.buffer_rng,
)

model = MLPQ(obs_dim=cfg.obs_dim,
                act_dim=cfg.act_dim,
                hidden_dim=cfg.hidden_dim,
                num_hidden=cfg.num_hidden,
                key=cfg.model_key)

target_model = MLPQ(obs_dim=cfg.obs_dim,
                    act_dim=cfg.act_dim,
                    hidden_dim=cfg.hidden_dim,
                    num_hidden=cfg.num_hidden,
                    key=cfg.model_key)

opt = optax.adam(cfg.lr)

learner = QLearning(model=model,
                    target_model=target_model,
                    opt=opt,
                    buffer=buffer,
                    cfg=cfg)

agent = EpsilonGreedyAgent(model=model,
                           learner=learner,
                           init_eps=cfg.init_eps,
                           min_eps=cfg.min_eps,
                           eps_decay=cfg.eps_decay,
                           eps_warmup=cfg.eps_warmup,
                           key=cfg.agent_key)

In [14]:
interact(agent, buffer, cfg)

TypeError: dot_general requires contracting dimensions to have the same shape, got (256,) and (3,).

In [None]:
assert 0

# Testing JAX

In [None]:
import jax
import jax.numpy as jnp

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(1000000)
%timeit selu(x).block_until_ready()


In [None]:
selu_jit = jax.jit(selu)

# Warm up
selu_jit(x).block_until_ready()

%timeit selu_jit(x).block_until_ready()

In [None]:
x_np = np.arange(1000000)
%timeit selu_jit(x_np).block_until_ready()

In [None]:
@partial(jax.jit, static_argnames=('n'))
def g(x, n):
  i = 0
  while i < n:
    i += 1
  return x + i

g(10, 20)  # Should raise an error.

In [None]:
import equinox as eqx
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jrandom

class MyModule(eqx.Module):
    # Specify the module's attributes;
    layers: list
    bias: jnp.ndarray

    # And how to initialise them;
    def __init__(self, key):
        key1, key2 = jrandom.split(key)
        self.layers = [eqx.nn.Linear(2, 8, key=key1),
                       eqx.nn.Linear(8, 2, key=key2)]
        self.bias = jnp.ones(2)

    # And the forward pass of the model.
    def __call__(self, x):
        for layer in self.layers[:-1]:
            x = jnn.relu(layer(x))
        return self.layers[-1](x) + self.bias

@jax.jit
@jax.grad
def loss(model, x, y):
    pred_y = jax.vmap(model)(x)
    return jnp.mean((y - pred_y) ** 2)

x_key, y_key, model_key = jrandom.split(jrandom.PRNGKey(0), 3)
x, y = jrandom.normal(x_key, (100, 2)), jrandom.normal(y_key, (100, 2))
model = MyModule(model_key)
grads = loss(model, x, y)

In [None]:
print(grads)

In [None]:
print(grads.bias)