In [None]:
%pip install git+https://github.com/coax-dev/coax.git@main

In [1]:
import os

os.environ.setdefault('JAX_PLATFORM_NAME', 'gpu')     # tell JAX to use GPU
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.1'  # don't use all gpu mem
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'              # tell XLA to be quiet

In [None]:
%load_ext tensorboard
%tensorboard --logdir ./data/tensorboard

In [2]:
import coax
import gym
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as onp
from optax import adam


# the name of this script
name = 'iqn'

# the cart-pole MDP
env = gym.make('CartPole-v0')
env = coax.wrappers.TrainMonitor(env, name=name, tensorboard_dir=f"./data/tensorboard/{name}")
quantile_embedding_dim = 16


def func(S, quantiles, is_training):
    """ type-2 q-function: s -> q(s,.) """
    def quantile(x, quantiles):
        x_size = x.shape[-1]
        quantile_net = jnp.tile(quantiles[..., None], [1, quantile_embedding_dim])
        quantile_net = (
            jnp.arange(1, quantile_embedding_dim + 1, 1).astype(jnp.float32)
            * onp.pi
            * quantile_net)
        quantile_net = jnp.cos(quantile_net)
        quantile_net = hk.Linear(x_size)(quantile_net)
        quantile_net = jax.nn.relu(quantile_net)
        x = x[:, None, ...] * quantile_net
        x = hk.Linear(x_size)(x)
        x = jax.nn.relu(x)
        return x
    encoder = hk.Sequential((
        hk.Linear(8), jax.nn.relu,
        hk.Linear(8), jax.nn.relu,
    ))
    x = encoder(S)
    x = quantile(x, quantiles=quantiles)
    x = hk.Linear(env.action_space.n, w_init=jnp.zeros)(x)
    x = jnp.moveaxis(x, -1, -2)
    return x


# value function and its derived policy
q = coax.QuantileQ(func, env, num_quantiles=1)
pi = coax.BoltzmannPolicy(q, temperature=0.1)

# target network
q_targ = q.copy()

# experience tracer
tracer = coax.reward_tracing.NStep(n=1, gamma=0.9)
buffer = coax.experience_replay.SimpleReplayBuffer(capacity=100000)

# updater
qlearning = coax.td_learning.QuantileQLearning(q, q_targ=q_targ, optimizer=adam(0.001))


# train
for ep in range(1000):
    s = env.reset()
    # pi.epsilon = max(0.01, pi.epsilon * 0.95)
    # env.record_metrics({'EpsilonGreedy/epsilon': pi.epsilon})

    for t in range(env.spec.max_episode_steps):
        a = pi(s)
        s_next, r, done, info = env.step(a)

        # extend last reward as asymptotic best-case return
        if t == env.spec.max_episode_steps - 1:
            assert done
            r = 1 / (1 - tracer.gamma)  # gamma + gamma^2 + gamma^3 + ... = 1 / (1 - gamma)

        # trace rewards and add transition to replay buffer
        tracer.add(s, a, r, done)
        while tracer:
            buffer.add(tracer.pop())

        # learn
        if len(buffer) >= 100:
            transition_batch = buffer.sample(batch_size=32)
            metrics = qlearning.update(transition_batch)
            env.record_metrics(metrics)

        # sync target network
        q_targ.soft_update(q, tau=0.01)

        if done:
            break

        s = s_next

    # early stopping
    if env.avg_G > env.spec.reward_threshold:
        break


# run env one more time to render
coax.utils.generate_gif(env, policy=pi, filepath=f"./data/{name}.gif", duration=25)


[iqn|absl|INFO] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
[iqn|absl|INFO] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
[iqn|TrainMonitor|INFO] ep: 1,	T: 21,	G: 20,	avg_r: 1,	avg_G: 20,	t: 20,	dt: 86.062ms
[iqn|TrainMonitor|INFO] ep: 2,	T: 33,	G: 11,	avg_r: 1,	avg_G: 15.5,	t: 11,	dt: 33.212ms
[iqn|TrainMonitor|INFO] ep: 3,	T: 54,	G: 20,	avg_r: 1,	avg_G: 17,	t: 20,	dt: 33.507ms
[iqn|TrainMonitor|INFO] ep: 4,	T: 83,	G: 28,	avg_r: 1,	avg_G: 19.8,	t: 28,	dt: 33.215ms
[iqn|TrainMonitor|INFO] ep: 5,	T: 108,	G: 24,	avg_r: 1,	avg_G: 20.6,	t: 24,	dt: 195.327ms,	QuantileQLearning/loss: 2.3
[iqn|TrainMonitor|INFO] ep: 6,	T: 123,	G: 14,	avg_r: 1,	avg_G: 19.5,	t: 14,	dt: 290.164ms,	QuantileQLearning/loss: 2.22
[iqn|TrainMonitor|INFO] ep: 7,	T: 159,	G: 35,	avg_r: 1,	avg_G: 21.7,	t: 35,	dt: 290.025ms,	QuantileQLearning/loss: 2.21
[iqn|TrainMonitor|INFO] ep: 8,	T: 173,	G: 13,	avg_r: 1,	avg_G: 20.6,	t: 13