In [1]:
import coax
import gym
import haiku as hk
import jax
import jax.numpy as jnp
import optax
from coax.value_losses import mse
from RPPRegularizer import RPPRegularizer

import matplotlib.pyplot as plt

# the name of this script
name = 'a2c'

# the cart-pole MDP
# env = gym.make('CartPole-v0')
env = gym.make("rpp_gym:InclinedCartpole-v0")
env.alpha = 0.2

In [2]:
env = coax.wrappers.TrainMonitor(env, name=name, tensorboard_dir=f"./data/tensorboard/{name}")

In [3]:
# env.spec.max_episode_steps = 200
# env.spec.reward_threshold = 195.0

In [4]:
from emlp import T, Scalar
from emlp.groups import SO, S, O, Trivial,Z
import emlp.nn.haiku as ehk
from emlp.reps import Rep
from emlp.nn import gated,gate_indices,uniform_rep
from math import prod
from representations import PseudoScalar
from mixed_emlp_haiku import MixedEMLP

## Trivial
# group=Trivial(2)
# rep_in = T(0)*prod(env.observation_space.shape)
# rep_out = T(0)*env.action_space.n#prod(env.action_space.shape)

## Reflection
group=Z(2)
rep_in = PseudoScalar()*prod(env.observation_space.shape)
rep_out = T(1)#*env.action_space.n#prod(env.action_space.shape)

ch = 100
num_layers = 3


# nn_pi = ehk.EMLP(rep_in,rep_out,group,ch=ch,num_layers=num_layers)
# nn_v = ehk.EMLP(rep_in,T(0),group,ch=ch,num_layers=num_layers)

# nn_pi = ehk.MLP(rep_in,rep_out(group),group,ch=ch,num_layers=num_layers)
# nn_v = ehk.MLP(rep_in,T(0),group,ch=ch,num_layers=num_layers)

nn_pi = MixedEMLP(rep_in,rep_out(group),group,ch=ch,num_layers=num_layers)
nn_v = MixedEMLP(rep_in,T(0),group,ch=ch,num_layers=num_layers)


def func_pi(S, is_training):
    return {'logits': nn_pi(S)}


def func_v(S, is_training):
    return nn_v(S).reshape(-1)



# def func_pi(S, is_training):
#     logits = hk.Sequential((
#         hk.Linear(16), jax.nn.relu,
#         hk.Linear(16), jax.nn.relu,
#         hk.Linear(16), jax.nn.relu,
#         hk.Linear(env.action_space.n, w_init=jnp.zeros)
#     ))
#     return {'logits': logits(S)}



# def func_v(S, is_training):
#     value = hk.Sequential((
#         hk.Linear(32), jax.nn.relu,
#         hk.Linear(32), jax.nn.relu,
#         hk.Linear(32), jax.nn.relu,
#         hk.Linear(32), jax.nn.relu,
#         hk.Linear(32), jax.nn.relu,
#         hk.Linear(1, w_init=jnp.zeros), jnp.ravel
#     ))
#     return value(S)


[a2c|root|INFO] Initing EMLP
[a2c|root|INFO] Linear W components:400 rep:96P+48P⊗V+20P⊗V²+8P⊗V³+4P⊗V⁴
[a2c|root|INFO] P cache miss
[a2c|root|INFO] Solving basis for P, for G=Z(2)
[a2c|root|INFO] P⊗V cache miss
[a2c|root|INFO] Solving basis for P⊗V, for G=Z(2)
[a2c|root|INFO] P⊗V² cache miss
[a2c|root|INFO] Solving basis for P⊗V², for G=Z(2)
[a2c|root|INFO] P⊗V³ cache miss
[a2c|root|INFO] Solving basis for P⊗V³, for G=Z(2)
[a2c|root|INFO] P⊗V⁴ cache miss
[a2c|root|INFO] Solving basis for P⊗V⁴, for G=Z(2)
[a2c|root|INFO] V cache miss
[a2c|root|INFO] Solving basis for V, for G=Z(2)
[a2c|root|INFO] V² cache miss
[a2c|root|INFO] Solving basis for V², for G=Z(2)
[a2c|root|INFO] V³ cache miss
[a2c|root|INFO] Solving basis for V³, for G=Z(2)
[a2c|root|INFO] V⁴ cache miss
[a2c|root|INFO] Solving basis for V⁴, for G=Z(2)
[a2c|root|INFO] Linear W components:10000 rep:576V⁰+576V+384V²+216V³+121V⁴+44V⁵+14V⁶+4V⁷+V⁸
[a2c|root|INFO] V⁵ cache miss
[a2c|root|INFO] Solving basis for V⁵, for G=Z(2)
[a2c|r

In [5]:
# these optimizers collect batches of grads before applying updates
optimizer_v = optax.chain(optax.apply_every(k=32), optax.adam(0.002))
optimizer_pi = optax.chain(optax.apply_every(k=32), optax.adam(0.001))


# value function and its derived policy
v = coax.V(func_v, env)
pi = coax.Policy(func_pi, env)

In [6]:
%pdb

Automatic pdb calling has been turned ON


In [7]:
# experience tracer
pi_regularizer = RPPRegularizer(pi, basic_wd=1000., equiv_wd=0.)
# v_regularizer = RPPRegularizer(v, basic_wd=1e-5, equiv_wd=0.)
tracer = coax.reward_tracing.NStep(n=1, gamma=0.9)

# updaters
vanilla_pg = coax.policy_objectives.VanillaPG(pi, optimizer=optimizer_pi,
                                              regularizer=pi_regularizer)
simple_td = coax.td_learning.SimpleTD(v, loss_function=mse, optimizer=optimizer_v,
                                     policy_regularizer=pi_regularizer)

epoch_rewards = []

# train
for ep in range(100):
    s = env.reset()
    er = 0
    for t in range(env.spec.max_episode_steps):
        a = pi(s)
        s_next, r, done, info = env.step(a)
        
        if done and (t == env.spec.max_episode_steps - 1):
            r = 1 / (1 - tracer.gamma)
        er+=r
        tracer.add(s, a, r, done)
        while tracer:
            transition_batch = tracer.pop()
            metrics_v, td_error = simple_td.update(transition_batch, return_td_error=True)
            metrics_pi = vanilla_pg.update(transition_batch, td_error)
            env.record_metrics(metrics_v)
            env.record_metrics(metrics_pi)

        if done:
            break

        s = s_next
    
    print("Epoch reward",er)
    epoch_rewards.append(er)
    # early stopping
    if env.avg_G > env.spec.reward_threshold:
        break


# run env one more time to render
#coax.utils.generate_gif(env, policy=pi, filepath=f"./data/{name}.gif", duration=25)

NameError: name 'test' is not defined

> [0;32m/home/greg_b/residual-pathway-priors/RL/RPPRegularizer.py[0m(25)[0;36mfunction[0;34m()[0m
[0;32m     23 [0;31m        [0;32mdef[0m [0mfunction[0m[0;34m([0m[0mdist_params[0m[0;34m,[0m [0mequiv_wd[0m[0;34m,[0m [0mbasic_wd[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     24 [0;31m[0;31m#             print(dist_params)[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 25 [0;31m            [0;32mif[0m [0mtest[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     26 [0;31m                [0mprint[0m[0;34m([0m[0mtemperatjflkdasjfklas[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     27 [0;31m            [0mequiv_l2[0m [0;34m=[0m [0;36m0.0[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  dist_params


{'logits': Traced<ShapedArray(float32[1,2])>with<DynamicJaxprTrace(level=0/3)>}


ipdb>  u


> [0;32m/home/greg_b/miniconda3/envs/rpp/lib/python3.8/site-packages/jax/linear_util.py[0m(166)[0;36mcall_wrapped[0;34m()[0m
[0;32m    164 [0;31m[0;34m[0m[0m
[0m[0;32m    165 [0;31m    [0;32mtry[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 166 [0;31m      [0mans[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mf[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mdict[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mparams[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    167 [0;31m    [0;32mexcept[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    168 [0;31m      [0;31m# Some transformations yield from inside context managers, so we have to[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/home/greg_b/miniconda3/envs/rpp/lib/python3.8/site-packages/jax/interpreters/partial_eval.py[0m(1188)[0;36mtrace_to_subjaxpr_dynamic[0;34m()[0m
[0;32m   1186 [0;31m    [0mtrace[0m [0;34m=[0m [0mDynamicJaxprTrace[0m[0;34m([0m[0mmain[0m[0;34m,[0m [0mcore[0m[0;34m.[0m[0mcur_sublevel[0m[0;34m([0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1187 [0;31m    [0min_tracers[0m [0;34m=[0m [0mmap[0m[0;34m([0m[0mtrace[0m[0;34m.[0m[0mnew_arg[0m[0;34m,[0m [0min_avals[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1188 [0;31m    [0mans[0m [0;34m=[0m [0mfun[0m[0;34m.[0m[0mcall_wrapped[0m[0;34m([0m[0;34m*[0m[0min_tracers[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1189 [0;31m    [0mout_tracers[0m [0;34m=[0m [0mmap[0m[0;34m([0m[0mtrace[0m[0;34m.[0m[0mfull_raise[0m[0;34m,[0m [0mans[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1190 [0;31m    [0mjaxpr[0m[0;34m,[0m 

ipdb>  u


> [0;32m/home/greg_b/miniconda3/envs/rpp/lib/python3.8/site-packages/jax/interpreters/partial_eval.py[0m(1061)[0;36mprocess_call[0;34m()[0m
[0;32m   1059 [0;31m  [0;32mdef[0m [0mprocess_call[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mcall_primitive[0m[0;34m,[0m [0mf[0m[0;34m,[0m [0mtracers[0m[0;34m,[0m [0mparams[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1060 [0;31m    [0min_avals[0m [0;34m=[0m [0;34m[[0m[0mt[0m[0;34m.[0m[0maval[0m [0;32mfor[0m [0mt[0m [0;32min[0m [0mtracers[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1061 [0;31m    [0mjaxpr[0m[0;34m,[0m [0mout_avals[0m[0;34m,[0m [0mconsts[0m [0;34m=[0m [0mtrace_to_subjaxpr_dynamic[0m[0;34m([0m[0mf[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mmain[0m[0;34m,[0m [0min_avals[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1062 [0;31m    [0;32mif[0m [0;32mnot[0m [0mjaxpr[0m[0;34m.[0m[0meqns[0m[0;34m:[0m[0;34m[0m[0;34

ipdb>  u


> [0;32m/home/greg_b/miniconda3/envs/rpp/lib/python3.8/site-packages/jax/core.py[0m(1405)[0;36mprocess[0;34m()[0m
[0;32m   1403 [0;31m[0;34m[0m[0m
[0m[0;32m   1404 [0;31m  [0;32mdef[0m [0mprocess[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mtrace[0m[0;34m,[0m [0mfun[0m[0;34m,[0m [0mtracers[0m[0;34m,[0m [0mparams[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1405 [0;31m    [0;32mreturn[0m [0mtrace[0m[0;34m.[0m[0mprocess_call[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mfun[0m[0;34m,[0m [0mtracers[0m[0;34m,[0m [0mparams[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1406 [0;31m[0;34m[0m[0m
[0m[0;32m   1407 [0;31m  [0;32mdef[0m [0mpost_process[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mtrace[0m[0;34m,[0m [0mout_tracers[0m[0;34m,[0m [0mparams[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  u


> [0;32m/home/greg_b/miniconda3/envs/rpp/lib/python3.8/site-packages/jax/core.py[0m(1393)[0;36mcall_bind[0;34m()[0m
[0;32m   1391 [0;31m  [0mtracers[0m [0;34m=[0m [0mmap[0m[0;34m([0m[0mtop_trace[0m[0;34m.[0m[0mfull_raise[0m[0;34m,[0m [0margs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1392 [0;31m  [0;32mwith[0m [0mmaybe_new_sublevel[0m[0;34m([0m[0mtop_trace[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1393 [0;31m    [0mouts[0m [0;34m=[0m [0mprimitive[0m[0;34m.[0m[0mprocess[0m[0;34m([0m[0mtop_trace[0m[0;34m,[0m [0mfun[0m[0;34m,[0m [0mtracers[0m[0;34m,[0m [0mparams[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1394 [0;31m  [0;32mreturn[0m [0mmap[0m[0;34m([0m[0mfull_lower[0m[0;34m,[0m [0mapply_todos[0m[0;34m([0m[0menv_trace_todo[0m[0;34m([0m[0;34m)[0m[0;34m,[0m [0mouts[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1395 [0;31m[0;34m[0m[0m


ipdb>  q


In [None]:
vanilla_pg

In [None]:

metrics_pi

In [None]:
equiv_l2 = 0.0
basic_l2 = 0.0
for k1, v1 in pi.params.items():
    if "bi" not in k1:
        for k2, v2 in v1.items():
    #         if "bi" not in k2:
    #             print(k2)
            if k2.endswith("w_basic"):
                basic_l2 += (v2 ** 2).sum()
            elif k2.endswith("w"):
                equiv_l2 += (v2 ** 2).sum()
print("Eq l2 = ", equiv_l2)
print("Basic l2 = ", basic_l2)

In [None]:
# coax.utils.dump(pi.params, "./emlp_pi_params.lz4")
# coax.utils.dump(v.params, "./emlp_v_params.lz4")