In [116]:
import coax
from coax.regularizers import Regularizer
import gym

from emlp import T, Scalar
from emlp.groups import SO, S, O, Trivial,Z
import emlp.nn.haiku as ehk
from mixed_emlp_haiku import MixedEMLP
from representations import PseudoScalar
from math import prod
from coax.value_losses import mse

import optax

In [117]:
from coax.utils import jit

In [118]:
env = gym.make("rpp_gym:InclinedCartpole-v0")

In [119]:
group=Z(2)
rep_in = PseudoScalar()*prod(env.observation_space.shape)
rep_out = T(1)#*env.action_space.n#prod(env.action_space.shape)

nn_pi = MixedEMLP(rep_in,rep_out(group),group,ch=100,num_layers=2)
nn_v = MixedEMLP(rep_in,T(0),group,ch=100,num_layers=2)


def func_pi(S, is_training):
    return {'logits': nn_pi(S)}


def func_v(S, is_training):
    return nn_v(S).reshape(-1)

In [120]:
# value function and its derived policy
v = coax.V(func_v, env)
pi = coax.Policy(func_pi, env)

In [132]:

class RPPRegularizer(Regularizer):
    r"""

    Abstract base class for policy regularizers. Check out
    :class:`coax.regularizers.EntropyRegularizer` for a specific example.

    Parameters
    ----------
    f : stochastic function approximator

        The stochastic function approximator (e.g. :class:`coax.Policy`) to regularize.

    """
    def __init__(self, f, equiv_wd=1e-3, basic_wd=1e-3):
        self.equiv_wd = equiv_wd
        self.basic_wd = basic_wd
        self.f = f
        
        def function(dist_params, equiv_wd, basic_wd):
#             print(dist_params)
            equiv_l2 = 0.0
            basic_l2 = 0.0
            for k1, v1 in self.f.params.items():
                for k2, v2 in v1.items():
                    if k2.endswith("_basic"):
                        basic_l2 += (v2 ** 2).sum()
                    else:
                        equiv_l2 += (v2 ** 2).sum()
            print("Eq l2 = ", equiv_l2, ", Eq wd = ", equiv_wd)
            print("Basic l2 = ", basic_l2, ", Basic wd = ", basic_wd)
            return (equiv_wd * equiv_l2) + (basic_wd * basic_l2)
            
        
        def metrics(dist_params, equiv_wd, basic_wd):
            equiv_l2 = 0.0
            basic_l2 = 0.0
            for k1, v1 in self.f.params.items():
                for k2, v2 in v1.items():
                    if k2.endswith("_basic"):
                        basic_l2 += (v2 ** 2).sum()
                    else:
                        equiv_l2 += (v2 ** 2).sum()
                        
            return {'RPPRegularizer/equiv_l2':equiv_l2,
                    'RPPRegularizer/basic_l2':basic_l2,
                    'RPPRegularizer/equiv_wd':equiv_wd,
                    'RPPRegularizer/basic_wd':basic_wd,}
            
            
        self._function = jit(function)
        self._metrics_func = jit(metrics)
            
    @property
    def hyperparams(self):
        return {'equiv_wd':self.equiv_wd, 'basic_wd':self.basic_wd}

    @property
    def function(self):
        return self._function

    @property
    def metrics_func(self):
        return self._metrics_func



In [133]:
pi_regularizer = RPPRegularizer(pi)
tracer = coax.reward_tracing.NStep(n=1, gamma=0.9)

In [134]:
optimizer_v = optax.chain(optax.apply_every(k=32), optax.adam(0.002))
optimizer_pi = optax.chain(optax.apply_every(k=32), optax.adam(0.001))

vanilla_pg = coax.policy_objectives.VanillaPG(pi, optimizer=optimizer_pi, regularizer=pi_regularizer)
simple_td = coax.td_learning.SimpleTD(v, loss_function=mse, optimizer=optimizer_v)

In [137]:
s = env.reset()
for t in range(env.spec.max_episode_steps):
    a = pi(s)
    s_next, r, done, info = env.step(a)
    print(done)

    if done and (t == env.spec.max_episode_steps - 1):
        r = 1 / (1 - tracer.gamma)
        
    tracer.add(s, a, r, done)
    while tracer:
        transition_batch = tracer.pop()
        metrics_v, td_error = simple_td.update(transition_batch, return_td_error=True)
        metrics_pi = vanilla_pg.update(transition_batch, td_error)
#         env.record_metrics(metrics_v)
#         env.record_metrics(metrics_pi)
    if done:
        break
    s = s_next

False
False
False
False
False
False
False
False
False
False
False
True


In [125]:
if tracer:
    print("yes")