In [1]:
import coax
import gym
import haiku as hk
import jax
import jax.numpy as jnp
from coax.value_losses import mse
from optax import adam

from emlp import T, Scalar
from emlp.groups import SO, S, O, Trivial
from emlp_haiku import EMLPBlock, Sequential, Linear
from emlp.reps import Rep
from emlp.nn import gated,gate_indices,uniform_rep

# the name of this script
name = 'dqn'

# the cart-pole MDP
env = gym.make('Acrobot-v1')
env = coax.wrappers.TrainMonitor(env, name=name, tensorboard_dir=f"./data/tensorboard/{name}")

In [5]:
ch=300
num_layers=4
group = Trivial(2)
rep_in = 6*T(0)(group)
rep_out = 3*T(0)(group)
# Parse ch as a single int, a sequence of ints, a single Rep, a sequence of Reps
if isinstance(ch,int): middle_layers = num_layers*[uniform_rep(ch,group)]
elif isinstance(ch,Rep): middle_layers = num_layers*[ch(group)]
else: middle_layers = [(c(group) if isinstance(c,Rep) else uniform_rep(c,group)) for c in ch]
# assert all((not rep.G is None) for rep in middle_layers[0].reps)
reps = [rep_in]+middle_layers


# norms = 100000*jnp.array([1., 1., 1., 1.])

def func(S, is_training):
    network = Sequential(
        *[EMLPBlock(rin,rout) for rin,rout in zip(reps,reps[1:])],
        Linear(reps[-1],rep_out)
    )
    
    return network(S)

In [6]:
# def func(S, is_training):
#     """ type-2 q-function: s -> q(s,.) """
#     seq = hk.Sequential((
#         hk.Linear(8), jax.nn.relu,
#         hk.Linear(8), jax.nn.relu,
#         hk.Linear(8), jax.nn.relu,
#         hk.Linear(env.action_space.n, w_init=jnp.zeros)
#     ))
#     return seq(S)

In [7]:
# value function and its derived policy
q = coax.Q(func, env)
pi = coax.BoltzmannPolicy(q, temperature=0.1)

# target network
q_targ = q.copy()

# experience tracer
tracer = coax.reward_tracing.NStep(n=1, gamma=0.9)
buffer = coax.experience_replay.SimpleReplayBuffer(capacity=100000)

# updater
qlearning = coax.td_learning.QLearning(q, q_targ=q_targ, loss_function=mse, optimizer=adam(0.001))

In [8]:
# train
metrics = None
logged_s = []
for ep in range(100):
    s = env.reset()
    # pi.epsilon = max(0.01, pi.epsilon * 0.95)
    # env.record_metrics({'EpsilonGreedy/epsilon': pi.epsilon})
    epoch_r = 0.
    for t in range(env.spec.max_episode_steps):
        a = pi(s)
        s_next, r, done, info = env.step(a)
        logged_s.append(s_next)
        epoch_r += r
        # extend last reward as asymptotic best-case return
        if t == env.spec.max_episode_steps - 1:
            assert done
            r = 1 / (1 - tracer.gamma)  # gamma + gamma^2 + gamma^3 + ... = 1 / (1 - gamma)

        # trace rewards and add transition to replay buffer
        tracer.add(s, a, r, done)
        while tracer:
            buffer.add(tracer.pop())

        # learn
        if len(buffer) >= 100:
            transition_batch = buffer.sample(batch_size=32)
            metrics = qlearning.update(transition_batch)
            env.record_metrics(metrics)

        # sync target network
        q_targ.soft_update(q, tau=0.01)
        
        if done:
            break

        if metrics is not None:
            print("Grad Norm = ", metrics['QLearning/grads_norm'])
            print("Loss = ", metrics['QLearning/loss'])
        s = s_next
    
    print("Epoch Reward =", epoch_r)
    # early stopping
    if env.avg_G > env.spec.reward_threshold:
        break


# run env one more time to render
# coax.utils.generate_gif(env, policy=pi, filepath=f"./data/{name}.gif", duration=25)

Grad Norm =  51795650000000.0
Loss =  3538373800000.0
Grad Norm =  22711519000000.0
Loss =  191966450000.0
Grad Norm =  35212646000000.0
Loss =  1065805400000.0
Grad Norm =  127005590000000.0
Loss =  1491777500000.0
Grad Norm =  33180434000000.0
Loss =  997955200000.0
Grad Norm =  26538756000000.0
Loss =  1939803200000.0
Grad Norm =  24157073000000.0
Loss =  1699653500000.0
Grad Norm =  28131748000000.0
Loss =  1213682400000.0
Grad Norm =  12692683000000.0
Loss =  142929950000.0
Grad Norm =  174842280000000.0
Loss =  2484938200000.0
Grad Norm =  12384249000000.0
Loss =  127232420000.0
Grad Norm =  25426870000000.0
Loss =  954509200000.0
Grad Norm =  24005896000000.0
Loss =  1402259600000.0
Grad Norm =  21067874000000.0
Loss =  1238682700000.0
Grad Norm =  22777147000000.0
Loss =  956942500000.0
Grad Norm =  23838840000000.0
Loss =  519760540000.0
Grad Norm =  49564325000000.0
Loss =  683968400000.0
Grad Norm =  42386466000000.0
Loss =  427059100000.0
Grad Norm =  2720800000000.0
Loss =

KeyboardInterrupt: 

In [None]:
arr.max(0)

In [None]:
transition_batch = buffer.sample(batch_size=32)
metrics = qlearning.update(transition_batch)