# Proximal Policy Optimization

In [None]:
import tensorflow as tf
import gym
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
def collect_data_single_actor(sess, batch_size):
    # make some empty lists for logging.
    batch_obs = []          # for observations
    batch_acts = []         # for actions
    batch_rets = []         # for measuring episode returns
    batch_lens = []         # for measuring episode lengths
    batch_new_obs = []
    batch_rews = []
    batch_terminal = []
    
    # reset episode-specific variables
    obs = env.reset()       # first obs comes from starting distribution
    done = False            # signal from environment that episode is over
    ep_rews = []            # list for rewards accrued throughout ep

    # collect experience by acting in the environment with current policy
    while True:
        # save obs
        batch_obs.append(obs.copy())

        # act in the environment
        act = sess.run(actions, {obs_ph: obs.reshape(1,-1)})[0]
        obs, rew, done, info = env.step(act)
        batch_new_obs.append(obs.copy())
        # save action, reward
        batch_terminal.append(float(done))
        batch_acts.append(act)
        ep_rews.append(rew)

        if done:
            if len(ep_rews) == 200:
                batch_terminal[-1] = 0.0
            # if episode is over, record info about episode
            ep_ret, ep_len = sum(ep_rews), len(ep_rews)
            batch_rets.append(ep_ret)
            batch_rews.extend(ep_rews)
            batch_lens.append(ep_len)
            
            # reset episode-specific variables
            obs, done, ep_rews = env.reset(), False, []

            # end experience loop if we have enough of it
            if len(batch_obs) > batch_size:
                break
    return batch_obs, batch_acts, batch_new_obs, batch_rews, batch_rets, batch_lens, batch_terminal

In [None]:
# class DataCollector():
    
#     def __init__(self, env_name, n_actors, n_samples):
#         self._envs = [gym.make(env_name) for _ in n_actors]
#         self._states = [env.reset() for env in self._envs]
#         self._n_samples = n_samples
        
#     def collect_data(self):
#         batch_obs = []          # for observations
#         batch_acts = []         # for actions
#         batch_rets = []         # for measuring episode returns
#         batch_lens = []         # for measuring episode lengths
#         batch_new_obs = []
#         batch_rews = []
#         batch_terminal = []
#         for i in range(len(self._envs)):
#             tmp_obs, tmp_acts, tmp_new_obs, tmp_rews, tmp_rets, tmp_lens, tmp_terminal = self._collect_data_single_actor

In [None]:
env = gym.make('CartPole-v0')
obs_dim = env.observation_space.shape[0]
n_acts = 2
gamma = 0.9

# placeholder
obs_ph = tf.placeholder(shape=(None, obs_dim), dtype=tf.float32)
act_ph = tf.placeholder(shape=(None,), dtype=tf.int32)
new_obs_ph = tf.placeholder(shape=(None, obs_dim), dtype=tf.float32)
rew_ph = tf.placeholder(shape=(None,1), dtype=tf.float32)
terminal_ph = tf.placeholder(shape=(None,1), dtype=tf.float32)

# make core of policy network
mlp_policy = tf.keras.models.Sequential()
mlp_policy.add(tf.keras.layers.Dense(50, activation='tanh'))
mlp_policy.add(tf.keras.layers.Dense(50, activation='tanh'))
mlp_policy.add(tf.keras.layers.Dense(n_acts))

# make core of policy network
mlp_policy_old = tf.keras.models.Sequential()
mlp_policy_old.add(tf.keras.layers.Dense(50, activation='tanh'))
mlp_policy_old.add(tf.keras.layers.Dense(50, activation='tanh'))
mlp_policy_old.add(tf.keras.layers.Dense(n_acts))

# make state-value function network
mlp_state_value = tf.keras.models.Sequential()
mlp_state_value.add(tf.keras.layers.Dense(50, activation='relu'))
mlp_state_value.add(tf.keras.layers.Dense(50, activation='relu'))
mlp_state_value.add(tf.keras.layers.Dense(1))
state_value = mlp_state_value(obs_ph)
new_state_value = mlp_state_value(new_obs_ph)
td_target = rew_ph + gamma * new_state_value * (1-terminal_ph)

# make loss function whose gradient, for the right data, is policy gradient
obs_logits = mlp_policy(obs_ph)
old_obs_logits = mlp_policy_old(obs_ph)
actions = tf.squeeze(tf.multinomial(logits=obs_logits,num_samples=1), axis=1)
action_masks = tf.one_hot(act_ph, n_acts)
selected_action_probs = tf.reduce_sum(action_masks * tf.nn.softmax(obs_logits), axis=1)
old_selected_action_probs = tf.reduce_sum(action_masks * tf.nn.softmax(old_obs_logits), axis=1)

r = selected_action_probs / tf.stop_gradient(old_selected_action_probs)
advantages = tf.squeeze(td_target - state_value, axis=1)
factor = 1 + 0.2 * tf.math.sign(advantages)
x = tf.math.minimum(advantages*r, advantages*factor)
policy_loss = -tf.reduce_mean(x)

# state value loss function
y  = (tf.stop_gradient(td_target) - state_value)**2
state_value_loss = tf.reduce_mean(y)


# debug
average_state_value = tf.reduce_mean(state_value)
max_r = tf.reduce_max(r)
max_advantages = tf.reduce_max(advantages)

In [None]:
def update_old_network(old_policy_network, policy_network):
    sess.run([v_t.assign(v) for v_t, v in zip(old_policy_network.trainable_weights, policy_network.trainable_weights)])

In [None]:
policy_optimizer = tf.train.AdamOptimizer(0.0003)
state_value_optimizer = tf.train.AdamOptimizer(0.0003)
train_policy = policy_optimizer.minimize(policy_loss)
train_state_value = state_value_optimizer.minimize(state_value_loss)

sess = tf.Session()
n_epochs = 50
K = 5
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for i in range(n_epochs):
    update_old_network(mlp_policy_old, mlp_policy)
    obs, acts, new_obs, rews, rets, lens, terminal = collect_data(sess, 1000)
    print(np.mean(rets), np.std(rets), np.min(rets), np.max(rets))
    for j in range(K):
        print(sess.run([max_r,max_advantages], feed_dict ={
                obs_ph: np.array(obs).reshape(-1, obs_dim),
                act_ph: np.array(acts),
                new_obs_ph: np.array(new_obs).reshape(-1, obs_dim),
                rew_ph: np.array(rews).reshape(-1, 1),
                terminal_ph: np.array(terminal).reshape(-1, 1)
        }))
        sess.run([train_policy],feed_dict={
                                    obs_ph: np.array(obs).reshape(-1, obs_dim),
                                    act_ph: np.array(acts),
                                    new_obs_ph: np.array(new_obs).reshape(-1, obs_dim),
                                    rew_ph: np.array(rews).reshape(-1, 1),
                                    terminal_ph: np.array(terminal).reshape(-1, 1)
                                 })
    for j in range(30):
        sess.run([train_state_value],feed_dict={
                                    obs_ph: np.array(obs).reshape(-1, obs_dim),
                                    act_ph: np.array(acts),
                                    new_obs_ph: np.array(new_obs).reshape(-1, obs_dim),
                                    rew_ph: np.array(rews).reshape(-1, 1),
                                    terminal_ph: np.array(terminal).reshape(-1, 1)
                                 })
    print('State value loss is:')
    print(sess.run(state_value_loss, feed_dict ={
                obs_ph: np.array(obs).reshape(-1, obs_dim),
                act_ph: np.array(acts),
                new_obs_ph: np.array(new_obs).reshape(-1, obs_dim),
                rew_ph: np.array(rews).reshape(-1, 1),
                terminal_ph: np.array(terminal).reshape(-1, 1)
        }))
    print()
        