In [0]:
!pip install git+https://github.com/deepmind/dm-haiku > /dev/null 2>&1
!pip install git+git://github.com/deepmind/rlax.git > /dev/null 2>&1

In [0]:
import jax
from jax import jit, grad, vmap
import jax.numpy as jnp
from jax.experimental import optix
import haiku as hk
import rlax

import gym
from functools import partial
import numpy as np

import random
from IPython.display import clear_output
from collections import deque
from typing import Callable, Mapping, NamedTuple, Tuple, Sequence

import matplotlib.pyplot as plt
%matplotlib inline
COLOR = 'white'
plt.rcParams['text.color'] = COLOR
plt.rcParams['axes.labelcolor'] = COLOR
plt.rcParams['xtick.color'] = COLOR
plt.rcParams['ytick.color'] = COLORT

# Hyperparameters

In [0]:
max_episodes = 1000
max_steps    = 300
BATCH_SIZE   = 128
GAMMA        = 0.999
BUFFER_SIZE  = 1000000
NOISE        = 0.1
POLYAK       = 0.995
SEED         = 1729

# Plotting

In [0]:
def plot(episode, rewards):
    clear_output(True)
    plt.figure(figsize=(20,5))
    plt.subplot(131)
    plt.title(f'episode {episode}. reward: {np.mean(rewards[-10:])}')
    plt.plot(rewards)
    plt.show()
    # plt.savefig(fname=f"~/eps_{episode}")

# Replay Buffer

In [0]:
class ReplayBuffer(object):
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        state      = jnp.expand_dims(state, 0)
        next_state = jnp.expand_dims(next_state, 0)
            
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
        return {'state': jnp.concatenate(state), 'action':jnp.asarray(action), 
                'reward':jnp.asarray(reward), 
                'next_state':jnp.concatenate(next_state), 'done':jnp.asarray(done)}
    
    def __len__(self):
        return len(self.buffer)

# Utils

In [0]:
@jit
def scale_action(lower_bound, upper_bound, action):
  #(action - min(tanh))/(max(tanh) - min(tanh)) *(max_act - min_act) + min_act
  action = lower_bound + (action + 1.0) * 0.5 * (upper_bound - lower_bound) 
  action = jnp.clip(action, lower_bound, upper_bound)
        
  return action

In [0]:
@jit
def std(a):
  return jnp.std(jnp.asarray(a))

# Network

In [0]:
@jit
def polyak_average(old, new):
  return jax.tree_multimap(
    lambda p_ema, p: p_ema * POLYAK + p * (1. - POLYAK), old, new)

In [0]:
def build_actor(num_actions: int) -> hk.Transformed:

  def actor(obs):
    network = hk.Sequential(
        [hk.Linear(400), jax.nn.relu, hk.Linear(300), jax.nn.relu,
         hk.Linear(num_actions), jnp.tanh])
    return network(obs)

  return hk.transform(actor)

In [0]:
def build_critic() -> hk.Transformed:

  def q(s, a):
    obs = jnp.concatenate((s, a), axis=1)
    network = hk.Sequential(
        [hk.Linear(400), jax.nn.relu, hk.Linear(300), jax.nn.relu,
         hk.Linear(1)])
    return network(obs)

  return hk.transform(q)

In [0]:
def main_loop():
  # Build env
  noise = NOISE
  env = gym.make("Pendulum-v0")
  replay_buffer = ReplayBuffer(BUFFER_SIZE)
  rng = hk.PRNGSequence(jax.random.PRNGKey(SEED))
  lower_bound   = env.action_space.low
  upper_bound = env.action_space.high      
  
  #logging
  rewards = []
  returns = []
  q = []
  g_minus_q = []
  episode_reward = 0
  ep_idx = 0

  # Build and initialize Network.
  action_dim = env.action_space.shape[0]
  state = env.reset()

  actor = build_actor(num_actions)
  actor_params = actor.init(next(rng), state)
  actor_target = actor_params

  action = actor.apply(actor_params, state)
  critic = build_critic()
  critic_params = critic.init(next(rng), jnp.expand_dims(state, axis=0), 
                                  jnp.expand_dims(sample_action, axis=0))
  critic_target = critic_params

  # Build and initialize optimizer.
  actor_optimizer = optix.adam(1e-4)
  actor_state = actor_optimizer.init(actor_params)

  critic_optimizer = optix.adam(1e-3)
  critic_state = critic_optimizer.init(critic_params) 

  @jax.jit
  def policy(net_params, key, obs, stddev=1.):
    """Sample action from epsilon-greedy policy."""
    action = actor.apply(net_params, obs)
    a = rlax.add_gaussian_noise(key, action, stddev)
    return a
  
  batched_dpg_loss = vmap(rlax.dpg_loss)

  @vmap 
  def q_loss(reward, done, q_target, q):
    td_target = reward + GAMMA*(1. - done)*q_target
    return (jax.lax.stop_gradient(td_target) - q)**2


  def actor_step(actor_optimizer, critic_params, actor_target, opt_state, state):
    def actor_loss(actor_params, critic_params, state):
        a_t = actor.apply(actor_params, state)
        dqda_t = grad(critic.apply, argnums=2)(critic_params, state, action)
        return jnp.mean(batched_dpg_loss(a_t, dqda_t))
    
    actor_grad = grad(actor_loss)(actor_params, critic_params, state)
    updates, opt_state = actor_optimizer.update(actor_grad, opt_state)
    actor_params = optix.apply_updates(actor_params, updates)

  def critic_step(critic_optimizer, critic_params, critic_target, actor_target, 
                  opt_state, batch):
    done = batch['done']
    state = batch['state']
    reward = batch['reward']
    action = batch['action']
    next_state = batch['next_state']
    
    def critic_loss(net_params, target_params, actor_target)
      state, action, reward, next_state = batch
      q = critic.apply(net_params, state, action)
      q_target = critic.apply(target_params, next_state, 
                              actor.apply(actor_target, next_state)) 
      return jnp.mean(q_loss(reward, done, q_target, q))
    
    critic_grad = grad(critic_loss)(critic_params, critic_target, 
                                    actor_target)
    updates, opt_state = critic_optimizer.update(critic_grad, opt_state)
    critic_params = optix.apply_updates(critic_params, updates)
  
  @jax.jit
  def update(critic_params, actor_params, critic_target, actor_target,
             actor_state, critic_state, batch):
    """Update network weights wrt DPG-learning loss."""

    loss, dloss_dtheta = jax.value_and_grad(dqn_learning_loss)(net_params, 
                                                         target_params, batch)
    updates, opt_state = optimizer.update(dloss_dtheta, opt_state)
    net_params = optix.apply_updates(net_params, updates)
    return net_params, opt_state, loss