From https://coax.readthedocs.io/en/latest/examples/getting_started/second_agent.html

Import some libaries:

In [1]:
import coax
import gym
import haiku as hk
import jax
import jax.numpy as jnp
from coax.value_losses import mse
from optax import adam

Setup the environment:

In [2]:
# the name of this script
name = 'dqn'

# the cart-pole MDP
env = gym.make('CartPole-v0')
env = coax.wrappers.TrainMonitor(env, name=name, tensorboard_dir=f"./data/tensorboard/{name}")

Define the Q-learning function:

In [3]:
def func(S, is_training):
    """ type-2 q-function: s -> q(s,.) """
    seq = hk.Sequential((
        hk.Linear(8), jax.nn.relu,
        hk.Linear(8), jax.nn.relu,
        hk.Linear(8), jax.nn.relu,
        hk.Linear(env.action_space.n, w_init=jnp.zeros)
    ))
    return seq(S)

Setup the RL problem:

In [4]:
# value function and its derived policy
q = coax.Q(func, env)
pi = coax.BoltzmannPolicy(q, temperature=0.1)

# target network
q_targ = q.copy()

# experience tracer
tracer = coax.reward_tracing.NStep(n=1, gamma=0.9)
buffer = coax.experience_replay.SimpleReplayBuffer(capacity=100000)

# updater
qlearning = coax.td_learning.QLearning(q, q_targ=q_targ, loss_function=mse, optimizer=adam(0.001))



Train the model:

In [5]:
%%time
# train
for ep in range(1000):
    s = env.reset()
    # pi.epsilon = max(0.01, pi.epsilon * 0.95)
    # env.record_metrics({'EpsilonGreedy/epsilon': pi.epsilon})

    for t in range(env.spec.max_episode_steps):
        a = pi(s)
        s_next, r, done, info = env.step(a)

        # extend last reward as asymptotic best-case return
        if t == env.spec.max_episode_steps - 1:
            assert done
            r = 1 / (1 - tracer.gamma)  # gamma + gamma^2 + gamma^3 + ... = 1 / (1 - gamma)

        # trace rewards and add transition to replay buffer
        tracer.add(s, a, r, done)
        while tracer:
            buffer.add(tracer.pop())

        # learn
        if len(buffer) >= 100:
            transition_batch = buffer.sample(batch_size=32)
            metrics = qlearning.update(transition_batch)
            env.record_metrics(metrics)

        # sync target network
        q_targ.soft_update(q, tau=0.01)

        if done:
            break

        s = s_next

    # early stopping
    if env.avg_G > env.spec.reward_threshold:
        break

[dqn|MainThread|TrainMonitor|INFO] ep: 1,	T: 18,	G: 17,	avg_G: 17,	t: 17,	dt: 141.744ms
[dqn|MainThread|TrainMonitor|INFO] ep: 2,	T: 35,	G: 16,	avg_G: 16.5,	t: 16,	dt: 41.245ms
[dqn|MainThread|TrainMonitor|INFO] ep: 3,	T: 51,	G: 15,	avg_G: 16,	t: 15,	dt: 25.956ms
[dqn|MainThread|TrainMonitor|INFO] ep: 4,	T: 71,	G: 19,	avg_G: 16.8,	t: 19,	dt: 21.720ms
[dqn|MainThread|TrainMonitor|INFO] ep: 5,	T: 116,	G: 44,	avg_G: 22.2,	t: 44,	dt: 205.420ms,	QLearning/loss: 0.494
[dqn|MainThread|TrainMonitor|INFO] ep: 6,	T: 130,	G: 13,	avg_G: 20.7,	t: 13,	dt: 45.933ms,	QLearning/loss: 0.479
[dqn|MainThread|TrainMonitor|INFO] ep: 7,	T: 158,	G: 27,	avg_G: 21.6,	t: 27,	dt: 47.998ms,	QLearning/loss: 0.442
[dqn|MainThread|TrainMonitor|INFO] ep: 8,	T: 194,	G: 35,	avg_G: 23.2,	t: 35,	dt: 58.058ms,	QLearning/loss: 0.333
[dqn|MainThread|TrainMonitor|INFO] ep: 9,	T: 225,	G: 30,	avg_G: 24,	t: 30,	dt: 89.077ms,	QLearning/loss: 0.152
[dqn|MainThread|TrainMonitor|INFO] ep: 10,	T: 235,	G: 9,	avg_G: 22.5,	t: 9,	dt: 89.

CPU times: user 2h 15min 42s, sys: 1h 21min 13s, total: 3h 36min 55s
Wall time: 2h 15min 48s


Visualize the trained policy:

In [6]:
coax.utils.generate_gif(env, policy=pi, filepath=f"./data/{name}.gif", duration=25)

[dqn|MainThread|generate_gif|INFO] recorded episode to: ./data/dqn.gif


Visualize a random policy:

In [7]:
coax.utils.generate_gif(env, filepath=f"./data/random.gif", duration=25)

[dqn|MainThread|generate_gif|INFO] recorded episode to: ./data/random.gif
