In [1]:
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
import gym
from gym.spaces import Discrete, Box

tfd = tfp.distributions
tf.random.set_seed(10)

In [2]:
def make_mlp_model(sizes, activation='tanh', output_activation=None):    
    """ Build a feedforward neural network. """     
    mlp = tf.keras.Sequential([tf.keras.layers.Dense(size, activation=activation) for size in sizes[:-1]])
    mlp.add(tf.keras.layers.Dense(sizes[-1], activation=output_activation))
    return mlp

In [3]:
def reward_to_go(rews):
    n = len(rews)
    rtgs = np.zeros_like(rews)
    for i in reversed(range(n)):
        rtgs[i] = rews[i] + (rtgs[i+1] if i+1 < n else 0)
    return rtgs

Selecting a good learning rate is very crucial, the default learning rate of tensorflow Adam is 0.001, which in this case results in very poor performance.
An

In [4]:
def train(env_name='CartPole-v0', hidden_sizes=[32], lr=1e-2, epochs=50, batch_size=5000, render=False):
    # make environment, check spaces, get obs / act dims
    env = gym.make(env_name)
    assert isinstance(env.observation_space, Box), "This example only works for envs with continuous state spaces."
    assert isinstance(env.action_space, Discrete), "This example only works for envs with discrete action spaces."
    
    obs_dim = env.observation_space.shape[0]
    n_acts = env.action_space.n
    
    mlp = mlp_model(sizes=hidden_sizes+[n_acts])
    
    optimizer = tf.optimizers.Adam(learning_rate=lr)
    
    @tf.function
    def select_action(observation):
        
        # make core of policy network
        logits = mlp(observation) 
        # make action selection op (outputs int actions, sampled from policy)
        action = tfd.Categorical(logits).sample()
        
        return action

    @tf.function
    def update(observation, act, weights):  
        action_mask = tf.one_hot(act, n_acts, dtype=tf.float64)
        with tf.GradientTape() as tape:
            logits = mlp(observation) 
            # make loss function whose gradient, for the right data, is policy gradient            
            log_probs = tf.reduce_sum(action_mask * tf.nn.log_softmax(logits), axis=1)

            loss = -tf.reduce_mean(weights * log_probs)
        grads = tape.gradient(loss, mlp.trainable_variables)
        optimizer.apply_gradients(zip(grads, mlp.trainable_variables))
        return loss
    
    # for training policy
    def train_one_epoch():
        # make some empty lists for logging.
        batch_obs = []          # for observations
        batch_acts = []         # for actions
        batch_weights = []      # for R(tau) weighting in policy gradient
        batch_rets = []         # for measuring episode returns
        batch_lens = []         # for measuring episode lengths

        # reset episode-specific variables
        obs = env.reset()       # first obs comes from starting distribution
        done = False            # signal from environment that episode is over
        ep_rews = []            # list for rewards accrued throughout ep

        # render first episode of each epoch
        finished_rendering_this_epoch = False

        # collect experience by acting in the environment with current policy
        while True:

            # rendering
            if (not finished_rendering_this_epoch) and render:
                env.render()

            # save obs
            batch_obs.append(obs.copy())

            # act in the environment 
            act = select_action(obs.reshape(1, -1))
            act = act.numpy()[0]

            obs, rew, done, _ = env.step(act)
            
            # save action, reward
            batch_acts.append(act)
            ep_rews.append(rew)
            
            if done:
                # if episode is over, record info about episode
                ep_ret, ep_len = sum(ep_rews), len(ep_rews)
                batch_rets.append(ep_ret)
                batch_lens.append(ep_len)
                
                # the weight for each logprob(a|b) is R(tau)
                batch_weights += list(reward_to_go(ep_rews))
                
                # reset episode-specific variables
                obs, done, ep_rews = env.reset(), False, []
                
                # won't render again this epoch
                finished_rendering_this_epoch = True
                
                # end experience loop if we have enough of it
                if len(batch_obs) > batch_size:
                    break
        
        # take a single poligy gradient undate step
        batch_loss = update(np.array(batch_obs), np.array(batch_acts), np.array(batch_weights))
        return batch_loss, batch_rets, batch_lens
    
    for i in range(epochs):
        batch_loss, batch_rets, batch_lens = train_one_epoch()
        print('epoch: %3d \t loss: %.3f \t return: %.3f \t ep_len: %.3f'%
                (i, batch_loss, np.mean(batch_rets), np.mean(batch_lens)))

Adding tf.function reduce wall time from 6 mins to 1 min

In [5]:
%%time
print('\nUsing simplest formulation of policy gradient.\n')
train()


Using simplest formulation of policy gradient.

epoch:   0 	 loss: 7.799 	 return: 17.436 	 ep_len: 17.436
epoch:   1 	 loss: 8.918 	 return: 19.551 	 ep_len: 19.551
epoch:   2 	 loss: 10.405 	 return: 22.745 	 ep_len: 22.745
epoch:   3 	 loss: 11.819 	 return: 26.511 	 ep_len: 26.511
epoch:   4 	 loss: 13.786 	 return: 30.018 	 ep_len: 30.018
epoch:   5 	 loss: 14.733 	 return: 33.905 	 ep_len: 33.905
epoch:   6 	 loss: 15.170 	 return: 35.695 	 ep_len: 35.695
epoch:   7 	 loss: 19.029 	 return: 41.717 	 ep_len: 41.717
epoch:   8 	 loss: 17.538 	 return: 45.613 	 ep_len: 45.613
epoch:   9 	 loss: 19.602 	 return: 51.293 	 ep_len: 51.293
epoch:  10 	 loss: 16.864 	 return: 46.741 	 ep_len: 46.741
epoch:  11 	 loss: 19.612 	 return: 53.660 	 ep_len: 53.660
epoch:  12 	 loss: 21.472 	 return: 57.148 	 ep_len: 57.148
epoch:  13 	 loss: 19.593 	 return: 54.269 	 ep_len: 54.269
epoch:  14 	 loss: 23.100 	 return: 63.937 	 ep_len: 63.937
epoch:  15 	 loss: 21.711 	 return: 62.741 	 ep_len: 