In [1]:
!pip install gym



In [1]:
import tensorflow as tf
import gym
import numpy as np

In [3]:
def collect_data(sess, batch_size, debug=False):
    # make some empty lists for logging.
    batch_obs = []          # for observations
    batch_acts = []         # for actions
    batch_weights = []      # for R(tau) weighting in policy gradient
    batch_rets = []         # for measuring episode returns
    batch_lens = []         # for measuring episode lengths

    # reset episode-specific variables
    obs = env.reset()       # first obs comes from starting distribution
    done = False            # signal from environment that episode is over
    ep_rews = []            # list for rewards accrued throughout ep

    # collect experience by acting in the environment with current policy
    while True:
        # save obs
        batch_obs.append(obs.copy())

        # act in the environment
        act = sess.run(actions, {obs_ph: obs.reshape(1,-1)})[0]
        obs, rew, done, _ = env.step(act)

        # save action, reward
        batch_acts.append(act)
        ep_rews.append(rew)

        if done:
            # if episode is over, record info about episode
            ep_ret, ep_len = sum(ep_rews), len(ep_rews)
            batch_rets.append(ep_ret)
            batch_lens.append(ep_len)
            
            # the weight for each logprob(a_t|s_t) is reward-to-go from t
            batch_weights += list(np.cumsum(ep_rews[::-1])[::-1])

            # reset episode-specific variables
            obs, done, ep_rews = env.reset(), False, []

            # end experience loop if we have enough of it
            if len(batch_rets) > batch_size:
                break
    return batch_obs, batch_acts, batch_weights, batch_rets, batch_lens

In [4]:
obs_dim = 4
n_acts = 2

# make core of policy network
obs_ph = tf.placeholder(shape=(None, obs_dim), dtype=tf.float32)
mlp = tf.keras.layers.Dense(n_acts)
logits = mlp(obs_ph)

# make core of state-action-value function network
mlp_action_val = tf.keras.models.Sequential()
mlp_action_val.add(tf.keras.layers.Dense(50, activation='tanh'))
mlp_action_val.add(tf.keras.layers.Dense(50, activation='tanh'))
mlp_action_val.add(tf.keras.layers.Dense(n_acts))
state_action_values = mlp_action_val(obs_ph)

mlp_val = tf.keras.models.Sequential()
mlp_val.add(tf.keras.layers.Dense(50, activation='tanh'))
mlp_val.add(tf.keras.layers.Dense(50, activation='tanh'))
mlp_val.add(tf.keras.layers.Dense(1))
state_value = mlp_val(obs_ph)

# make action selection op (outputs int actions, sampled from policy)
actions = tf.squeeze(tf.multinomial(logits=logits,num_samples=1), axis=1)

# make loss function whose gradient, for the right data, is policy gradient
act_ph = tf.placeholder(shape=(None,), dtype=tf.int32)
weights_ph = tf.placeholder(shape=(None,), dtype=tf.float32)
action_masks = tf.one_hot(act_ph, n_acts)
log_probs = tf.reduce_sum(action_masks * tf.nn.log_softmax(logits), axis=1)
action_values = tf.reduce_sum(action_masks * state_action_values, axis=1)
loss = -tf.reduce_mean((action_values-state_value) * log_probs)

# state value loss function
loss_action_value = tf.reduce_mean((action_values - weights_ph)**2)
loss_state_value = tf.reduce_mean((state_value - weights_ph)**2)

In [None]:
# main
env = gym.make('CartPole-v1')
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
optimizer = tf.train.GradientDescentOptimizer(0.01)
optimizer_action_value = tf.train.GradientDescentOptimizer(0.01)
optimizer_state_value = tf.train.GradientDescentOptimizer(0.01)
for i in range(10000):
    tmp1, tmp2, tmp3, batch_rets, batch_len = collect_data(sess, 100, debug=False)
    print(i, np.mean(batch_len), np.min(batch_len), np.max(batch_len))
    train = optimizer.minimize(loss, var_list=[mlp.kernel, mlp.bias])
    train_action_value = optimizer_action_value.minimize(loss_action_value)
    train_state_value = optimizer_state_value.minimize(loss_state_value)

    sess.run([train, train_action_value, train_state_value],feed_dict={
                                    obs_ph: np.array(tmp1),
                                    act_ph: np.array(tmp2),
                                    weights_ph: np.array(tmp3)
                                 })
print('Evaluation')
tmp1, tmp2, tmp3, batch_rets, batch_len = collect_data(sess, 5, debug=True)
print(np.mean(batch_len), np.min(batch_len), np.max(batch_len))

[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
0 15.693069306930694 8 45
1 14.554455445544555 8 32
2 14.613861386138614 8 35
3 16.752475247524753 8 60
4 15.534653465346535 8 32
5 15.386138613861386 9 51
6 14.851485148514852 8 31
7 15.653465346534654 8 45
8 15.594059405940595 8 35
9 14.603960396039604 9 37
10 15.267326732673267 8 30
11 15.643564356435643 9 36
12 14.900990099009901 8 37
13 15.603960396039604 8 42
14 15.603960396039604 8 49
15 15.504950495049505 9 36
16 15.693069306930694 8 53
17 14.742574257425742 8 32
18 17.534653465346533 8 43
19 15.445544554455445 8 39
20 14.900990099009901 8 40
21 14.821782178217822 9 28
22 15.861386138613861 9 47
23 15.841584158415841 8 34
24 16.653465346534652 9 61
25 16.663366336633665 9 44
26 16.386138613861387 8 45
27 17.287128712871286 8 53
28 16.702970297029704 8 39
29 16.95049504950495 9 44
30 15.03960396039604 8 40
31 16.97029702970297 8 38
32 15.871287128712872 8 63
33 15.86138613

289 42.67326732673267 13 109
290 44.495049504950494 13 105
291 40.95049504950495 11 167
292 45.51485148514851 18 131
293 45.31683168316832 15 117
294 42.94059405940594 14 109
295 44.01980198019802 17 89
296 42.2970297029703 16 107
297 41.83168316831683 15 107
298 43.722772277227726 16 104
299 43.524752475247524 11 101
300 44.07920792079208 15 104
301 41.76237623762376 17 90
302 41.37623762376238 14 105
303 44.45544554455446 16 139
304 43.07920792079208 15 77
305 42.415841584158414 14 104
306 45.89108910891089 15 106
307 44.86138613861386 14 100
308 44.475247524752476 15 110
309 50.97029702970297 16 178
310 45.42574257425743 18 119
311 45.32673267326733 13 94
312 43.0990099009901 18 117
313 45.79207920792079 15 101
314 48.9009900990099 10 137
315 46.45544554455446 19 161
316 46.148514851485146 14 117
317 44.722772277227726 14 113
318 41.79207920792079 15 197
319 45.31683168316832 16 123
320 46.92079207920792 18 132
321 44.84158415841584 13 137
322 46.04950495049505 19 100
323 47.0198019

571 60.960396039603964 22 136
572 58.78217821782178 20 163
573 67.51485148514851 18 197
574 56.277227722772274 20 135
575 58.65346534653465 13 164
576 58.24752475247525 21 140
577 58.51485148514851 13 140
578 65.22772277227723 22 229
579 53.0990099009901 20 147
580 57.98019801980198 21 154
581 57.742574257425744 13 175
582 57.495049504950494 21 178
583 59.366336633663366 17 153
584 61.613861386138616 22 227
585 65.97029702970298 19 186
586 62.76237623762376 22 201
587 62.554455445544555 24 235
588 59.54455445544554 23 154
589 55.42574257425743 19 146
590 61.93069306930693 26 225
591 62.31683168316832 23 175
592 59.97029702970297 23 155
593 58.584158415841586 20 151
594 55.46534653465346 19 126
595 56.31683168316832 16 139
596 59.06930693069307 16 138
597 61.83168316831683 21 235
598 64.87128712871286 18 143
599 61.504950495049506 17 188
600 62.82178217821782 19 195
601 66.0 22 174
602 64.17821782178218 29 162
603 63.772277227722775 18 179
604 58.64356435643565 22 146
605 63.07920792079