In [6]:
import gym
import haiku as hk
import jax
import jax.numpy as jnp
import jax.nn as nn
import jax.random as random
import numpy as np
import optax
import matplotlib.pyplot as plt
from jax import tree_map
from jax import grad, vjp, value_and_grad
from jax import lax
from jax import jit
from tqdm import tqdm
%matplotlib inline 

In [7]:
def policy_fn(observation):
  mlp = hk.Sequential(
      [
       hk.Linear(20), nn.relu,
       hk.Linear(2)
      ]
  )
  return mlp(observation)

In [8]:

def val_fn(observation):
  mlp = hk.Sequential(
      [
       hk.Linear(64), nn.relu,
       hk.Linear(1)
      ]
  )
  return mlp(observation)

In [9]:
def main():

  NUM_EPISODE = 2000
  GAMMA = 0.95
  BATCH = 2
  RECORD_INTERVAL = 20
  ANNEAL_SCHEDULE = 100

  rng = hk.PRNGSequence(0)

  actor = hk.without_apply_rng(hk.transform(policy_fn))
  critic = hk.without_apply_rng(hk.transform(val_fn))

  dummy_obs = jnp.array([0,0,0,0],dtype=jnp.float32)

  actor_params = actor.init(rng.next(), dummy_obs)
  critic_params = critic.init(rng.next(), dummy_obs)


  actor_opt = optax.adam(1e-4)
  critic_opt = optax.adam(1e-3)

  actor_opt_state = actor_opt.init(actor_params)
  critic_opt_state = critic_opt.init(critic_params)


  @jit
  def actor_obj(params, obs, rand):
    logits = actor.apply(params, obs)
    action = lax.stop_gradient(random.categorical(rand,logits))
    return nn.log_softmax(logits)[action], action 
  
  @jit
  def update_actor(params, opt_state, gradient):
    updates, new_opt_state = actor_opt.update(gradient, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state
  
#   @jit
#   def update_actor(params, gradient, lr):
#     params = tree_map(lambda p, g: p + lr*g, params, gradient)
#     return params
  
  @jit
  def accumulate_actor_grad(cum_grad, log_grad, tde, discount):
    new_cum_grad = tree_map(lambda cg, lg: cg + discount*tde*lg, cum_grad, log_grad)
    return new_cum_grad
  
  @jit
  def zero_tree(tree):
    return tree_map(lambda x: 0*x, tree)
  
  
  
  
  @jit
  def critic_obj(params, obs, reward, next_obs, gamma):
    v_t = critic.apply(params, obs)
    target = reward + gamma*critic.apply(params, next_obs)
    td_error = lax.stop_gradient(target) - v_t
    return (td_error**2)[0], td_error
  @jit
  def update_critic(params, opt_state, gradient):
    updates, new_opt_state = critic_opt.update(gradient, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state

  
  env = gym.make('CartPole-v0')
  env.seed(1)
  episode_lengths = []
  avg_tde = []
  lr = 1e-6

  for eps in tqdm(range(NUM_EPISODE)):
    # initialization for each episode
    o_t = jnp.array(env.reset(), dtype=jnp.float32)
    done = False
    I = 1.0
    cumulative_tde = 0
    num_step = 0
    cum_actor_grad = zero_tree(actor_params)
    
    # learning rate annealing
    if eps % ANNEAL_SCHEDULE == 0:
      lr = lr/2
    
    while not done:
      likelihood_grad, action = grad(actor_obj, has_aux=True)(actor_params, o_t, rng.next())
      env.render()

      num_step += 1

      o_tp1, reward, done, _ = env.step(action.item())
      o_tp1 = jnp.array(o_tp1, dtype=jnp.float32)

      critic_grad, tde = grad(critic_obj,has_aux=True)(critic_params, o_t, reward, o_tp1,GAMMA)
      cumulative_tde += abs(tde.item())
      critic_params, critic_opt_state = update_critic(critic_params, critic_opt_state, critic_grad)
      
      cum_actor_grad = accumulate_actor_grad(cum_actor_grad, likelihood_grad, tde.item(), I)
      if num_step % BATCH == 0:
        actor_params, actor_opt_state = update_actor(actor_params, actor_opt_state, cum_actor_grad)
        cum_actor_grad = zero_tree(cum_actor_grad)

      I = GAMMA*I
      o_t = o_tp1

    actor_params, actor_opt_state = update_actor(actor_params, actor_opt_state,cum_actor_grad)
    
    if eps % RECORD_INTERVAL == 0:
      episode_lengths.append(num_step)
      avg_tde.append(cumulative_tde/num_step)
    
  plt.figure()
  plt.plot(episode_lengths,'b')
  plt.xlabel('episode')
  plt.ylabel('length')

  plt.figure()
  plt.plot(avg_tde,'r')
  plt.xlabel('episode')
  plt.ylabel('TD-error')

In [None]:
main()

  3%|█▎                                       | 65/2000 [00:27<15:19,  2.10it/s]