In [1]:
import coax
import gym
import haiku as hk
import jax
import jax.numpy as jnp
import optax
from coax.value_losses import mse



# the name of this script
name = 'a2c'

# the cart-pole MDP
env = gym.make('CartPole-v0')
env = coax.wrappers.TrainMonitor(env, name=name, tensorboard_dir=f"./data/tensorboard/{name}")

In [2]:
# ## Baseline MLPS

# def func_pi(S, is_training):
#     logits = hk.Sequential((
#         hk.Linear(8), jax.nn.relu,
#         hk.Linear(8), jax.nn.relu,
#         hk.Linear(8), jax.nn.relu,
#         hk.Linear(env.action_space.n, w_init=jnp.zeros)
#     ))
#     return {'logits': logits(S)}


# def func_v(S, is_training):
#     value = hk.Sequential((
#         hk.Linear(8), jax.nn.relu,
#         hk.Linear(8), jax.nn.relu,
#         hk.Linear(8), jax.nn.relu,
#         hk.Linear(1, w_init=jnp.zeros), jnp.ravel
#     ))
#     return value(S)



In [5]:
from emlp import T, Scalar
from emlp.groups import SO, S, O, Trivial,Z
from emlp_haiku import EMLPBlock, Sequential, Linear,EMLP
from emlp.reps import Rep
from emlp.nn import gated,gate_indices,uniform_rep
from math import prod
from representations import PseudoScalar

## Trivial
# group=Trivial(2)
# rep_in = T(0)*prod(env.observation_space.shape)
# rep_out = T(0)*env.action_space.n#prod(env.action_space.shape)

## Reflection
group=Z(2)
rep_in = PseudoScalar()*prod(env.observation_space.shape)
rep_out = T(1)#*env.action_space.n#prod(env.action_space.shape)

nn_pi = EMLP(rep_in,rep_out,group,ch=100,num_layers=2)
nn_v = EMLP(rep_in,T(0),group,ch=100,num_layers=2)

def func_pi(S, is_training):
    return {'logits': nn_pi(S)}


def func_v(S, is_training):
    return nn_v(S).reshape(-1)

[a2c|root|INFO] Initing EMLP
[a2c|root|INFO] Linear W components:400 rep:96P+48P⊗V+20P⊗V²+8P⊗V³+4P⊗V⁴
[a2c|root|INFO] Linear W components:10000 rep:576V⁰+576V+384V²+216V³+121V⁴+44V⁵+14V⁶+4V⁷+V⁸
[a2c|root|INFO] Linear W components:200 rep:24V+12V²+5V³+2V⁴+V⁵
[a2c|root|INFO] Initing EMLP
[a2c|root|INFO] Linear W components:400 rep:96P+48P⊗V+20P⊗V²+8P⊗V³+4P⊗V⁴
[a2c|root|INFO] Linear W components:10000 rep:576V⁰+576V+384V²+216V³+121V⁴+44V⁵+14V⁶+4V⁷+V⁸
[a2c|root|INFO] Linear W components:100 rep:24V⁰+12V+5V²+2V³+V⁴


In [None]:

# these optimizers collect batches of grads before applying updates
optimizer_v = optax.chain(optax.apply_every(k=32), optax.adam(0.002))
optimizer_pi = optax.chain(optax.apply_every(k=32), optax.adam(0.001))


# value function and its derived policy
v = coax.V(func_v, env)
pi = coax.Policy(func_pi, env)

# experience tracer
tracer = coax.reward_tracing.NStep(n=1, gamma=0.9)

# updaters
vanilla_pg = coax.policy_objectives.VanillaPG(pi, optimizer=optimizer_pi)
simple_td = coax.td_learning.SimpleTD(v, loss_function=mse, optimizer=optimizer_v)


# train
for ep in range(1000):
    s = env.reset()
    er = 0
    for t in range(env.spec.max_episode_steps):
        a = pi(s)
        s_next, r, done, info = env.step(a)
        
        if done and (t == env.spec.max_episode_steps - 1):
            r = 1 / (1 - tracer.gamma)
        er+=r
        tracer.add(s, a, r, done)
        while tracer:
            transition_batch = tracer.pop()
            metrics_v, td_error = simple_td.update(transition_batch, return_td_error=True)
            metrics_pi = vanilla_pg.update(transition_batch, td_error)
            env.record_metrics(metrics_v)
            env.record_metrics(metrics_pi)

        if done:
            break

        s = s_next
    
    print("Epoch reward",er)
    # early stopping
    if env.avg_G > env.spec.reward_threshold:
        break


# run env one more time to render
#coax.utils.generate_gif(env, policy=pi, filepath=f"./data/{name}.gif", duration=25)

[a2c|TrainMonitor|INFO] ep: 1,	T: 13,	G: 12,	avg_r: 1,	avg_G: 12,	t: 12,	dt: 666.033ms,	SimpleTD/loss: 0.612,	VanillaPG/loss: 0.758
[a2c|TrainMonitor|INFO] ep: 2,	T: 26,	G: 12,	avg_r: 1,	avg_G: 12,	t: 12,	dt: 15.533ms,	SimpleTD/loss: 0.612,	VanillaPG/loss: 0.756


Epoch reward 12.0
Epoch reward 12.0


[a2c|TrainMonitor|INFO] ep: 3,	T: 42,	G: 15,	avg_r: 1,	avg_G: 13,	t: 15,	dt: 14.677ms,	SimpleTD/loss: 0.529,	VanillaPG/loss: 0.718


Epoch reward 15.0


[a2c|TrainMonitor|INFO] ep: 4,	T: 67,	G: 24,	avg_r: 1,	avg_G: 15.8,	t: 24,	dt: 14.339ms,	SimpleTD/loss: 0.426,	VanillaPG/loss: 0.611


Epoch reward 24.0


[a2c|TrainMonitor|INFO] ep: 5,	T: 186,	G: 118,	avg_r: 1,	avg_G: 36.2,	t: 118,	dt: 14.087ms,	SimpleTD/loss: 0.602,	VanillaPG/loss: 0.258


Epoch reward 118.0


[a2c|TrainMonitor|INFO] ep: 6,	T: 207,	G: 20,	avg_r: 1,	avg_G: 33.5,	t: 20,	dt: 14.587ms,	SimpleTD/loss: 1.37,	VanillaPG/loss: -0.51


Epoch reward 20.0


[a2c|TrainMonitor|INFO] ep: 7,	T: 283,	G: 75,	avg_r: 1,	avg_G: 39.4,	t: 75,	dt: 14.075ms,	SimpleTD/loss: 0.406,	VanillaPG/loss: 0.194


Epoch reward 75.0


[a2c|TrainMonitor|INFO] ep: 8,	T: 352,	G: 68,	avg_r: 1,	avg_G: 43,	t: 68,	dt: 13.971ms,	SimpleTD/loss: 0.587,	VanillaPG/loss: 0.028


Epoch reward 68.0


[a2c|TrainMonitor|INFO] ep: 9,	T: 472,	G: 119,	avg_r: 1,	avg_G: 51.4,	t: 119,	dt: 13.906ms,	SimpleTD/loss: 0.646,	VanillaPG/loss: -0.0497


Epoch reward 119.0


[a2c|TrainMonitor|INFO] ep: 10,	T: 507,	G: 34,	avg_r: 1,	avg_G: 49.7,	t: 34,	dt: 13.985ms,	SimpleTD/loss: 0.398,	VanillaPG/loss: 0.0163


Epoch reward 34.0


[a2c|TrainMonitor|INFO] ep: 11,	T: 648,	G: 140,	avg_r: 1,	avg_G: 58.7,	t: 140,	dt: 13.991ms,	SimpleTD/loss: 0.0747,	VanillaPG/loss: 0.0924


Epoch reward 140.0


[a2c|TrainMonitor|INFO] ep: 12,	T: 787,	G: 138,	avg_r: 1,	avg_G: 66.7,	t: 138,	dt: 13.608ms,	SimpleTD/loss: 0.311,	VanillaPG/loss: -0.0803


Epoch reward 138.0


[a2c|TrainMonitor|INFO] ep: 13,	T: 879,	G: 91,	avg_r: 1,	avg_G: 69.1,	t: 91,	dt: 14.788ms,	SimpleTD/loss: 0.355,	VanillaPG/loss: -0.144


Epoch reward 91.0


[a2c|TrainMonitor|INFO] ep: 14,	T: 1,080,	G: 200,	avg_r: 1,	avg_G: 82.2,	t: 200,	dt: 13.738ms,	SimpleTD/loss: 0.00657,	VanillaPG/loss: 0.00404


Epoch reward 209.0


[a2c|TrainMonitor|INFO] ep: 15,	T: 1,199,	G: 118,	avg_r: 1,	avg_G: 85.8,	t: 118,	dt: 13.916ms,	SimpleTD/loss: 0.221,	VanillaPG/loss: -0.0165


Epoch reward 118.0


[a2c|TrainMonitor|INFO] ep: 16,	T: 1,321,	G: 121,	avg_r: 1,	avg_G: 89.3,	t: 121,	dt: 14.869ms,	SimpleTD/loss: 0.242,	VanillaPG/loss: -0.044


Epoch reward 121.0


[a2c|TrainMonitor|INFO] ep: 17,	T: 1,473,	G: 151,	avg_r: 1,	avg_G: 95.5,	t: 151,	dt: 13.777ms,	SimpleTD/loss: 0.127,	VanillaPG/loss: 0.0508


Epoch reward 151.0


[a2c|TrainMonitor|INFO] ep: 18,	T: 1,628,	G: 154,	avg_r: 1,	avg_G: 101,	t: 154,	dt: 14.604ms,	SimpleTD/loss: 0.129,	VanillaPG/loss: -0.00902


Epoch reward 154.0


[a2c|TrainMonitor|INFO] ep: 19,	T: 1,798,	G: 169,	avg_r: 1,	avg_G: 108,	t: 169,	dt: 14.025ms,	SimpleTD/loss: 0.0353,	VanillaPG/loss: 0.0285


Epoch reward 169.0


[a2c|TrainMonitor|INFO] ep: 20,	T: 1,999,	G: 200,	avg_r: 1,	avg_G: 117,	t: 200,	dt: 13.824ms,	SimpleTD/loss: 0.00728,	VanillaPG/loss: -0.0111


Epoch reward 209.0


[a2c|TrainMonitor|INFO] ep: 21,	T: 2,200,	G: 200,	avg_r: 1,	avg_G: 126,	t: 200,	dt: 13.666ms,	SimpleTD/loss: 5.69e-05,	VanillaPG/loss: -0.000748


Epoch reward 209.0


[a2c|TrainMonitor|INFO] ep: 22,	T: 2,401,	G: 200,	avg_r: 1,	avg_G: 133,	t: 200,	dt: 13.909ms,	SimpleTD/loss: 0.000882,	VanillaPG/loss: 0.00169


Epoch reward 209.0


[a2c|TrainMonitor|INFO] ep: 23,	T: 2,602,	G: 200,	avg_r: 1,	avg_G: 140,	t: 200,	dt: 13.988ms,	SimpleTD/loss: 0.00176,	VanillaPG/loss: -0.00571


Epoch reward 209.0


[a2c|TrainMonitor|INFO] ep: 24,	T: 2,803,	G: 200,	avg_r: 1,	avg_G: 146,	t: 200,	dt: 14.095ms,	SimpleTD/loss: 5.86e-05,	VanillaPG/loss: -0.00213


Epoch reward 209.0


[a2c|TrainMonitor|INFO] ep: 25,	T: 2,988,	G: 184,	avg_r: 1,	avg_G: 150,	t: 184,	dt: 13.803ms,	SimpleTD/loss: 0.205,	VanillaPG/loss: -0.00664


Epoch reward 184.0
