In [None]:
import coax
import gym
import haiku as hk
import jax
import jax.numpy as jnp
import optax
from coax.value_losses import mse

import matplotlib.pyplot as plt

# the name of this script
name = 'a2c'

# the cart-pole MDP
# env = gym.make('CartPole-v0')
env = gym.make("rpp_gym:InclinedCartpole-v0")

In [None]:
env = coax.wrappers.TrainMonitor(env, name=name, tensorboard_dir=f"./data/tensorboard/{name}")

In [None]:
# env.spec.max_episode_steps = 200
# env.spec.reward_threshold = 195.0

In [None]:
from emlp import T, Scalar
from emlp.groups import SO, S, O, Trivial,Z
import emlp.nn.haiku as ehk
from emlp.reps import Rep
from emlp.nn import gated,gate_indices,uniform_rep
from math import prod
from representations import PseudoScalar
from mixed_emlp_haiku import MixedEMLP

## Trivial
# group=Trivial(2)
# rep_in = T(0)*prod(env.observation_space.shape)
# rep_out = T(0)*env.action_space.n#prod(env.action_space.shape)

## Reflection
group=Z(2)
rep_in = PseudoScalar()*prod(env.observation_space.shape)
rep_out = T(1)#*env.action_space.n#prod(env.action_space.shape)

nn_pi = ehk.EMLP(rep_in,rep_out,group,ch=100,num_layers=2)
nn_v = ehk.EMLP(rep_in,T(0),group,ch=100,num_layers=2)

# nn_pi = ehk.MLP(rep_in,rep_out(group),group,ch=100,num_layers=2)
# nn_v = ehk.MLP(rep_in,T(0),group,ch=100,num_layers=2)

nn_pi = MixedEMLP(rep_in,rep_out(group),group,ch=100,num_layers=2)
nn_v = MixedEMLP(rep_in,T(0),group,ch=100,num_layers=2)


def func_pi(S, is_training):
    return {'logits': nn_pi(S)}


def func_v(S, is_training):
    return nn_v(S).reshape(-1)



# def func_pi(S, is_training):
#     logits = hk.Sequential((
#         hk.Linear(16), jax.nn.relu,
#         hk.Linear(16), jax.nn.relu,
#         hk.Linear(16), jax.nn.relu,
#         hk.Linear(env.action_space.n, w_init=jnp.zeros)
#     ))
#     return {'logits': logits(S)}



# def func_v(S, is_training):
#     value = hk.Sequential((
#         hk.Linear(32), jax.nn.relu,
#         hk.Linear(32), jax.nn.relu,
#         hk.Linear(32), jax.nn.relu,
#         hk.Linear(32), jax.nn.relu,
#         hk.Linear(32), jax.nn.relu,
#         hk.Linear(1, w_init=jnp.zeros), jnp.ravel
#     ))
#     return value(S)


In [None]:
# these optimizers collect batches of grads before applying updates
optimizer_v = optax.chain(optax.apply_every(k=32), optax.adam(0.002))
optimizer_pi = optax.chain(optax.apply_every(k=32), optax.adam(0.001))


# value function and its derived policy
v = coax.V(func_v, env)
pi = coax.Policy(func_pi, env)

In [None]:
store = v.params

In [None]:
# experience tracer
tracer = coax.reward_tracing.NStep(n=1, gamma=0.9)

# updaters
vanilla_pg = coax.policy_objectives.VanillaPG(pi, optimizer=optimizer_pi)
simple_td = coax.td_learning.SimpleTD(v, loss_function=mse, optimizer=optimizer_v)

epoch_rewards = []

# train
for ep in range(1000):
    s = env.reset()
    er = 0
    for t in range(env.spec.max_episode_steps):
        a = pi(s)
        s_next, r, done, info = env.step(a)
        
        if done and (t == env.spec.max_episode_steps - 1):
            r = 1 / (1 - tracer.gamma)
        er+=r
        tracer.add(s, a, r, done)
        while tracer:
            transition_batch = tracer.pop()
            metrics_v, td_error = simple_td.update(transition_batch, return_td_error=True)
            metrics_pi = vanilla_pg.update(transition_batch, td_error)
            env.record_metrics(metrics_v)
            env.record_metrics(metrics_pi)

        if done:
            break

        s = s_next
    
    print("Epoch reward",er)
    epoch_rewards.append(er)
    # early stopping
    if env.avg_G > env.spec.reward_threshold:
        break


# run env one more time to render
#coax.utils.generate_gif(env, policy=pi, filepath=f"./data/{name}.gif", duration=25)

In [None]:
coax.utils.dump(pi.params, "./emlp_pi_params.lz4")
coax.utils.dump(v.params, "./emlp_v_params.lz4")