In [1]:
import coax
import gym
import haiku as hk
import jax
import jax.numpy as jnp
import optax
from coax.value_losses import mse
from RPPRegularizer import RPPRegularizer

import matplotlib.pyplot as plt

# the name of this script
name = 'a2c'

# the cart-pole MDP
# env = gym.make('CartPole-v0')
env = gym.make("rpp_gym:InclinedCartpole-v0")

In [2]:
env = coax.wrappers.TrainMonitor(env, name=name, tensorboard_dir=f"./data/tensorboard/{name}")

In [3]:
# env.spec.max_episode_steps = 200
# env.spec.reward_threshold = 195.0

In [4]:
from emlp import T, Scalar
from emlp.groups import SO, S, O, Trivial,Z
import emlp.nn.haiku as ehk
from emlp.reps import Rep
from emlp.nn import gated,gate_indices,uniform_rep
from math import prod
from representations import PseudoScalar
from mixed_emlp_haiku import MixedEMLP

## Trivial
# group=Trivial(2)
# rep_in = T(0)*prod(env.observation_space.shape)
# rep_out = T(0)*env.action_space.n#prod(env.action_space.shape)

## Reflection
group=Z(2)
rep_in = PseudoScalar()*prod(env.observation_space.shape)
rep_out = T(1)#*env.action_space.n#prod(env.action_space.shape)

nn_pi = ehk.EMLP(rep_in,rep_out,group,ch=100,num_layers=2)
nn_v = ehk.EMLP(rep_in,T(0),group,ch=100,num_layers=2)

# nn_pi = ehk.MLP(rep_in,rep_out(group),group,ch=100,num_layers=2)
# nn_v = ehk.MLP(rep_in,T(0),group,ch=100,num_layers=2)

# nn_pi = MixedEMLP(rep_in,rep_out(group),group,ch=100,num_layers=2)
# nn_v = MixedEMLP(rep_in,T(0),group,ch=100,num_layers=2)


def func_pi(S, is_training):
    return {'logits': nn_pi(S)}


def func_v(S, is_training):
    return nn_v(S).reshape(-1)



# def func_pi(S, is_training):
#     logits = hk.Sequential((
#         hk.Linear(16), jax.nn.relu,
#         hk.Linear(16), jax.nn.relu,
#         hk.Linear(16), jax.nn.relu,
#         hk.Linear(env.action_space.n, w_init=jnp.zeros)
#     ))
#     return {'logits': logits(S)}



# def func_v(S, is_training):
#     value = hk.Sequential((
#         hk.Linear(32), jax.nn.relu,
#         hk.Linear(32), jax.nn.relu,
#         hk.Linear(32), jax.nn.relu,
#         hk.Linear(32), jax.nn.relu,
#         hk.Linear(32), jax.nn.relu,
#         hk.Linear(1, w_init=jnp.zeros), jnp.ravel
#     ))
#     return value(S)


[a2c|root|INFO] Initing EMLP (Haiku)
[a2c|root|INFO] Linear W components:400 rep:96P+48P⊗V+20P⊗V²+8P⊗V³+4P⊗V⁴
[a2c|root|INFO] P cache miss
[a2c|root|INFO] Solving basis for P, for G=Z(2)
[a2c|root|INFO] P⊗V cache miss
[a2c|root|INFO] Solving basis for P⊗V, for G=Z(2)
[a2c|root|INFO] P⊗V² cache miss
[a2c|root|INFO] Solving basis for P⊗V², for G=Z(2)
[a2c|root|INFO] P⊗V³ cache miss
[a2c|root|INFO] Solving basis for P⊗V³, for G=Z(2)
[a2c|root|INFO] P⊗V⁴ cache miss
[a2c|root|INFO] Solving basis for P⊗V⁴, for G=Z(2)
[a2c|root|INFO] V cache miss
[a2c|root|INFO] Solving basis for V, for G=Z(2)
[a2c|root|INFO] V² cache miss
[a2c|root|INFO] Solving basis for V², for G=Z(2)
[a2c|root|INFO] V³ cache miss
[a2c|root|INFO] Solving basis for V³, for G=Z(2)
[a2c|root|INFO] V⁴ cache miss
[a2c|root|INFO] Solving basis for V⁴, for G=Z(2)
[a2c|root|INFO] Linear W components:10000 rep:576V⁰+576V+384V²+216V³+121V⁴+44V⁵+14V⁶+4V⁷+V⁸
[a2c|root|INFO] V⁵ cache miss
[a2c|root|INFO] Solving basis for V⁵, for G=Z(2

In [5]:
# these optimizers collect batches of grads before applying updates
optimizer_v = optax.chain(optax.apply_every(k=32), optax.adam(0.002))
optimizer_pi = optax.chain(optax.apply_every(k=32), optax.adam(0.001))


# value function and its derived policy
v = coax.V(func_v, env)
pi = coax.Policy(func_pi, env)

In [6]:
# experience tracer
pi_regularizer = RPPRegularizer(pi, basic_wd=1e-4, equiv_wd=1e-4)
tracer = coax.reward_tracing.NStep(n=1, gamma=0.9)

# updaters
vanilla_pg = coax.policy_objectives.VanillaPG(pi, optimizer=optimizer_pi,
                                              regularizer=pi_regularizer)
simple_td = coax.td_learning.SimpleTD(v, loss_function=mse, optimizer=optimizer_v)

epoch_rewards = []

# train
for ep in range(1000):
    s = env.reset()
    er = 0
    for t in range(env.spec.max_episode_steps):
        a = pi(s)
        s_next, r, done, info = env.step(a)
        
        if done and (t == env.spec.max_episode_steps - 1):
            r = 1 / (1 - tracer.gamma)
        er+=r
        tracer.add(s, a, r, done)
        while tracer:
            transition_batch = tracer.pop()
            metrics_v, td_error = simple_td.update(transition_batch, return_td_error=True)
            metrics_pi = vanilla_pg.update(transition_batch, td_error)
            env.record_metrics(metrics_v)
            env.record_metrics(metrics_pi)

        if done:
            break

        s = s_next
    
    print("Epoch reward",er)
    epoch_rewards.append(er)
    # early stopping
    if env.avg_G > env.spec.reward_threshold:
        break


# run env one more time to render
#coax.utils.generate_gif(env, policy=pi, filepath=f"./data/{name}.gif", duration=25)

[a2c|TrainMonitor|INFO] ep: 1,	T: 12,	G: 11,	avg_r: 1,	avg_G: 11,	t: 11,	dt: 781.796ms,	SimpleTD/loss: 0.349,	VanillaPG/loss: 1


Epoch reward 11.0


[a2c|TrainMonitor|INFO] ep: 2,	T: 71,	G: 58,	avg_r: 1,	avg_G: 34.5,	t: 58,	dt: 27.695ms,	SimpleTD/loss: 0.371,	VanillaPG/loss: 0.974


Epoch reward 58.0


[a2c|TrainMonitor|INFO] ep: 3,	T: 95,	G: 23,	avg_r: 1,	avg_G: 30.7,	t: 23,	dt: 29.031ms,	SimpleTD/loss: 0.755,	VanillaPG/loss: 0.838


Epoch reward 23.0


[a2c|TrainMonitor|INFO] ep: 4,	T: 126,	G: 30,	avg_r: 1,	avg_G: 30.5,	t: 30,	dt: 32.820ms,	SimpleTD/loss: 0.658,	VanillaPG/loss: 0.631


Epoch reward 30.0


[a2c|TrainMonitor|INFO] ep: 5,	T: 142,	G: 15,	avg_r: 1,	avg_G: 27.4,	t: 15,	dt: 32.072ms,	SimpleTD/loss: 1.74,	VanillaPG/loss: 0.397


Epoch reward 15.0


[a2c|TrainMonitor|INFO] ep: 6,	T: 155,	G: 12,	avg_r: 1,	avg_G: 24.8,	t: 12,	dt: 32.732ms,	SimpleTD/loss: 2.87,	VanillaPG/loss: 0.232


Epoch reward 12.0


[a2c|TrainMonitor|INFO] ep: 7,	T: 174,	G: 18,	avg_r: 1,	avg_G: 23.9,	t: 18,	dt: 32.518ms,	SimpleTD/loss: 0.878,	VanillaPG/loss: 0.37


Epoch reward 18.0


[a2c|TrainMonitor|INFO] ep: 8,	T: 184,	G: 9,	avg_r: 1,	avg_G: 22,	t: 9,	dt: 33.430ms,	SimpleTD/loss: 1.02,	VanillaPG/loss: 0.553


Epoch reward 9.0


[a2c|TrainMonitor|INFO] ep: 9,	T: 206,	G: 21,	avg_r: 1,	avg_G: 21.9,	t: 21,	dt: 31.557ms,	SimpleTD/loss: 0.5,	VanillaPG/loss: 0.745


Epoch reward 21.0


[a2c|TrainMonitor|INFO] ep: 10,	T: 244,	G: 37,	avg_r: 1,	avg_G: 23.4,	t: 37,	dt: 32.107ms,	SimpleTD/loss: 0.422,	VanillaPG/loss: 0.694


Epoch reward 37.0


[a2c|TrainMonitor|INFO] ep: 11,	T: 263,	G: 18,	avg_r: 1,	avg_G: 22.9,	t: 18,	dt: 32.423ms,	SimpleTD/loss: 0.467,	VanillaPG/loss: 0.412


Epoch reward 18.0


[a2c|TrainMonitor|INFO] ep: 12,	T: 279,	G: 15,	avg_r: 1,	avg_G: 22.1,	t: 15,	dt: 32.569ms,	SimpleTD/loss: 0.153,	VanillaPG/loss: 0.565


Epoch reward 15.0


[a2c|TrainMonitor|INFO] ep: 13,	T: 297,	G: 17,	avg_r: 1,	avg_G: 21.6,	t: 17,	dt: 32.002ms,	SimpleTD/loss: 0.508,	VanillaPG/loss: 0.522


Epoch reward 17.0


[a2c|TrainMonitor|INFO] ep: 14,	T: 310,	G: 12,	avg_r: 1,	avg_G: 20.6,	t: 12,	dt: 32.540ms,	SimpleTD/loss: 0.934,	VanillaPG/loss: 0.889


Epoch reward 12.0


[a2c|TrainMonitor|INFO] ep: 15,	T: 322,	G: 11,	avg_r: 1,	avg_G: 19.6,	t: 11,	dt: 31.599ms,	SimpleTD/loss: 7.49,	VanillaPG/loss: 0.608


Epoch reward 11.0


[a2c|TrainMonitor|INFO] ep: 16,	T: 335,	G: 12,	avg_r: 1,	avg_G: 18.9,	t: 12,	dt: 33.140ms,	SimpleTD/loss: 13.1,	VanillaPG/loss: 0.676


Epoch reward 12.0


[a2c|TrainMonitor|INFO] ep: 17,	T: 346,	G: 10,	avg_r: 1,	avg_G: 18,	t: 10,	dt: 34.815ms,	SimpleTD/loss: 0.282,	VanillaPG/loss: 0.408


Epoch reward 10.0


[a2c|TrainMonitor|INFO] ep: 18,	T: 366,	G: 19,	avg_r: 1,	avg_G: 18.1,	t: 19,	dt: 31.860ms,	SimpleTD/loss: 0.517,	VanillaPG/loss: 0.526


Epoch reward 19.0


[a2c|TrainMonitor|INFO] ep: 19,	T: 383,	G: 16,	avg_r: 1,	avg_G: 17.9,	t: 16,	dt: 33.490ms,	SimpleTD/loss: 0.716,	VanillaPG/loss: 0.536


Epoch reward 16.0


[a2c|TrainMonitor|INFO] ep: 20,	T: 408,	G: 24,	avg_r: 1,	avg_G: 18.5,	t: 24,	dt: 34.548ms,	SimpleTD/loss: 0.256,	VanillaPG/loss: 0.575


Epoch reward 24.0


[a2c|TrainMonitor|INFO] ep: 21,	T: 419,	G: 10,	avg_r: 1,	avg_G: 17.6,	t: 10,	dt: 33.563ms,	SimpleTD/loss: 1.27,	VanillaPG/loss: 0.286


Epoch reward 10.0


[a2c|TrainMonitor|INFO] ep: 22,	T: 436,	G: 16,	avg_r: 1,	avg_G: 17.5,	t: 16,	dt: 32.444ms,	SimpleTD/loss: 1.34,	VanillaPG/loss: 0.355


Epoch reward 16.0


[a2c|TrainMonitor|INFO] ep: 23,	T: 449,	G: 12,	avg_r: 1,	avg_G: 16.9,	t: 12,	dt: 33.376ms,	SimpleTD/loss: 0.236,	VanillaPG/loss: 0.439


Epoch reward 12.0


[a2c|TrainMonitor|INFO] ep: 24,	T: 462,	G: 12,	avg_r: 1,	avg_G: 16.4,	t: 12,	dt: 33.570ms,	SimpleTD/loss: 0.399,	VanillaPG/loss: 0.549


Epoch reward 12.0


[a2c|TrainMonitor|INFO] ep: 25,	T: 473,	G: 10,	avg_r: 1,	avg_G: 15.8,	t: 10,	dt: 33.619ms,	SimpleTD/loss: 0.517,	VanillaPG/loss: 0.548


Epoch reward 10.0


[a2c|TrainMonitor|INFO] ep: 26,	T: 501,	G: 27,	avg_r: 1,	avg_G: 16.9,	t: 27,	dt: 32.895ms,	SimpleTD/loss: 0.213,	VanillaPG/loss: 0.703


Epoch reward 27.0


[a2c|TrainMonitor|INFO] ep: 27,	T: 519,	G: 17,	avg_r: 1,	avg_G: 16.9,	t: 17,	dt: 32.182ms,	SimpleTD/loss: 0.387,	VanillaPG/loss: 0.465


Epoch reward 17.0


[a2c|TrainMonitor|INFO] ep: 28,	T: 535,	G: 15,	avg_r: 1,	avg_G: 16.7,	t: 15,	dt: 32.902ms,	SimpleTD/loss: 0.711,	VanillaPG/loss: 0.287


Epoch reward 15.0


[a2c|TrainMonitor|INFO] ep: 29,	T: 623,	G: 87,	avg_r: 1,	avg_G: 23.8,	t: 87,	dt: 33.137ms,	SimpleTD/loss: 0.61,	VanillaPG/loss: 0.425


Epoch reward 87.0


[a2c|TrainMonitor|INFO] ep: 30,	T: 640,	G: 16,	avg_r: 1,	avg_G: 23,	t: 16,	dt: 32.892ms,	SimpleTD/loss: 1.81,	VanillaPG/loss: 0.176


Epoch reward 16.0


[a2c|TrainMonitor|INFO] ep: 31,	T: 682,	G: 41,	avg_r: 1,	avg_G: 24.8,	t: 41,	dt: 32.496ms,	SimpleTD/loss: 0.379,	VanillaPG/loss: 0.557


Epoch reward 41.0


[a2c|TrainMonitor|INFO] ep: 32,	T: 719,	G: 36,	avg_r: 1,	avg_G: 25.9,	t: 36,	dt: 33.911ms,	SimpleTD/loss: 0.589,	VanillaPG/loss: 0.586


Epoch reward 36.0


[a2c|TrainMonitor|INFO] ep: 33,	T: 741,	G: 21,	avg_r: 1,	avg_G: 25.4,	t: 21,	dt: 32.159ms,	SimpleTD/loss: 0.755,	VanillaPG/loss: 0.463


Epoch reward 21.0


[a2c|TrainMonitor|INFO] ep: 34,	T: 780,	G: 38,	avg_r: 1,	avg_G: 26.7,	t: 38,	dt: 34.384ms,	SimpleTD/loss: 0.63,	VanillaPG/loss: 0.257


Epoch reward 38.0


[a2c|TrainMonitor|INFO] ep: 35,	T: 820,	G: 39,	avg_r: 1,	avg_G: 27.9,	t: 39,	dt: 32.669ms,	SimpleTD/loss: 0.616,	VanillaPG/loss: 0.243


Epoch reward 39.0


[a2c|TrainMonitor|INFO] ep: 36,	T: 895,	G: 74,	avg_r: 1,	avg_G: 32.5,	t: 74,	dt: 31.926ms,	SimpleTD/loss: 0.58,	VanillaPG/loss: 0.208


Epoch reward 74.0


[a2c|TrainMonitor|INFO] ep: 37,	T: 990,	G: 94,	avg_r: 1,	avg_G: 38.7,	t: 94,	dt: 32.859ms,	SimpleTD/loss: 0.522,	VanillaPG/loss: 0.52


Epoch reward 94.0


[a2c|TrainMonitor|INFO] ep: 38,	T: 1,070,	G: 79,	avg_r: 1,	avg_G: 42.7,	t: 79,	dt: 33.031ms,	SimpleTD/loss: 0.45,	VanillaPG/loss: 0.531


Epoch reward 79.0


[a2c|TrainMonitor|INFO] ep: 39,	T: 1,168,	G: 97,	avg_r: 1,	avg_G: 48.1,	t: 97,	dt: 32.873ms,	SimpleTD/loss: 0.37,	VanillaPG/loss: 0.493


Epoch reward 97.0


[a2c|TrainMonitor|INFO] ep: 40,	T: 1,215,	G: 46,	avg_r: 1,	avg_G: 47.9,	t: 46,	dt: 32.414ms,	SimpleTD/loss: 0.473,	VanillaPG/loss: 0.488


Epoch reward 46.0


[a2c|TrainMonitor|INFO] ep: 41,	T: 1,271,	G: 55,	avg_r: 1,	avg_G: 48.6,	t: 55,	dt: 32.727ms,	SimpleTD/loss: 0.562,	VanillaPG/loss: 0.542


Epoch reward 55.0


[a2c|TrainMonitor|INFO] ep: 42,	T: 1,373,	G: 101,	avg_r: 1,	avg_G: 53.9,	t: 101,	dt: 32.076ms,	SimpleTD/loss: 0.41,	VanillaPG/loss: 0.513


Epoch reward 101.0


[a2c|TrainMonitor|INFO] ep: 43,	T: 1,459,	G: 85,	avg_r: 1,	avg_G: 57,	t: 85,	dt: 32.906ms,	SimpleTD/loss: 0.407,	VanillaPG/loss: 0.518


Epoch reward 85.0


[a2c|TrainMonitor|INFO] ep: 44,	T: 1,517,	G: 57,	avg_r: 1,	avg_G: 57,	t: 57,	dt: 32.785ms,	SimpleTD/loss: 0.409,	VanillaPG/loss: 0.456


Epoch reward 57.0


[a2c|TrainMonitor|INFO] ep: 45,	T: 1,609,	G: 91,	avg_r: 1,	avg_G: 60.4,	t: 91,	dt: 32.846ms,	SimpleTD/loss: 0.364,	VanillaPG/loss: 0.523


Epoch reward 91.0


[a2c|TrainMonitor|INFO] ep: 46,	T: 1,682,	G: 72,	avg_r: 1,	avg_G: 61.5,	t: 72,	dt: 32.691ms,	SimpleTD/loss: 0.478,	VanillaPG/loss: 0.353


Epoch reward 72.0


[a2c|TrainMonitor|INFO] ep: 47,	T: 1,727,	G: 44,	avg_r: 1,	avg_G: 59.8,	t: 44,	dt: 32.194ms,	SimpleTD/loss: 0.443,	VanillaPG/loss: 0.479


Epoch reward 44.0


[a2c|TrainMonitor|INFO] ep: 48,	T: 1,798,	G: 70,	avg_r: 1,	avg_G: 60.8,	t: 70,	dt: 32.459ms,	SimpleTD/loss: 0.354,	VanillaPG/loss: 0.527


Epoch reward 70.0


[a2c|TrainMonitor|INFO] ep: 49,	T: 1,860,	G: 61,	avg_r: 1,	avg_G: 60.8,	t: 61,	dt: 32.318ms,	SimpleTD/loss: 0.38,	VanillaPG/loss: 0.486


Epoch reward 61.0


[a2c|TrainMonitor|INFO] ep: 50,	T: 1,936,	G: 75,	avg_r: 1,	avg_G: 62.2,	t: 75,	dt: 32.248ms,	SimpleTD/loss: 0.238,	VanillaPG/loss: 0.517


Epoch reward 75.0


[a2c|TrainMonitor|INFO] ep: 51,	T: 1,990,	G: 53,	avg_r: 1,	avg_G: 61.3,	t: 53,	dt: 32.103ms,	SimpleTD/loss: 0.363,	VanillaPG/loss: 0.403


Epoch reward 53.0


[a2c|TrainMonitor|INFO] ep: 52,	T: 2,090,	G: 99,	avg_r: 1,	avg_G: 65.1,	t: 99,	dt: 32.539ms,	SimpleTD/loss: 0.042,	VanillaPG/loss: 0.48


Epoch reward 99.0


[a2c|TrainMonitor|INFO] ep: 53,	T: 2,179,	G: 88,	avg_r: 1,	avg_G: 67.4,	t: 88,	dt: 33.312ms,	SimpleTD/loss: 0.0269,	VanillaPG/loss: 0.487


Epoch reward 88.0


[a2c|TrainMonitor|INFO] ep: 54,	T: 2,255,	G: 75,	avg_r: 1,	avg_G: 68.1,	t: 75,	dt: 32.101ms,	SimpleTD/loss: 0.0442,	VanillaPG/loss: 0.496


Epoch reward 75.0


[a2c|TrainMonitor|INFO] ep: 55,	T: 2,348,	G: 92,	avg_r: 1,	avg_G: 70.5,	t: 92,	dt: 32.275ms,	SimpleTD/loss: 0.0137,	VanillaPG/loss: 0.509


Epoch reward 92.0


[a2c|TrainMonitor|INFO] ep: 56,	T: 2,416,	G: 67,	avg_r: 1,	avg_G: 70.2,	t: 67,	dt: 32.982ms,	SimpleTD/loss: 0.00824,	VanillaPG/loss: 0.512


Epoch reward 67.0


[a2c|TrainMonitor|INFO] ep: 57,	T: 2,472,	G: 55,	avg_r: 1,	avg_G: 68.7,	t: 55,	dt: 31.775ms,	SimpleTD/loss: 0.008,	VanillaPG/loss: 0.502


Epoch reward 55.0


[a2c|TrainMonitor|INFO] ep: 58,	T: 2,550,	G: 77,	avg_r: 1,	avg_G: 69.5,	t: 77,	dt: 32.824ms,	SimpleTD/loss: 0.0273,	VanillaPG/loss: 0.493


Epoch reward 77.0


[a2c|TrainMonitor|INFO] ep: 59,	T: 2,613,	G: 62,	avg_r: 1,	avg_G: 68.7,	t: 62,	dt: 32.319ms,	SimpleTD/loss: 0.0254,	VanillaPG/loss: 0.494


Epoch reward 62.0


[a2c|TrainMonitor|INFO] ep: 60,	T: 2,669,	G: 55,	avg_r: 1,	avg_G: 67.4,	t: 55,	dt: 32.414ms,	SimpleTD/loss: 0.00634,	VanillaPG/loss: 0.513


Epoch reward 55.0


[a2c|TrainMonitor|INFO] ep: 61,	T: 2,724,	G: 54,	avg_r: 1,	avg_G: 66,	t: 54,	dt: 33.210ms,	SimpleTD/loss: 0.0296,	VanillaPG/loss: 0.506


Epoch reward 54.0


[a2c|TrainMonitor|INFO] ep: 62,	T: 2,779,	G: 54,	avg_r: 1,	avg_G: 64.8,	t: 54,	dt: 32.184ms,	SimpleTD/loss: 0.00894,	VanillaPG/loss: 0.507


Epoch reward 54.0


[a2c|TrainMonitor|INFO] ep: 63,	T: 2,970,	G: 190,	avg_r: 1,	avg_G: 77.3,	t: 190,	dt: 31.868ms,	SimpleTD/loss: 0.00973,	VanillaPG/loss: 0.512


Epoch reward 190.0


[a2c|TrainMonitor|INFO] ep: 64,	T: 3,062,	G: 91,	avg_r: 1,	avg_G: 78.7,	t: 91,	dt: 32.944ms,	SimpleTD/loss: 0.00532,	VanillaPG/loss: 0.503


Epoch reward 91.0


[a2c|TrainMonitor|INFO] ep: 65,	T: 3,130,	G: 67,	avg_r: 1,	avg_G: 77.5,	t: 67,	dt: 33.128ms,	SimpleTD/loss: 0.0756,	VanillaPG/loss: 0.508


Epoch reward 67.0


[a2c|TrainMonitor|INFO] ep: 66,	T: 3,186,	G: 55,	avg_r: 1,	avg_G: 75.3,	t: 55,	dt: 31.742ms,	SimpleTD/loss: 0.031,	VanillaPG/loss: 0.508


Epoch reward 55.0


[a2c|TrainMonitor|INFO] ep: 67,	T: 3,283,	G: 96,	avg_r: 1,	avg_G: 77.4,	t: 96,	dt: 32.374ms,	SimpleTD/loss: 0.0135,	VanillaPG/loss: 0.511


Epoch reward 96.0


[a2c|TrainMonitor|INFO] ep: 68,	T: 3,353,	G: 69,	avg_r: 1,	avg_G: 76.5,	t: 69,	dt: 32.295ms,	SimpleTD/loss: 0.102,	VanillaPG/loss: 0.505


Epoch reward 69.0


[a2c|TrainMonitor|INFO] ep: 69,	T: 3,433,	G: 79,	avg_r: 1,	avg_G: 76.8,	t: 79,	dt: 31.699ms,	SimpleTD/loss: 0.0466,	VanillaPG/loss: 0.491


Epoch reward 79.0


[a2c|TrainMonitor|INFO] ep: 70,	T: 3,524,	G: 90,	avg_r: 1,	avg_G: 78.1,	t: 90,	dt: 32.349ms,	SimpleTD/loss: 0.0321,	VanillaPG/loss: 0.52


Epoch reward 90.0


[a2c|TrainMonitor|INFO] ep: 71,	T: 3,643,	G: 118,	avg_r: 1,	avg_G: 82.1,	t: 118,	dt: 31.840ms,	SimpleTD/loss: 0.196,	VanillaPG/loss: 0.509


Epoch reward 118.0


[a2c|TrainMonitor|INFO] ep: 72,	T: 3,785,	G: 141,	avg_r: 1,	avg_G: 88,	t: 141,	dt: 32.174ms,	SimpleTD/loss: 0.0143,	VanillaPG/loss: 0.523


Epoch reward 141.0


[a2c|TrainMonitor|INFO] ep: 73,	T: 3,855,	G: 69,	avg_r: 1,	avg_G: 86.1,	t: 69,	dt: 30.399ms,	SimpleTD/loss: 0.282,	VanillaPG/loss: 0.508


Epoch reward 69.0


[a2c|TrainMonitor|INFO] ep: 74,	T: 3,960,	G: 104,	avg_r: 1,	avg_G: 87.9,	t: 104,	dt: 32.014ms,	SimpleTD/loss: 0.154,	VanillaPG/loss: 0.52


Epoch reward 104.0


[a2c|TrainMonitor|INFO] ep: 75,	T: 4,042,	G: 81,	avg_r: 1,	avg_G: 87.2,	t: 81,	dt: 32.220ms,	SimpleTD/loss: 0.063,	VanillaPG/loss: 0.514


Epoch reward 81.0


[a2c|TrainMonitor|INFO] ep: 76,	T: 4,156,	G: 113,	avg_r: 1,	avg_G: 89.8,	t: 113,	dt: 32.458ms,	SimpleTD/loss: 0.0291,	VanillaPG/loss: 0.511


Epoch reward 113.0


[a2c|TrainMonitor|INFO] ep: 77,	T: 4,244,	G: 87,	avg_r: 1,	avg_G: 89.5,	t: 87,	dt: 32.232ms,	SimpleTD/loss: 0.046,	VanillaPG/loss: 0.513


Epoch reward 87.0


[a2c|TrainMonitor|INFO] ep: 78,	T: 4,363,	G: 118,	avg_r: 1,	avg_G: 92.3,	t: 118,	dt: 32.287ms,	SimpleTD/loss: 0.0303,	VanillaPG/loss: 0.517


Epoch reward 118.0


[a2c|TrainMonitor|INFO] ep: 79,	T: 4,462,	G: 98,	avg_r: 1,	avg_G: 92.9,	t: 98,	dt: 32.298ms,	SimpleTD/loss: 0.144,	VanillaPG/loss: 0.512


Epoch reward 98.0


[a2c|TrainMonitor|INFO] ep: 80,	T: 4,579,	G: 116,	avg_r: 1,	avg_G: 95.2,	t: 116,	dt: 32.232ms,	SimpleTD/loss: 0.0507,	VanillaPG/loss: 0.515


Epoch reward 116.0


[a2c|TrainMonitor|INFO] ep: 81,	T: 4,673,	G: 93,	avg_r: 1,	avg_G: 95,	t: 93,	dt: 32.958ms,	SimpleTD/loss: 0.169,	VanillaPG/loss: 0.515


Epoch reward 93.0


[a2c|TrainMonitor|INFO] ep: 82,	T: 4,767,	G: 93,	avg_r: 1,	avg_G: 94.8,	t: 93,	dt: 32.177ms,	SimpleTD/loss: 0.0975,	VanillaPG/loss: 0.516


Epoch reward 93.0


[a2c|TrainMonitor|INFO] ep: 83,	T: 4,968,	G: 200,	avg_r: 1,	avg_G: 105,	t: 200,	dt: 32.387ms,	SimpleTD/loss: 0.00355,	VanillaPG/loss: 0.515


Epoch reward 209.0


[a2c|TrainMonitor|INFO] ep: 84,	T: 5,063,	G: 94,	avg_r: 1,	avg_G: 104,	t: 94,	dt: 32.087ms,	SimpleTD/loss: 0.0248,	VanillaPG/loss: 0.514


Epoch reward 94.0


[a2c|TrainMonitor|INFO] ep: 85,	T: 5,139,	G: 75,	avg_r: 1,	avg_G: 101,	t: 75,	dt: 31.875ms,	SimpleTD/loss: 0.384,	VanillaPG/loss: 0.515


Epoch reward 75.0


[a2c|TrainMonitor|INFO] ep: 86,	T: 5,276,	G: 136,	avg_r: 1,	avg_G: 105,	t: 136,	dt: 32.073ms,	SimpleTD/loss: 0.123,	VanillaPG/loss: 0.521


Epoch reward 136.0


[a2c|TrainMonitor|INFO] ep: 87,	T: 5,439,	G: 162,	avg_r: 1,	avg_G: 110,	t: 162,	dt: 30.874ms,	SimpleTD/loss: 0.293,	VanillaPG/loss: 0.375


Epoch reward 162.0


[a2c|TrainMonitor|INFO] ep: 88,	T: 5,509,	G: 69,	avg_r: 1,	avg_G: 106,	t: 69,	dt: 32.926ms,	SimpleTD/loss: 0.36,	VanillaPG/loss: 0.521


Epoch reward 69.0


[a2c|TrainMonitor|INFO] ep: 89,	T: 5,571,	G: 61,	avg_r: 1,	avg_G: 102,	t: 61,	dt: 32.088ms,	SimpleTD/loss: 0.417,	VanillaPG/loss: 0.521


Epoch reward 61.0


[a2c|TrainMonitor|INFO] ep: 90,	T: 5,649,	G: 77,	avg_r: 1,	avg_G: 99.3,	t: 77,	dt: 31.544ms,	SimpleTD/loss: 0.227,	VanillaPG/loss: 0.516


Epoch reward 77.0


[a2c|TrainMonitor|INFO] ep: 91,	T: 5,694,	G: 44,	avg_r: 1,	avg_G: 93.8,	t: 44,	dt: 32.294ms,	SimpleTD/loss: 0.45,	VanillaPG/loss: 0.491


Epoch reward 44.0


[a2c|TrainMonitor|INFO] ep: 92,	T: 5,752,	G: 57,	avg_r: 1,	avg_G: 90.1,	t: 57,	dt: 27.919ms,	SimpleTD/loss: 0.161,	VanillaPG/loss: 0.516


Epoch reward 57.0


[a2c|TrainMonitor|INFO] ep: 93,	T: 5,816,	G: 63,	avg_r: 1,	avg_G: 87.4,	t: 63,	dt: 31.665ms,	SimpleTD/loss: 0.132,	VanillaPG/loss: 0.512


Epoch reward 63.0


[a2c|TrainMonitor|INFO] ep: 94,	T: 5,884,	G: 67,	avg_r: 1,	avg_G: 85.4,	t: 67,	dt: 32.821ms,	SimpleTD/loss: 0.0419,	VanillaPG/loss: 0.519


Epoch reward 67.0


[a2c|TrainMonitor|INFO] ep: 95,	T: 6,003,	G: 118,	avg_r: 1,	avg_G: 88.6,	t: 118,	dt: 32.627ms,	SimpleTD/loss: 0.0179,	VanillaPG/loss: 0.517


Epoch reward 118.0


[a2c|TrainMonitor|INFO] ep: 96,	T: 6,075,	G: 71,	avg_r: 1,	avg_G: 86.9,	t: 71,	dt: 32.158ms,	SimpleTD/loss: 0.0289,	VanillaPG/loss: 0.517


Epoch reward 71.0


[a2c|TrainMonitor|INFO] ep: 97,	T: 6,191,	G: 115,	avg_r: 1,	avg_G: 89.7,	t: 115,	dt: 31.625ms,	SimpleTD/loss: 0.000993,	VanillaPG/loss: 0.516


Epoch reward 115.0


[a2c|TrainMonitor|INFO] ep: 98,	T: 6,269,	G: 77,	avg_r: 1,	avg_G: 88.4,	t: 77,	dt: 32.261ms,	SimpleTD/loss: 0.00966,	VanillaPG/loss: 0.503


Epoch reward 77.0


[a2c|TrainMonitor|INFO] ep: 99,	T: 6,334,	G: 64,	avg_r: 1,	avg_G: 86,	t: 64,	dt: 31.304ms,	SimpleTD/loss: 0.00669,	VanillaPG/loss: 0.515


Epoch reward 64.0


[a2c|TrainMonitor|INFO] ep: 100,	T: 6,535,	G: 200,	avg_r: 1,	avg_G: 97.4,	t: 200,	dt: 31.392ms,	SimpleTD/loss: 0.000311,	VanillaPG/loss: 0.516


Epoch reward 209.0


[a2c|TrainMonitor|INFO] ep: 101,	T: 6,726,	G: 190,	avg_r: 1,	avg_G: 107,	t: 190,	dt: 32.375ms,	SimpleTD/loss: 0.0568,	VanillaPG/loss: 0.516


Epoch reward 190.0


[a2c|TrainMonitor|INFO] ep: 102,	T: 6,782,	G: 55,	avg_r: 1,	avg_G: 101,	t: 55,	dt: 32.108ms,	SimpleTD/loss: 0.00773,	VanillaPG/loss: 0.518


Epoch reward 55.0


[a2c|TrainMonitor|INFO] ep: 103,	T: 6,983,	G: 200,	avg_r: 1,	avg_G: 111,	t: 200,	dt: 29.822ms,	SimpleTD/loss: 0.0211,	VanillaPG/loss: 0.517


Epoch reward 209.0


[a2c|TrainMonitor|INFO] ep: 104,	T: 7,093,	G: 109,	avg_r: 1,	avg_G: 111,	t: 109,	dt: 28.193ms,	SimpleTD/loss: 0.0782,	VanillaPG/loss: 0.517


Epoch reward 109.0


[a2c|TrainMonitor|INFO] ep: 105,	T: 7,180,	G: 86,	avg_r: 1,	avg_G: 109,	t: 86,	dt: 27.800ms,	SimpleTD/loss: 0.0129,	VanillaPG/loss: 0.51


Epoch reward 86.0


[a2c|TrainMonitor|INFO] ep: 106,	T: 7,314,	G: 133,	avg_r: 1,	avg_G: 111,	t: 133,	dt: 31.271ms,	SimpleTD/loss: 0.0267,	VanillaPG/loss: 0.515


Epoch reward 133.0


[a2c|TrainMonitor|INFO] ep: 107,	T: 7,385,	G: 70,	avg_r: 1,	avg_G: 107,	t: 70,	dt: 32.685ms,	SimpleTD/loss: 0.0327,	VanillaPG/loss: 0.518


Epoch reward 70.0


[a2c|TrainMonitor|INFO] ep: 108,	T: 7,458,	G: 72,	avg_r: 1,	avg_G: 103,	t: 72,	dt: 32.751ms,	SimpleTD/loss: 0.296,	VanillaPG/loss: 0.516


Epoch reward 72.0


[a2c|TrainMonitor|INFO] ep: 109,	T: 7,605,	G: 146,	avg_r: 1,	avg_G: 108,	t: 146,	dt: 32.741ms,	SimpleTD/loss: 0.0228,	VanillaPG/loss: 0.517


Epoch reward 146.0


[a2c|TrainMonitor|INFO] ep: 110,	T: 7,731,	G: 125,	avg_r: 1,	avg_G: 109,	t: 125,	dt: 32.514ms,	SimpleTD/loss: 0.0205,	VanillaPG/loss: 0.514


Epoch reward 125.0


[a2c|TrainMonitor|INFO] ep: 111,	T: 7,852,	G: 120,	avg_r: 1,	avg_G: 110,	t: 120,	dt: 32.058ms,	SimpleTD/loss: 0.163,	VanillaPG/loss: 0.517


Epoch reward 120.0


[a2c|TrainMonitor|INFO] ep: 112,	T: 7,916,	G: 63,	avg_r: 1,	avg_G: 106,	t: 63,	dt: 33.083ms,	SimpleTD/loss: 0.0684,	VanillaPG/loss: 0.517


Epoch reward 63.0


[a2c|TrainMonitor|INFO] ep: 113,	T: 7,984,	G: 67,	avg_r: 1,	avg_G: 102,	t: 67,	dt: 32.664ms,	SimpleTD/loss: 0.0199,	VanillaPG/loss: 0.527


Epoch reward 67.0


[a2c|TrainMonitor|INFO] ep: 114,	T: 8,160,	G: 175,	avg_r: 1,	avg_G: 109,	t: 175,	dt: 32.588ms,	SimpleTD/loss: 0.0051,	VanillaPG/loss: 0.517


Epoch reward 175.0


[a2c|TrainMonitor|INFO] ep: 115,	T: 8,337,	G: 176,	avg_r: 1,	avg_G: 116,	t: 176,	dt: 31.733ms,	SimpleTD/loss: 0.0285,	VanillaPG/loss: 0.517


Epoch reward 176.0


[a2c|TrainMonitor|INFO] ep: 116,	T: 8,521,	G: 183,	avg_r: 1,	avg_G: 123,	t: 183,	dt: 32.438ms,	SimpleTD/loss: 0.00384,	VanillaPG/loss: 0.516


Epoch reward 183.0


[a2c|TrainMonitor|INFO] ep: 117,	T: 8,606,	G: 84,	avg_r: 1,	avg_G: 119,	t: 84,	dt: 32.266ms,	SimpleTD/loss: 0.106,	VanillaPG/loss: 0.515


Epoch reward 84.0


[a2c|TrainMonitor|INFO] ep: 118,	T: 8,727,	G: 120,	avg_r: 1,	avg_G: 119,	t: 120,	dt: 31.630ms,	SimpleTD/loss: 0.0267,	VanillaPG/loss: 0.52


Epoch reward 120.0


[a2c|TrainMonitor|INFO] ep: 119,	T: 8,845,	G: 117,	avg_r: 1,	avg_G: 119,	t: 117,	dt: 31.335ms,	SimpleTD/loss: 0.117,	VanillaPG/loss: 0.516


Epoch reward 117.0


[a2c|TrainMonitor|INFO] ep: 120,	T: 8,994,	G: 148,	avg_r: 1,	avg_G: 122,	t: 148,	dt: 27.795ms,	SimpleTD/loss: 0.0333,	VanillaPG/loss: 0.517


Epoch reward 148.0


[a2c|TrainMonitor|INFO] ep: 121,	T: 9,106,	G: 111,	avg_r: 1,	avg_G: 121,	t: 111,	dt: 31.457ms,	SimpleTD/loss: 0.0196,	VanillaPG/loss: 0.517


Epoch reward 111.0


[a2c|TrainMonitor|INFO] ep: 122,	T: 9,200,	G: 93,	avg_r: 1,	avg_G: 118,	t: 93,	dt: 32.024ms,	SimpleTD/loss: 0.137,	VanillaPG/loss: 0.517


Epoch reward 93.0


[a2c|TrainMonitor|INFO] ep: 123,	T: 9,349,	G: 148,	avg_r: 1,	avg_G: 121,	t: 148,	dt: 31.956ms,	SimpleTD/loss: 0.0178,	VanillaPG/loss: 0.512


Epoch reward 148.0


[a2c|TrainMonitor|INFO] ep: 124,	T: 9,550,	G: 200,	avg_r: 1,	avg_G: 129,	t: 200,	dt: 32.403ms,	SimpleTD/loss: 0.000969,	VanillaPG/loss: 0.517


Epoch reward 209.0


[a2c|TrainMonitor|INFO] ep: 125,	T: 9,712,	G: 161,	avg_r: 1,	avg_G: 132,	t: 161,	dt: 32.308ms,	SimpleTD/loss: 0.1,	VanillaPG/loss: 0.518


Epoch reward 161.0


[a2c|TrainMonitor|INFO] ep: 126,	T: 9,913,	G: 200,	avg_r: 1,	avg_G: 139,	t: 200,	dt: 32.427ms,	SimpleTD/loss: 0.0331,	VanillaPG/loss: 0.563


Epoch reward 209.0


[a2c|TrainMonitor|INFO] ep: 127,	T: 10,114,	G: 200,	avg_r: 1,	avg_G: 145,	t: 200,	dt: 31.035ms,	SimpleTD/loss: 0.0111,	VanillaPG/loss: 0.511


Epoch reward 209.0


[a2c|TrainMonitor|INFO] ep: 128,	T: 10,315,	G: 200,	avg_r: 1,	avg_G: 150,	t: 200,	dt: 29.520ms,	SimpleTD/loss: 0.000332,	VanillaPG/loss: 0.515


Epoch reward 209.0


[a2c|TrainMonitor|INFO] ep: 129,	T: 10,516,	G: 200,	avg_r: 1,	avg_G: 155,	t: 200,	dt: 29.264ms,	SimpleTD/loss: 6.81e-06,	VanillaPG/loss: 0.516


Epoch reward 209.0


[a2c|TrainMonitor|INFO] ep: 130,	T: 10,717,	G: 200,	avg_r: 1,	avg_G: 160,	t: 200,	dt: 30.888ms,	SimpleTD/loss: 4.27e-06,	VanillaPG/loss: 0.516


Epoch reward 209.0


[a2c|TrainMonitor|INFO] ep: 131,	T: 10,918,	G: 200,	avg_r: 1,	avg_G: 164,	t: 200,	dt: 30.620ms,	SimpleTD/loss: 3.99e-06,	VanillaPG/loss: 0.516


Epoch reward 209.0


[a2c|TrainMonitor|INFO] ep: 132,	T: 11,119,	G: 200,	avg_r: 1,	avg_G: 167,	t: 200,	dt: 30.457ms,	SimpleTD/loss: 4.39e-06,	VanillaPG/loss: 0.516


Epoch reward 209.0


[a2c|TrainMonitor|INFO] ep: 133,	T: 11,320,	G: 200,	avg_r: 1,	avg_G: 171,	t: 200,	dt: 29.840ms,	SimpleTD/loss: 4.26e-06,	VanillaPG/loss: 0.516


Epoch reward 209.0


[a2c|TrainMonitor|INFO] ep: 134,	T: 11,521,	G: 200,	avg_r: 1,	avg_G: 174,	t: 200,	dt: 29.856ms,	SimpleTD/loss: 5.05e-06,	VanillaPG/loss: 0.516


Epoch reward 209.0


[a2c|TrainMonitor|INFO] ep: 135,	T: 11,722,	G: 200,	avg_r: 1,	avg_G: 176,	t: 200,	dt: 32.319ms,	SimpleTD/loss: 3.77e-06,	VanillaPG/loss: 0.516


Epoch reward 209.0


[a2c|TrainMonitor|INFO] ep: 136,	T: 11,923,	G: 200,	avg_r: 1,	avg_G: 179,	t: 200,	dt: 31.674ms,	SimpleTD/loss: 3.7e-06,	VanillaPG/loss: 0.516


Epoch reward 209.0


[a2c|TrainMonitor|INFO] ep: 137,	T: 12,124,	G: 200,	avg_r: 1,	avg_G: 181,	t: 200,	dt: 31.746ms,	SimpleTD/loss: 4.83e-06,	VanillaPG/loss: 0.516


Epoch reward 209.0


[a2c|TrainMonitor|INFO] ep: 138,	T: 12,325,	G: 200,	avg_r: 1,	avg_G: 183,	t: 200,	dt: 30.411ms,	SimpleTD/loss: 6.98e-05,	VanillaPG/loss: 0.516


Epoch reward 209.0


[a2c|TrainMonitor|INFO] ep: 139,	T: 12,526,	G: 200,	avg_r: 1,	avg_G: 184,	t: 200,	dt: 32.387ms,	SimpleTD/loss: 0.000834,	VanillaPG/loss: 0.516


Epoch reward 209.0


[a2c|TrainMonitor|INFO] ep: 140,	T: 12,727,	G: 200,	avg_r: 1,	avg_G: 186,	t: 200,	dt: 32.248ms,	SimpleTD/loss: 0.000214,	VanillaPG/loss: 0.517


Epoch reward 209.0


[a2c|TrainMonitor|INFO] ep: 141,	T: 12,928,	G: 200,	avg_r: 1,	avg_G: 187,	t: 200,	dt: 31.847ms,	SimpleTD/loss: 5.91e-05,	VanillaPG/loss: 0.515


Epoch reward 209.0


[a2c|TrainMonitor|INFO] ep: 142,	T: 13,129,	G: 200,	avg_r: 1,	avg_G: 189,	t: 200,	dt: 32.620ms,	SimpleTD/loss: 0.000118,	VanillaPG/loss: 0.516


Epoch reward 209.0


[a2c|TrainMonitor|INFO] ep: 143,	T: 13,330,	G: 200,	avg_r: 1,	avg_G: 190,	t: 200,	dt: 30.155ms,	SimpleTD/loss: 0.000139,	VanillaPG/loss: 0.516


Epoch reward 209.0


[a2c|TrainMonitor|INFO] ep: 144,	T: 13,531,	G: 200,	avg_r: 1,	avg_G: 191,	t: 200,	dt: 27.240ms,	SimpleTD/loss: 0.000122,	VanillaPG/loss: 0.516


Epoch reward 209.0


[a2c|TrainMonitor|INFO] ep: 145,	T: 13,732,	G: 200,	avg_r: 1,	avg_G: 192,	t: 200,	dt: 29.924ms,	SimpleTD/loss: 0.000192,	VanillaPG/loss: 0.515


Epoch reward 209.0


[a2c|TrainMonitor|INFO] ep: 146,	T: 13,933,	G: 200,	avg_r: 1,	avg_G: 193,	t: 200,	dt: 32.106ms,	SimpleTD/loss: 0.000489,	VanillaPG/loss: 0.515


Epoch reward 209.0


[a2c|TrainMonitor|INFO] ep: 147,	T: 14,134,	G: 200,	avg_r: 1,	avg_G: 193,	t: 200,	dt: 29.369ms,	SimpleTD/loss: 0.000495,	VanillaPG/loss: 0.518


Epoch reward 209.0


[a2c|TrainMonitor|INFO] ep: 148,	T: 14,335,	G: 200,	avg_r: 1,	avg_G: 194,	t: 200,	dt: 31.016ms,	SimpleTD/loss: 0.00026,	VanillaPG/loss: 0.515


Epoch reward 209.0


[a2c|TrainMonitor|INFO] ep: 149,	T: 14,536,	G: 200,	avg_r: 1,	avg_G: 195,	t: 200,	dt: 28.119ms,	SimpleTD/loss: 0.000225,	VanillaPG/loss: 0.516


Epoch reward 209.0
Epoch reward 209.0


In [7]:
metrics_pi

{'RPPRegularizer/basic_l2': DeviceArray(0., dtype=float32),
 'RPPRegularizer/basic_wd': DeviceArray(1.e-04, dtype=float32),
 'RPPRegularizer/equiv_l2': DeviceArray(5165.331, dtype=float32),
 'RPPRegularizer/equiv_wd': DeviceArray(1.e-04, dtype=float32),
 'VanillaPG/grads_max': DeviceArray(0.08917519, dtype=float32),
 'VanillaPG/grads_norm': DeviceArray(0.48706943, dtype=float32),
 'VanillaPG/kl_div_old': DeviceArray(0.23434481, dtype=float32),
 'VanillaPG/loss': DeviceArray(0.48180765, dtype=float32),
 'VanillaPG/loss_bare': DeviceArray(-0.03472544, dtype=float32)}

In [8]:
equiv_l2 = 0.0
basic_l2 = 0.0
for k1, v1 in pi.params.items():
    if "bi" not in k1:
        for k2, v2 in v1.items():
    #         if "bi" not in k2:
    #             print(k2)
            if k2.endswith("_basic"):
                basic_l2 += (v2 ** 2).sum()
            elif k2.endswith("w"):
                equiv_l2 += (v2 ** 2).sum()
print("Eq l2 = ", equiv_l2)
print("Basic l2 = ", basic_l2)

Eq l2 =  338.9434
Basic l2 =  0.0
