# Proximal Policy Optimization

In [1]:
import tensorflow as tf
import gym
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
def collect_data(sess, batch_size, gamma=0.99, debug=False):
    # make some empty lists for logging.
    batch_obs = []          # for observations
    batch_acts = []         # for actions
    batch_rets = []         # for measuring episode returns
    batch_lens = []         # for measuring episode lengths
    batch_new_obs = []
    batch_rews = []
    batch_terminal = []
    
    # reset episode-specific variables
    obs = env.reset()       # first obs comes from starting distribution
    done = False            # signal from environment that episode is over
    ep_rews = []            # list for rewards accrued throughout ep

    # collect experience by acting in the environment with current policy
    while True:
        # save obs
        batch_obs.append(obs.copy())

        # act in the environment
        act = sess.run(actions, {obs_ph: obs.reshape(1,-1)})[0]
        obs, rew, done, info = env.step(act)
        batch_new_obs.append(obs.copy())
        # save action, reward
        batch_terminal.append(float(done))
        batch_acts.append(act)
        ep_rews.append(rew)

        if done:
            if len(ep_rews) == 200:
                batch_terminal[-1] = 0
            # if episode is over, record info about episode
            ep_ret, ep_len = sum(ep_rews), len(ep_rews)
            batch_rets.append(ep_ret)
            batch_rews.extend(ep_rews)
            batch_lens.append(ep_len)
            
            # reset episode-specific variables
            obs, done, ep_rews = env.reset(), False, []

            # end experience loop if we have enough of it
            if len(batch_obs) > batch_size:
                break
    return batch_obs, batch_acts, batch_new_obs, batch_rews, batch_rets, batch_lens, batch_terminal

In [10]:
env = gym.make('CartPole-v0')
obs_dim = env.observation_space.shape[0]
n_acts = 2
gamma = 0.90

# placeholder
obs_ph = tf.placeholder(shape=(None, obs_dim), dtype=tf.float32)
act_ph = tf.placeholder(shape=(None,), dtype=tf.int32)
new_obs_ph = tf.placeholder(shape=(None, obs_dim), dtype=tf.float32)
rew_ph = tf.placeholder(shape=(None,), dtype=tf.float32)
terminal_ph = tf.placeholder(shape=(None,), dtype=tf.float32)

# make core of policy network
mlp_policy = tf.keras.models.Sequential()
mlp_policy.add(tf.keras.layers.Dense(50, activation='tanh'))
mlp_policy.add(tf.keras.layers.Dense(50, activation='tanh'))
mlp_policy.add(tf.keras.layers.Dense(n_acts))

# make core of policy network
mlp_policy_old = tf.keras.models.Sequential()
mlp_policy_old.add(tf.keras.layers.Dense(50, activation='tanh'))
mlp_policy_old.add(tf.keras.layers.Dense(50, activation='tanh'))
mlp_policy_old.add(tf.keras.layers.Dense(n_acts))

# make state-value function network
mlp_state_value = tf.keras.models.Sequential()
mlp_state_value.add(tf.keras.layers.Dense(50, activation='relu'))
mlp_state_value.add(tf.keras.layers.Dense(50, activation='relu'))
mlp_state_value.add(tf.keras.layers.Dense(1))
state_value = mlp_state_value(obs_ph)
new_state_value = mlp_state_value(new_obs_ph)
td_target = rew_ph + gamma * new_state_value * (1-terminal_ph)

# make loss function whose gradient, for the right data, is policy gradient
obs_logits = mlp_policy(obs_ph)
old_obs_logits = mlp_policy_old(obs_ph)
actions = tf.squeeze(tf.multinomial(logits=obs_logits,num_samples=1), axis=1)
action_masks = tf.one_hot(act_ph, n_acts)
selected_action_probs = tf.reduce_sum(action_masks * tf.nn.softmax(obs_logits), axis=1)
old_selected_action_probs = tf.reduce_sum(action_masks * tf.nn.softmax(old_obs_logits), axis=1)
r = selected_action_probs / tf.stop_gradient(old_selected_action_probs)
advantages = td_target - state_value
factor = 1 + 0.2 * tf.math.sign(advantages)
policy_loss = -tf.reduce_mean(tf.math.minimum(r, factor) * advantages)

# state value loss function
state_value_loss = tf.reduce_mean((tf.stop_gradient(td_target) - state_value)**2)

In [12]:
def update_old_network(old_policy_network, policy_network):
    sess.run([v_t.assign(v) for v_t, v in zip(old_policy_network.trainable_weights, policy_network.trainable_weights)])

In [13]:
policy_optimizer = tf.train.AdamOptimizer(0.0003)
state_value_optimizer = tf.train.AdamOptimizer(0.001)
train_policy = policy_optimizer.minimize(policy_loss)
train_state_value = state_value_optimizer.minimize(state_value_loss)

sess = tf.Session()
n_epochs = 1000
K = 5
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for i in range(n_epochs):
    update_old_network(mlp_policy_old, mlp_policy)
    obs, acts, new_obs, rews, rets, lens, terminal = collect_data(sess, 1000)
    print(np.mean(rets), np.std(rets), np.min(rets), np.max(rets))
    for j in range(K):
        sess.run([train_policy, train_state_value],feed_dict={
                                    obs_ph: np.array(obs).reshape(-1, obs_dim),
                                    act_ph: np.array(acts),
                                    new_obs_ph: np.array(new_obs).reshape(-1, obs_dim),
                                    rew_ph: np.array(rews),
                                    terminal_ph: np.array(terminal)
                                 })
#         print(sess.run(state_value, feed_dict ={
#                 obs_ph: np.zeros(4).reshape(1, -1)
#         }))

28.36111111111111 16.44545323482346 10.0 92.0
30.147058823529413 14.802885952609092 11.0 58.0
26.526315789473685 16.27192746015966 11.0 95.0
28.22222222222222 14.128598088650454 10.0 68.0
27.833333333333332 16.192419351179257 11.0 92.0
32.125 16.78866805318397 12.0 82.0
28.166666666666668 14.76952568259688 9.0 78.0
27.43243243243243 12.65026554692022 13.0 71.0
25.82051282051282 13.666113425084381 10.0 75.0
25.897435897435898 14.73665388697015 9.0 93.0
23.953488372093023 13.758639549383682 9.0 71.0
25.871794871794872 13.023016089235492 9.0 71.0
23.41860465116279 9.511883774348934 10.0 51.0
22.42222222222222 10.489124291143742 9.0 47.0
24.285714285714285 13.015166453915572 11.0 60.0
23.976190476190474 10.317548839669087 8.0 53.0
26.657894736842106 16.422507526641716 11.0 93.0
21.319148936170212 11.087862605030367 9.0 69.0
23.41860465116279 10.296155430919484 10.0 63.0
21.956521739130434 10.6932659055963 9.0 62.0
21.340425531914892 11.359271395095428 9.0 51.0
19.26923076923077 10.11304447

10.5 1.8874586088176875 8.0 18.0
9.941176470588236 1.3197112682716738 8.0 14.0
10.31958762886598 1.418191034145785 8.0 15.0
10.13 1.6652627420320194 8.0 18.0
10.141414141414142 2.0449993258258803 8.0 21.0
10.07 1.5827507700203465 8.0 16.0
10.121212121212121 1.8437594842637888 8.0 17.0
10.181818181818182 1.4520180061467804 8.0 15.0
10.02 1.3782597723216041 8.0 15.0
10.23469387755102 1.5306462581254807 8.0 17.0


KeyboardInterrupt: 