From https://coax.readthedocs.io/en/latest/examples/getting_started/third_agent.html 

Import some libraries:

In [1]:
import os

In [2]:
import gym
import jax
import coax
import haiku as hk
import jax.numpy as jnp
from optax import adam

Set up some environment variables:

In [3]:
# os.environ.setdefault('JAX_PLATFORM_NAME', 'gpu')     # tell JAX to use GPU
# os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.1'  # don't use all gpu mem
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'              # tell XLA to be quiet

Name the script:

In [4]:
name = 'ppo'

Setup the environment:

In [5]:
env = gym.make('PongNoFrameskip-v4')
env = gym.wrappers.AtariPreprocessing(env)
env = coax.wrappers.FrameStacking(env, num_frames=3)
env = coax.wrappers.TrainMonitor(env, name=name, tensorboard_dir=f"./data/tensorboard/{name}")

Define the models:

In [6]:
def shared(S, is_training):
    seq = hk.Sequential([
        coax.utils.diff_transform,
        hk.Conv2D(16, kernel_shape=8, stride=4), jax.nn.relu,
        hk.Conv2D(32, kernel_shape=4, stride=2), jax.nn.relu,
        hk.Flatten(),
    ])
    X = jnp.stack(S, axis=-1) / 255.  # stack frames
    return seq(X)


def func_pi(S, is_training):
    logits = hk.Sequential((
        hk.Linear(256), jax.nn.relu,
        hk.Linear(env.action_space.n, w_init=jnp.zeros),
    ))
    X = shared(S, is_training)
    return {'logits': logits(X)}


def func_v(S, is_training):
    value = hk.Sequential((
        hk.Linear(256), jax.nn.relu,
        hk.Linear(1, w_init=jnp.zeros), jnp.ravel
    ))
    X = shared(S, is_training)
    return value(X)

Setup the RL problem:

In [7]:
# function approximators
pi = coax.Policy(func_pi, env)
v = coax.V(func_v, env)

# target networks
pi_behavior = pi.copy()
v_targ = v.copy()

# policy regularizer (avoid premature exploitation)
entropy = coax.regularizers.EntropyRegularizer(pi, beta=0.001)

# updaters
simpletd = coax.td_learning.SimpleTD(v, v_targ, optimizer=adam(3e-4))
ppo_clip = coax.policy_objectives.PPOClip(pi, regularizer=entropy, optimizer=adam(3e-4))

# reward tracer and replay buffer
tracer = coax.reward_tracing.NStep(n=5, gamma=0.99)
buffer = coax.experience_replay.SimpleReplayBuffer(capacity=256)



Train:

In [None]:
%%time
while env.T < 3000000:
    s = env.reset()

    for t in range(env.spec.max_episode_steps):
        a, logp = pi_behavior(s, return_logp=True)
        s_next, r, done, info = env.step(a)

        # trace rewards and add transition to replay buffer
        tracer.add(s, a, r, done, logp)
        while tracer:
            buffer.add(tracer.pop())

        # learn
        if len(buffer) >= buffer.capacity:
            num_batches = int(4 * buffer.capacity / 32)  # 4 epochs per round
            for _ in range(num_batches):
                transition_batch = buffer.sample(32)
                metrics_v, td_error = simpletd.update(transition_batch, return_td_error=True)
                metrics_pi = ppo_clip.update(transition_batch, td_error)
                env.record_metrics(metrics_v)
                env.record_metrics(metrics_pi)

            buffer.clear()

            # sync target networks
            pi_behavior.soft_update(pi, tau=0.1)
            v_targ.soft_update(v, tau=0.1)

        if done:
            break

        s = s_next

    # generate an animated GIF to see what's going on
    if env.period(name='generate_gif', T_period=10000) and env.T > 50000:
        T = env.T - env.T % 10000  # round to 10000s
        coax.utils.generate_gif(
            env=env, policy=pi, resize_to=(320, 420),
            filepath=f"./data/gifs/{name}/T{T:08d}.gif")

[ppo|MainThread|TrainMonitor|INFO] ep: 1,	T: 807,	G: -21,	avg_G: -21,	t: 806,	dt: 182.260ms,	SimpleTD/loss: 0.0565,	EntropyRegularizer/entropy: 1.78,	PPOClip/loss: 0.00654
[ppo|MainThread|TrainMonitor|INFO] ep: 2,	T: 1,896,	G: -19,	avg_G: -20,	t: 1088,	dt: 104.867ms,	SimpleTD/loss: 0.037,	EntropyRegularizer/entropy: 1.79,	PPOClip/loss: -0.00446
[ppo|MainThread|TrainMonitor|INFO] ep: 3,	T: 2,802,	G: -21,	avg_G: -20.3,	t: 905,	dt: 80.122ms,	SimpleTD/loss: 0.0145,	EntropyRegularizer/entropy: 1.79,	PPOClip/loss: -0.00215
[ppo|MainThread|TrainMonitor|INFO] ep: 4,	T: 3,901,	G: -19,	avg_G: -20,	t: 1098,	dt: 118.034ms,	SimpleTD/loss: 0.0111,	EntropyRegularizer/entropy: 1.78,	PPOClip/loss: -0.0017
[ppo|MainThread|TrainMonitor|INFO] ep: 5,	T: 4,861,	G: -20,	avg_G: -20,	t: 959,	dt: 115.876ms,	SimpleTD/loss: 0.01,	EntropyRegularizer/entropy: 1.78,	PPOClip/loss: -0.00467
[ppo|MainThread|TrainMonitor|INFO] ep: 6,	T: 5,685,	G: -21,	avg_G: -20.2,	t: 823,	dt: 122.553ms,	SimpleTD/loss: 0.0054,	EntropyRe