# Proximal Policy Optimization

In [None]:
import tensorflow as tf
import gym
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
def collect_data(sess, batch_size, gamma=0.99, debug=False):
    # make some empty lists for logging.
    batch_obs = []          # for observations
    batch_acts = []         # for actions
    batch_rets = []         # for measuring episode returns
    batch_lens = []         # for measuring episode lengths
    batch_new_obs = []
    
    # reset episode-specific variables
    obs = env.reset()       # first obs comes from starting distribution
    done = False            # signal from environment that episode is over
    ep_rews = []            # list for rewards accrued throughout ep

    # collect experience by acting in the environment with current policy
    while True:
        # save obs
        batch_obs.append(obs.copy())

        # act in the environment
        act = sess.run(actions, {obs_ph: obs.reshape(1,-1)})[0]
        obs, rew, done, info = env.step(act)
        batch_new_obs.append(obs.copy())
        # save action, reward
        batch_acts.append(act)
        ep_rews.append(rew)

        if done:
            # if episode is over, record info about episode
            ep_ret, ep_len = sum(ep_rews), len(ep_rews)
            batch_rets.append(ep_ret)
            batch_lens.append(ep_len)
            
            # reset episode-specific variables
            obs, done, ep_rews = env.reset(), False, []

            # end experience loop if we have enough of it
            if len(batch_obs) > batch_size:
                break
    return batch_obs, batch_acts, batch_new_obs, batch_rets, batch_lens

In [None]:
env = gym.make('CartPole-v0')
obs_dim = env.observation_space.shape[0]
n_acts = 2
gamma = 0.9

# placeholder
obs_ph = tf.placeholder(shape=(None, obs_dim), dtype=tf.float32)
act_ph = tf.placeholder(shape=(None,), dtype=tf.int32)
new_obs_ph = tf.placeholder(shape=(None, obs_dim), dtype=tf.float32)
rew_ph = tf.placeholder(shape=(None,), dtype=tf.float32)

# make core of policy network
mlp_policy = tf.keras.models.Sequential()
mlp_policy.add(tf.keras.layers.Dense(50, activation='tanh'))
mlp_policy.add(tf.keras.layers.Dense(50, activation='tanh'))
mlp_policy.add(tf.keras.layers.Dense(n_acts))

# make state-value function network
mlp_state_value = tf.keras.models.Sequential()
mlp_state_value.add(tf.keras.layers.Dense(50, activation='relu'))
mlp_state_value.add(tf.keras.layers.Dense(50, activation='relu'))
mlp_state_value.add(tf.keras.layers.Dense(1))
state_value = mlp_state_value(obs_ph)
new_state_value = mlp_state_value(new_obs_ph)
td_target = rew_ph + gamma * new_state_value

# make loss function whose gradient, for the right data, is policy gradient
obs_logits = mlp_policy(obs_ph)
actions = tf.squeeze(tf.multinomial(logits=obs_logits,num_samples=1), axis=1)
action_masks = tf.one_hot(act_ph, n_acts)
selected_action_probs = tf.reduce_sum(action_masks * tf.nn.softmax(obs_logits), axis=1)
r = selected_action_probs / tf.stop_gradient(selected_action_probs)
advantages = td_target - state_value
factor = 1 + 0.2 * tf.math.sign(advantages)
policy_loss = tf.reduce_mean(tf.math.minimum(r, factor) * advantages)

# state value loss function
state_value_loss = tf.reduce_mean((tf.stop_gradient(td_target) - state_value)**2)

In [None]:
policy_optimizer = tf.train.AdamOptimizer(0.003)
state_value_optimizer = tf.train.AdamOptimizer(0.001)
train_policy = policy_optimizer.minimize(policy_loss)
train_state_value = state_value_optimizer.minimize(state_value_loss)

sess = tf.Session()
n_epochs = 100
K = 5
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for i in range(n_epochs):
    obs, acts, new_obs, rews, lens = collect_data(sess, 1000)
    for j in range(K):
        sess.run([train_policy],feed_dict={
                                    obs_ph: np.array(obs).reshape(-1, obs_dim),
                                    act_ph: np.array(acts),
                                    new_obs_ph: np.array(new_obs).reshape(-1, obs_dim),
                                    rew_ph: np.array(rews)                                
                                 })