In [None]:
import coax
import gym
import haiku as hk
import jax
import jax.numpy as jnp
from coax.value_losses import mse
from optax import adam
import sys
sys.path.append("..")
# from soft_emlp import MixedEMLP
from emlp import T, Scalar
from emlp.groups import SO, S, O, Trivial
from emlp_haiku import EMLPBlock, Sequential, Linear
from emlp.reps import Rep
from emlp.nn import gated,gate_indices,uniform_rep
# the name of this script

In [None]:
ch=384
num_layers=3
group = S(2)
rep_in = T(0)(group)
rep_out = T(1)(group)
# Parse ch as a single int, a sequence of ints, a single Rep, a sequence of Reps
if isinstance(ch,int): middle_layers = num_layers*[uniform_rep(ch,group)]
elif isinstance(ch,Rep): middle_layers = num_layers*[ch(group)]
else: middle_layers = [(c(group) if isinstance(c,Rep) else uniform_rep(c,group)) for c in ch]
# assert all((not rep.G is None) for rep in middle_layers[0].reps)
reps = [rep_in]+middle_layers

def func(S, is_training):
    network = Sequential(
        *[EMLPBlock(rin,rout) for rin,rout in zip(reps,reps[1:])],
        Linear(reps[-1],rep_out)
    )
    return network(S)

In [None]:
module = hk.to_module(network)

In [None]:
# def func(S, is_training):
#     ch=384
#     num_layers=3
#     """ type-2 q-function: s -> q(s,.) """
#     mlp = Sequential(
#         *[Sequential(hk.Linear(ch),jax.nn.swish) for _ in range(num_layers)],
#         hk.Linear(2)
#     )
#     return mlp(S)

In [7]:
name = 'dqn'
env = gym.make('CartPole-v0')
env = coax.wrappers.TrainMonitor(env, name=name)

In [8]:
q = coax.Q(func, env)

In [9]:
name = 'dqn'

# the cart-pole MDP
env = gym.make('Pendulum-v0')
env = coax.wrappers.TrainMonitor(env, name=name)

q = coax.Q(func, env)
pi = coax.BoltzmannPolicy(q, temperature=0.1)

# target network
q_targ = q.copy()

# experience tracer
tracer = coax.reward_tracing.NStep(n=1, gamma=0.9)
buffer = coax.experience_replay.SimpleReplayBuffer(capacity=100000)

# updater
qlearning = coax.td_learning.QLearning(q, q_targ=q_targ, loss_function=mse, optimizer=adam(0.001))

In [10]:
# train
for ep in range(10):
    s = env.reset()
    # pi.epsilon = max(0.01, pi.epsilon * 0.95)
    # env.record_metrics({'EpsilonGreedy/epsilon': pi.epsilon})

    for t in range(env.spec.max_episode_steps):
        a = pi(s)
        s_next, r, done, info = env.step(a)
        
        # extend last reward as asymptotic best-case return
        if t == env.spec.max_episode_steps - 1:
            assert done
            r = 1 / (1 - tracer.gamma)  # gamma + gamma^2 + gamma^3 + ... = 1 / (1 - gamma)

        # trace rewards and add transition to replay buffer
        tracer.add(s, a, r, done)
        while tracer:
            buffer.add(tracer.pop())

        # learn
        if len(buffer) >= 100:
            transition_batch = buffer.sample(batch_size=32)
            metrics = qlearning.update(transition_batch)
#             print(metrics)
            env.record_metrics(metrics)

        # sync target network
        q_targ.soft_update(q, tau=0.01)

        if done:
            break

        s = s_next
#     print(jnp.linalg.norm(q.params['linear']['w']))
#     print(jnp.linalg.norm(q.params['sequential/_hk_linear']['w']))
    # early stopping
    if env.avg_G > env.spec.reward_threshold:
        break

{'QLearning/grads_max': DeviceArray(555352.9, dtype=float32), 'QLearning/grads_norm': DeviceArray(5698575., dtype=float32), 'QLearning/loss': DeviceArray(18938.887, dtype=float32), 'QLearning/td_error': DeviceArray(189.21268, dtype=float32), 'QLearning/td_error_targ': DeviceArray(0.71166587, dtype=float32)}
{'QLearning/grads_max': DeviceArray(9325980., dtype=float32), 'QLearning/grads_norm': DeviceArray(1.32853e+08, dtype=float32), 'QLearning/loss': DeviceArray(4706313., dtype=float32), 'QLearning/td_error': DeviceArray(-3056.62, dtype=float32), 'QLearning/td_error_targ': DeviceArray(102.08308, dtype=float32)}
{'QLearning/grads_max': DeviceArray(243226.66, dtype=float32), 'QLearning/grads_norm': DeviceArray(2835196.8, dtype=float32), 'QLearning/loss': DeviceArray(21723.2, dtype=float32), 'QLearning/td_error': DeviceArray(-193.78238, dtype=float32), 'QLearning/td_error_targ': DeviceArray(12.792564, dtype=float32)}
{'QLearning/grads_max': DeviceArray(109226.46, dtype=float32), 'QLearning

FlatMapping({
  'sequential/_hk_bi_linear': FlatMapping({
                                'w': DeviceArray([-1.4687121,  0.2002634, -1.6090823, ...,  0.6305096,
                                                   0.5523067,  0.6574904], dtype=float32),
                              }),
  'sequential/_hk_bi_linear_1': FlatMapping({
                                  'w': DeviceArray([ 1.5327713 ,  1.1088597 ,  0.32760307, ...,  0.5139265 ,
                                                     1.762129  , -1.0986638 ], dtype=float32),
                                }),
  'sequential/_hk_bi_linear_2': FlatMapping({
                                  'w': DeviceArray([-1.1811807 , -0.86260676,  0.32020476, ..., -0.54447836,
                                                     0.3602686 , -0.2499646 ], dtype=float32),
                                }),
  'sequential/_hk_linear': FlatMapping({
                             'b': DeviceArray([1.0034429 , 1.0027238 , 1.0049691 , 0.99364674, 1.0150

In [None]:
jnp.linalg.norm(q.params['linear']['w'])

In [None]:
transition_batch = buffer.sample(batch_size=2)
metrics = qlearning.update(transition_batch)