In [0]:
!pip install git+https://github.com/deepmind/dm-haiku > /dev/null 2>&1
!pip install git+git://github.com/deepmind/rlax.git > /dev/null 2>&1

In [0]:
import jax
from jax import jit, grad, vmap
import jax.numpy as jnp
from jax.experimental import optix
import haiku as hk
import rlax

import gym
from functools import partial
import numpy as np

import random
from IPython.display import clear_output
from collections import deque
from typing import Callable, Mapping, NamedTuple, Tuple, Sequence

import matplotlib.pyplot as plt
%matplotlib inline
COLOR = 'white'
plt.rcParams['text.color'] = COLOR
plt.rcParams['axes.labelcolor'] = COLOR
plt.rcParams['xtick.color'] = COLOR
plt.rcParams['ytick.color'] = COLORT

# Hyperparameters

In [0]:
max_episodes = 1000
max_steps    = 300
BATCH_SIZE   = 128
GAMMA        = 0.999
BUFFER_SIZE  = 1000000
NOISE        = 0.1
POLYAK       = 0.995
SEED         = 1729

# Plotting

In [0]:
def plot(episode, rewards):
    clear_output(True)
    plt.figure(figsize=(20,5))
    plt.subplot(131)
    plt.title(f'episode {episode}. reward: {np.mean(rewards[-10:])}')
    plt.plot(rewards)
    plt.show()
    # plt.savefig(fname=f"~/eps_{episode}")

# Replay Buffer

In [0]:
class ReplayBuffer(object):
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        state      = jnp.expand_dims(state, 0)
        next_state = jnp.expand_dims(next_state, 0)
            
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
        return {'state': jnp.concatenate(state), 'action':jnp.asarray(action), 
                'reward':jnp.asarray(reward), 
                'next_state':jnp.concatenate(next_state), 'done':jnp.asarray(done)}
    
    def __len__(self):
        return len(self.buffer)

# Utils

In [0]:
@jit
def scale_action(lower_bound, upper_bound, action):
  #(action - min(tanh))/(max(tanh) - min(tanh)) *(max_act - min_act) + min_act
  action = lower_bound + (action + 1.0) * 0.5 * (upper_bound - lower_bound) 
  action = jnp.clip(action, lower_bound, upper_bound)
        
  return action

In [0]:
@jit
def std(a):
  return jnp.std(jnp.asarray(a))

# Network

In [0]:
@jit
def polyak_average(old, new):
  return jax.tree_multimap(
    lambda p_ema, p: p_ema * POLYAK + p * (1. - POLYAK), old, new)

In [0]:
def build_actor(num_actions: int) -> hk.Transformed:

  def q(obs):
    network = hk.Sequential(
        [hk.Linear(128), jax.nn.relu, hk.Linear(128), jax.nn.relu,
         hk.Linear(num_actions)])
    return network(obs)

  return hk.transform(q)

In [0]:
def build_critic(num_actions: int) -> hk.Transformed:

  def q(obs):
    network = hk.Sequential(
        [hk.Linear(128), jax.nn.relu, hk.Linear(128), jax.nn.relu,
         hk.Linear(num_actions)])
    return network(obs)

  return hk.transform(q)