In [None]:
import equinox as eqx
import gym
import jax
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np

from argparse import Namespace
from functools import partial
from jax import grad, jit, vmap

from jax_learning.buffers.ram_buffers import NextStateNumPyBuffer

In [None]:
cfg_dict = {
    "batch_size": 128,
    "lr": 3e-4,
    "max_timesteps": 1000000,
    "memory_size": 1000000,
    "env": "MountainCar-v0",
    "seed": 0,
    "init_eps": 0.99,
    "min_eps": 0.0,
    "eps_decay": 0.9999,
    "eps_warmup": 1000,
    "render": True,
}
cfg = Namespace(**cfg_dict)

In [None]:
env = gym.make(cfg.env)

In [None]:
cfg.act_dim = (env.action_space.n, )

In [None]:
cfg.obs_dim = env.observation_space.shape

In [None]:
cfg.h_state_dim = (1,)
cfg.rew_dim = (1,)

In [None]:
cfg

In [None]:
cfg.buffer_rng = np.random.RandomState(cfg.seed)
cfg.agent_key, cfg.model_key = jrandom.split(jrandom.PRNGKey(cfg.seed), num=2)

In [None]:
def interact(agent, buffer, cfg):
    max_timesteps = cfg.max_timesteps
    render = env.render if cfg.render else lambda: None
    
    obs = env.reset()
    h_state = agent.reset()
    for timestep_i in range(max_timesteps):
        act = agent.compute_action(obs)
        next_obs, rew, done, info = env.step(act)
        render()
        buffer.push(obs, h_state, act, rew, done, info, next_obs, h_state)
        obs = next_obs

In [None]:
class LinearQ(eqx.Module):
    _obs_dim: int
    _act_dim: int
    linear: eqx.nn.Linear
    bias: jnp.ndarray

    def __init__(self, obs_dim, act_dim, key):
        self._obs_dim = int(np.product(obs_dim))
        self._act_dim = int(np.product(act_dim))
        self.linear = eqx.nn.Linear(self._obs_dim, self._act_dim, use_bias=False, key=key)
        self.bias = jnp.zeros(self._act_dim)

    def q_values(self, input):
        return self.linear(input) + self.bias
    
    def greedy_action(self, input):
        q_val = self.q_values(input)
        return jnp.argmax(q_val, axis=-1)
    
    def random_action(self, input, key):
        q_val = self.q_values(input)
        return jrandom.categorical(key=key, logits=q_val, axis=-1)

In [None]:
class EpsilonGreedyAgent:
    def __init__(self, model, init_eps, min_eps, eps_decay, eps_warmup, key):
        self._model = model
        self._eps = init_eps
        self._init_eps = init_eps
        self._min_eps = min_eps
        self._eps_decay = eps_decay
        self._eps_warmup = eps_warmup
        self._key = key

    @property
    def model(self):
        return self._model
        
    def greedy_action(self, obs):
        return np.asarray(self.model.greedy_action(obs))
        
    def compute_action(self, obs, overwrite_rng_key=True):
        new_key, curr_key = jrandom.split(self._key)
        if jrandom.bernoulli(key=curr_key, p=self._eps):
            action = self.model.random_action(obs, curr_key)
        else:
            action = self.model.greedy_action(obs)

        if overwrite_rng_key:
            self._key = new_key
            if self._eps_warmup > 0:
                self._eps_warmup -= 1
            else:
                self._eps = max(self._eps * self._eps_decay, self._min_eps)

        return np.asarray(action)
    
    def reset(self):
        return np.array([np.nan], dtype=np.float32)

In [None]:
model = LinearQ(obs_dim=cfg.obs_dim,
                act_dim=cfg.act_dim,
                key=cfg.model_key)
agent = EpsilonGreedyAgent(model=model,
                           init_eps=cfg.init_eps,
                           min_eps=cfg.min_eps,
                           eps_decay=cfg.eps_decay,
                           eps_warmup=cfg.eps_warmup,
                           key=cfg.agent_key)
buffer = NextStateNumPyBuffer(
    memory_size=cfg.memory_size,
    obs_dim=cfg.obs_dim,
    h_state_dim=cfg.h_state_dim,
    act_dim=cfg.act_dim,
    rew_dim=cfg.rew_dim,
    rng=cfg.buffer_rng,
)

In [None]:
interact(agent, buffer, cfg)

In [None]:
assert 0

# Testing JAX

In [None]:
import jax
import jax.numpy as jnp

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(1000000)
%timeit selu(x).block_until_ready()


In [None]:
selu_jit = jax.jit(selu)

# Warm up
selu_jit(x).block_until_ready()

%timeit selu_jit(x).block_until_ready()

In [None]:
x_np = np.arange(1000000)
%timeit selu_jit(x_np).block_until_ready()

In [None]:
@partial(jax.jit, static_argnames=('n'))
def g(x, n):
  i = 0
  while i < n:
    i += 1
  return x + i

g(10, 20)  # Should raise an error.

In [None]:
import equinox as eqx
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jrandom

class MyModule(eqx.Module):
    # Specify the module's attributes;
    layers: list
    bias: jnp.ndarray

    # And how to initialise them;
    def __init__(self, key):
        key1, key2 = jrandom.split(key)
        self.layers = [eqx.nn.Linear(2, 8, key=key1),
                       eqx.nn.Linear(8, 2, key=key2)]
        self.bias = jnp.ones(2)

    # And the forward pass of the model.
    def __call__(self, x):
        for layer in self.layers[:-1]:
            x = jnn.relu(layer(x))
        return self.layers[-1](x) + self.bias

@jax.jit
@jax.grad
def loss(model, x, y):
    pred_y = jax.vmap(model)(x)
    return jnp.mean((y - pred_y) ** 2)

x_key, y_key, model_key = jrandom.split(jrandom.PRNGKey(0), 3)
x, y = jrandom.normal(x_key, (100, 2)), jrandom.normal(y_key, (100, 2))
model = MyModule(model_key)
grads = loss(model, x, y)