In [1]:
import equinox as eqx
import gym
import jax
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np
import optax

from argparse import Namespace
from functools import partial
from jax import grad, jit, vmap

from jax_learning.buffers.ram_buffers import NextStateNumPyBuffer
from jax_learning.buffers.utils import batch_flatten, to_jnp

In [2]:
cfg_dict = {
    # Environment setup
    "env": "MountainCar-v0",
    "seed": 0,
    "render": False,
    
    # Experiment progress
    "load_step": 0,
    "log_interval": 5000,
    
    # Learning hyperparameters
    "max_timesteps": 1000000,
    "buffer_size": 1000000,
    "buffer_warmup": 1000,
    "num_gradient_steps": 1,
    "batch_size": 128,
    "lr": 3e-4,
    "gamma": 0.99,
    
    # Epsilon greedy hyperparameters
    "init_eps": 0.99,
    "min_eps": 0.0,
    "eps_decay": 0.9999,
    "eps_warmup": 1000,
    
}
cfg = Namespace(**cfg_dict)

In [3]:
env = gym.make(cfg.env)

  logger.warn(


In [4]:
cfg.obs_dim = env.observation_space.shape
cfg.act_dim = (env.action_space.n,)
cfg.action_space = "discrete"

In [5]:
cfg.h_state_dim = (1,)
cfg.rew_dim = (1,)

In [6]:
cfg

Namespace(env='MountainCar-v0', seed=0, render=False, load_step=0, log_interval=5000, max_timesteps=1000000, buffer_size=1000000, buffer_warmup=1000, num_gradient_steps=1, batch_size=128, lr=0.0003, gamma=0.99, init_eps=0.99, min_eps=0.0, eps_decay=0.9999, eps_warmup=1000, obs_dim=(2,), act_dim=(3,), action_space='discrete', h_state_dim=(1,), rew_dim=(1,))

In [7]:
cfg.buffer_rng = np.random.RandomState(cfg.seed)
cfg.agent_key, cfg.model_key = jrandom.split(jrandom.PRNGKey(cfg.seed), num=2)



In [14]:
def interact(agent, buffer, cfg):
    max_timesteps = cfg.max_timesteps
    log_interval = cfg.log_interval
    render = env.render if cfg.render else lambda: None
    
    obs = env.reset()
    h_state = agent.reset()
    eps_returns = [0.]
    for timestep_i in range(max_timesteps):
        act, next_h_state = agent.compute_action(obs, h_state)
        next_obs, rew, done, info = env.step(act)
        render()
        buffer.push(obs, h_state, act, rew, done, info, next_obs, next_h_state)
        update_info = agent.learner.update(next_obs, next_h_state)
        obs = next_obs
        eps_returns[-1] += rew
        
        if done:
            obs = env.reset()
            h_state = agent.reset()
            eps_returns.append(0.)

        if (timestep_i + 1) % log_interval == 0:
            print("Number of episodes: {}, Mean return: {}".format(len(eps_returns) - 1, np.mean(eps_returns[:-1])))
            eps_returns = eps_returns[-1:]

In [15]:
@eqx.filter_value_and_grad
def compute_one_step_td_loss(model, obss, h_states, acts, targets):
    curr_q_values, _ = jax.vmap(model.q_values)(obss, h_states)
    return jnp.mean((jnp.sum(curr_q_values * acts, axis=-1, keepdims=True) - targets) ** 2)

@eqx.filter_jit
def make_step(model, opt, opt_state, obss, h_states, acts, rews, dones, next_obss, next_h_states, gamma):
    next_acts, next_q_values, _ = jax.vmap(model.greedy_action)(next_obss, next_h_states)
    targets = jax.lax.stop_gradient(rews + (1 - dones) * (gamma * next_q_values))
    loss, grads = compute_one_step_td_loss(model,
                                           obss,
                                           h_states,
                                           acts,
                                           targets)
    updates, opt_state = opt.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

class QLearning:
    def __init__(self, model, opt, buffer, cfg):
        self._model = model
        self._opt = opt
        self._opt_state = self._opt.init(model)
        self._buffer = buffer
        self._cfg = cfg
        
        self._step = cfg.load_step
        
        self._batch_size = cfg.batch_size
        self._buffer_warmup = cfg.buffer_warmup
        self._num_gradient_steps = cfg.num_gradient_steps
        self._gamma = cfg.gamma

    @property
    def model(self):
        return self._model
        
    @property
    def buffer(self):
        return self._buffer
    
    @property
    def opt(self):
        return self._opt
    
    @property
    def opt_state(self):
        return self._opt_state
        
    def update(self, next_obs, next_h_state):
        self._step += 1
        update_info = dict()
        
        if self._step <= self._buffer_warmup:
            return update_info
        
        update_info["loss"] = []
        for update_i in range(self._num_gradient_steps):
            obss, h_states, acts, rews, dones, next_obss, next_h_states, _, _, _ = self.buffer.sample_with_next_obs(batch_size=self._batch_size,
                                                                                                                    next_obs=next_obs,
                                                                                                                    next_h_state=next_h_state)

            acts = np.eye(self.model.act_dim)[acts.astype(np.int64)]
            
            (obss, h_states, acts, rews, dones, next_obss, next_h_states) = to_jnp(*batch_flatten(obss,
                                                                                                        h_states,
                                                                                                        acts,
                                                                                                        rews,
                                                                                                        dones,
                                                                                                        next_obss,
                                                                                                        next_h_states))
            loss, model, opt_state = make_step(model=self.model,
                                               opt=self.opt,
                                               opt_state=self.opt_state,
                                               obss=obss,
                                               h_states=h_states,
                                               acts=acts,
                                               rews=rews,
                                               dones=dones,
                                               next_obss=next_obss,
                                               next_h_states=next_h_states,
                                               gamma=self._gamma)
            update_info["loss"].append(loss.item())
        return update_info

In [16]:
class LinearQ(eqx.Module):
    obs_dim: int
    act_dim: int
    linear: eqx.nn.Linear
    bias: jnp.ndarray

    def __init__(self, obs_dim, act_dim, key):
        self.obs_dim = int(np.product(obs_dim))
        self.act_dim = int(np.product(act_dim))
        self.linear = eqx.nn.Linear(self.obs_dim, self.act_dim, use_bias=False, key=key)
        self.bias = jnp.zeros(self.act_dim)

    def q_values(self, obs, h_state):
        return self.linear(obs) + self.bias, h_state
    
    def greedy_action(self, obs, h_state):
        q_val, h_state = self.q_values(obs, h_state)
        acts = jnp.argmax(q_val, axis=-1)
        return acts, q_val[acts].reshape((1,)), h_state
    
    def random_action(self, obs, h_state, key):
        q_val, h_state = self.q_values(obs, h_state)
        acts = jrandom.categorical(key=key, logits=q_val, axis=-1)
        return acts, q_val[acts].reshape((1,)), h_state

In [17]:
class EpsilonGreedyAgent:
    def __init__(self, model, learner, init_eps, min_eps, eps_decay, eps_warmup, key):
        self._model = model
        self._learner = learner
        self._eps = init_eps
        self._init_eps = init_eps
        self._min_eps = min_eps
        self._eps_decay = eps_decay
        self._eps_warmup = eps_warmup
        self._key = key

    @property
    def model(self):
        return self._model
        
    @property
    def learner(self):
        return self._learner
        
    def greedy_action(self, obs, h_state):
        action, val, next_h_state = self.model.greedy_action(obs, h_state)
        return np.asarray(action), np.asarray(next_h_state)
        
    def compute_action(self, obs, h_state, overwrite_rng_key=True):
        new_key, curr_key = jrandom.split(self._key)
        if jrandom.bernoulli(key=curr_key, p=self._eps):
            action, val, next_h_state = self.model.random_action(obs, h_state, curr_key)
        else:
            action, val, next_h_state = self.model.greedy_action(obs, h_state)

        if overwrite_rng_key:
            self._key = new_key
            if self._eps_warmup > 0:
                self._eps_warmup -= 1
            else:
                self._eps = max(self._eps * self._eps_decay, self._min_eps)

        return np.asarray(action), np.asarray(next_h_state)
    
    def reset(self):
        if hasattr(self.model, "reset"):
            return self.model.reset()
        return np.array([0.], dtype=np.float32)

In [18]:
buffer = NextStateNumPyBuffer(
    buffer_size=cfg.buffer_size,
    obs_dim=cfg.obs_dim,
    h_state_dim=cfg.h_state_dim,
    act_dim=(1,) if cfg.action_space == "discrete" else cfg.act_dim,
    rew_dim=cfg.rew_dim,
    rng=cfg.buffer_rng,
)

model = LinearQ(obs_dim=cfg.obs_dim,
                act_dim=cfg.act_dim,
                key=cfg.model_key)

opt = optax.adam(cfg.lr)

learner = QLearning(model=model,
                    opt=opt,
                    buffer=buffer,
                    cfg=cfg)

agent = EpsilonGreedyAgent(model=model,
                           learner=learner,
                           init_eps=cfg.init_eps,
                           min_eps=cfg.min_eps,
                           eps_decay=cfg.eps_decay,
                           eps_warmup=cfg.eps_warmup,
                           key=cfg.agent_key)

In [19]:
interact(agent, buffer, cfg)

-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0


KeyboardInterrupt: 

In [None]:
assert 0

# Testing JAX

In [None]:
import jax
import jax.numpy as jnp

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(1000000)
%timeit selu(x).block_until_ready()


In [None]:
selu_jit = jax.jit(selu)

# Warm up
selu_jit(x).block_until_ready()

%timeit selu_jit(x).block_until_ready()

In [None]:
x_np = np.arange(1000000)
%timeit selu_jit(x_np).block_until_ready()

In [None]:
@partial(jax.jit, static_argnames=('n'))
def g(x, n):
  i = 0
  while i < n:
    i += 1
  return x + i

g(10, 20)  # Should raise an error.

In [None]:
import equinox as eqx
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jrandom

class MyModule(eqx.Module):
    # Specify the module's attributes;
    layers: list
    bias: jnp.ndarray

    # And how to initialise them;
    def __init__(self, key):
        key1, key2 = jrandom.split(key)
        self.layers = [eqx.nn.Linear(2, 8, key=key1),
                       eqx.nn.Linear(8, 2, key=key2)]
        self.bias = jnp.ones(2)

    # And the forward pass of the model.
    def __call__(self, x):
        for layer in self.layers[:-1]:
            x = jnn.relu(layer(x))
        return self.layers[-1](x) + self.bias

@jax.jit
@jax.grad
def loss(model, x, y):
    pred_y = jax.vmap(model)(x)
    return jnp.mean((y - pred_y) ** 2)

x_key, y_key, model_key = jrandom.split(jrandom.PRNGKey(0), 3)
x, y = jrandom.normal(x_key, (100, 2)), jrandom.normal(y_key, (100, 2))
model = MyModule(model_key)
grads = loss(model, x, y)