In [1]:
import numpy as np
import tensorflow as tf
import gym
from collections import namedtuple
from gym.wrappers import Monitor
import scipy.signal
import os

  from ._conv import register_converters as _register_converters


In [18]:
env = gym.make('CartPole-v0')
obs_dim = env.observation_space.shape[0]
num_actions = env.action_space.n
hidden_layers = [64,64]
epoch = 20
steps_per_epoch = 4000
gamma = 0.99
gae_lambda = 0.97
train_v = 60
video_freq = 10
save_freq = 5
checkpointDir = "checkpoint"
monitorDir = "monitor"
obs_dim,num_actions

[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m


(4, 2)

In [3]:
def policy_estimator(x,action,hidden_layers,num_actions,output_activation,activation):
    for l in hidden_layers:
        x = tf.layers.dense(x,units=l,activation=activation)
    logits = tf.layers.dense(x,units=num_actions,activation=output_activation)
    
    probs = tf.nn.log_softmax(logits)
    pi_a = tf.squeeze(tf.multinomial(logits,1), axis=1)
    prob_a = tf.reduce_sum(tf.one_hot(action, depth=num_actions) * probs, axis=1)
    prob_pi = tf.reduce_sum(tf.one_hot(pi_a, depth=num_actions) * probs, axis=1)
    return pi_a,prob_pi,prob_a

In [4]:
def value_estimator(x,hidden_layers,output_activation=None,activation=tf.tanh):
    for l in hidden_layers:
        x = tf.layers.dense(x,units=l,activation=activation)
    logits = tf.layers.dense(x,units=1,activation=output_activation)
    return tf.squeeze(logits,axis=1)

In [5]:
def actor_critic(x,act,hidden_layers,num_actions,output_activation=None,activation=tf.tanh):
    pi_a,prob_pi,prob_a = policy_estimator(x,act,hidden_layers,num_actions,output_activation=None,activation=tf.tanh)
    v = value_estimator(x,hidden_layers,output_activation=None,activation=tf.tanh)
    return pi_a,prob_pi,prob_a,v

In [6]:
x = tf.placeholder(dtype = tf.float32,shape = (None,obs_dim),name="observations")
actions = tf.placeholder(dtype = tf.int32,shape = (None,),name="actions")
ret = tf.placeholder(dtype = tf.float32,shape = (None,),name="ret")
advantages = tf.placeholder(dtype = tf.float32,shape = (None,),name="advs")

pi_a,prob_pi,prob_a,value = actor_critic(x,actions,hidden_layers,num_actions)

policy_loss = -tf.reduce_mean(prob_a * advantages)
value_loss = tf.reduce_mean((ret - value)**2)

In [7]:
optimize_policy = tf.train.AdamOptimizer(learning_rate=3e-4).minimize(policy_loss)
optimize_value = tf.train.AdamOptimizer(learning_rate=1e-3).minimize(value_loss)

In [8]:
def calculate_advantage(rews,values,final_value):
    values = np.append(values,final_value)
    rews = np.append(rews,final_value)
    dels = rews[:-1] + gamma * values[1:] - values[:-1]
    return scipy.signal.lfilter([1], [1, float(-gamma*gae_lambda)], dels[::-1], axis=0)[::-1]
    

In [9]:
def rewards_to_go(rews,final_value):
    rews = np.append(rews,final_value)
    return scipy.signal.lfilter([1], [1, float(-gamma)], rews[::-1], axis=0)[::-1][:-1]

In [13]:
def update_policy(sess,obs_memory,action_memory,rtgs_memory,adv_memory):
    feed_dict = {x:obs_memory,actions:action_memory,ret:rtgs_memory,advantages:adv_memory}
    
    policy_loss_e,value_loss_e = sess.run([policy_loss,value_loss],feed_dict=feed_dict)
    
    sess.run(optimize_policy,feed_dict=feed_dict)
    
    for i in range(train_v):
        sess.run(optimize_value,feed_dict=feed_dict)

    return policy_loss_e,value_loss_e
    

In [14]:
def train():
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    env = gym.make('CartPole-v0')
    
    if not os.path.exists(checkpointDir):
        os.makedirs(checkpointDir)
    #if not os.path.exists(monitorDir):
     #   os.makedirs(monitorDir)
        
    checkpoint = os.path.join(checkpointDir,"model")
    #monitor = os.path.join(monitorDir,"game")
    
    saver = tf.train.Saver()
        
        
    #env = Monitor(env, directory=monitor, video_callable=lambda e: e % \
    #                  video_freq == 0, resume=True)
    
    ckpt = tf.train.latest_checkpoint(checkpointDir)
    if ckpt:
        saver.restore(sess,ckpt)
        print("Existing checkpoint {} restored...".format(ckpt))
            
    obs,rew,done = env.reset(),0,False
    total_rew = 0
    episode_length = 0
    episode_stats = []
    losses = []
    obs_memory = np.zeros((steps_per_epoch,obs_dim), dtype=np.float32)
    action_memory = np.zeros(steps_per_epoch, dtype=np.int32)
    rew_memory = np.zeros(steps_per_epoch, dtype=np.float32)
    value_memory = np.zeros(steps_per_epoch, dtype=np.float32)
    prob_pi_memory = np.zeros(steps_per_epoch, dtype=np.float32)
    adv_memory = np.zeros(steps_per_epoch, dtype=np.float32)
    rtgs_memory = np.zeros(steps_per_epoch, dtype=np.float32)

    for e in range(epoch):
        buf_head = 0
        for t in range(steps_per_epoch):
            pi_a_t,value_t,prob_pi_t = sess.run([pi_a,value,prob_pi],feed_dict={x:obs.reshape(1,-1)})
            
            obs_memory[t],action_memory[t],rew_memory[t],\
            value_memory[t],prob_pi_memory[t] = obs,pi_a_t,rew,value_t,prob_pi_t
            
            obs,rew,done,_ = env.step(pi_a_t[0])
            total_rew += rew
            episode_length += 1
            
            if done or (t==steps_per_epoch-1):
                if not done:
                    print("Alert:Final episode terminated without completion...")
                    final_value = sess.run([value],feed_dict={x:obs.reshape(1,-1)})
                else:
                    final_value = rew
                
                
                adv_memory[buf_head:t] = calculate_advantage(rew_memory[buf_head:t],\
                                                             value_memory[buf_head:t],final_value)
                rtgs_memory[buf_head:t] = rewards_to_go(rew_memory[buf_head:t],final_value)
                buf_head = t
                episode_stats.append((total_rew,episode_length))
                obs,rew,done,total_rew,episode_length = env.reset(),0,False,0,0
        
        if(e%save_freq == 0) or (e == epoch-1):
            saver.save(sess, checkpoint)
            
        policy_loss_e,value_loss_e = update_policy(sess,obs_memory,action_memory,rtgs_memory,adv_memory)
        print("Epoch Episode No. Total rewards     episode Length")
        for i in range(len(episode_stats)):
            print(f"{e}       {i}           {episode_stats[i][0]}                 {episode_stats[i][1]}")
        print(f"Policy Loss : {policy_loss_e} Value Loss : {value_loss_e}")
        episode_stats = []

In [19]:
train()

[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
INFO:tensorflow:Restoring parameters from checkpoint/model
Existing checkpoint checkpoint/model restored...
Alert:Final episode terminated without completion...
Epoch Episode No. Total rewards     episode Length
0       0           33.0                 33
0       1           74.0                 74
0       2           22.0                 22
0       3           28.0                 28
0       4           18.0                 18
0       5           24.0                 24
0       6           52.0                 52
0       7           10.0                 10
0       8           29.0                 29
0       9           28.0                 28
0       10           37.0                 37
0       11           24.0                 24
0       12           13.0                 13
0       13           59.0                 59
0       14           24.0                 24
0       15     

Alert:Final episode terminated without completion...
Epoch Episode No. Total rewards     episode Length
2       0           56.0                 56
2       1           38.0                 38
2       2           48.0                 48
2       3           48.0                 48
2       4           13.0                 13
2       5           31.0                 31
2       6           15.0                 15
2       7           27.0                 27
2       8           29.0                 29
2       9           23.0                 23
2       10           30.0                 30
2       11           20.0                 20
2       12           33.0                 33
2       13           36.0                 36
2       14           16.0                 16
2       15           16.0                 16
2       16           29.0                 29
2       17           23.0                 23
2       18           67.0                 67
2       19           22.0                 22
2     

Alert:Final episode terminated without completion...
Epoch Episode No. Total rewards     episode Length
4       0           15.0                 15
4       1           23.0                 23
4       2           37.0                 37
4       3           27.0                 27
4       4           14.0                 14
4       5           22.0                 22
4       6           37.0                 37
4       7           18.0                 18
4       8           40.0                 40
4       9           16.0                 16
4       10           14.0                 14
4       11           20.0                 20
4       12           50.0                 50
4       13           32.0                 32
4       14           44.0                 44
4       15           33.0                 33
4       16           64.0                 64
4       17           31.0                 31
4       18           23.0                 23
4       19           49.0                 49
4     

Alert:Final episode terminated without completion...
Epoch Episode No. Total rewards     episode Length
6       0           18.0                 18
6       1           79.0                 79
6       2           18.0                 18
6       3           79.0                 79
6       4           21.0                 21
6       5           87.0                 87
6       6           17.0                 17
6       7           26.0                 26
6       8           33.0                 33
6       9           17.0                 17
6       10           18.0                 18
6       11           36.0                 36
6       12           25.0                 25
6       13           21.0                 21
6       14           36.0                 36
6       15           27.0                 27
6       16           52.0                 52
6       17           31.0                 31
6       18           49.0                 49
6       19           22.0                 22
6     

Alert:Final episode terminated without completion...
Epoch Episode No. Total rewards     episode Length
8       0           33.0                 33
8       1           15.0                 15
8       2           40.0                 40
8       3           47.0                 47
8       4           19.0                 19
8       5           31.0                 31
8       6           14.0                 14
8       7           19.0                 19
8       8           16.0                 16
8       9           13.0                 13
8       10           47.0                 47
8       11           32.0                 32
8       12           49.0                 49
8       13           21.0                 21
8       14           79.0                 79
8       15           20.0                 20
8       16           16.0                 16
8       17           19.0                 19
8       18           34.0                 34
8       19           34.0                 34
8     

Alert:Final episode terminated without completion...
Epoch Episode No. Total rewards     episode Length
10       0           45.0                 45
10       1           62.0                 62
10       2           32.0                 32
10       3           18.0                 18
10       4           27.0                 27
10       5           52.0                 52
10       6           64.0                 64
10       7           21.0                 21
10       8           18.0                 18
10       9           26.0                 26
10       10           59.0                 59
10       11           25.0                 25
10       12           53.0                 53
10       13           20.0                 20
10       14           20.0                 20
10       15           54.0                 54
10       16           21.0                 21
10       17           44.0                 44
10       18           31.0                 31
10       19           18.0      

Alert:Final episode terminated without completion...
Epoch Episode No. Total rewards     episode Length
12       0           25.0                 25
12       1           55.0                 55
12       2           16.0                 16
12       3           34.0                 34
12       4           67.0                 67
12       5           27.0                 27
12       6           60.0                 60
12       7           19.0                 19
12       8           63.0                 63
12       9           21.0                 21
12       10           32.0                 32
12       11           29.0                 29
12       12           26.0                 26
12       13           53.0                 53
12       14           54.0                 54
12       15           31.0                 31
12       16           69.0                 69
12       17           12.0                 12
12       18           12.0                 12
12       19           75.0      

Alert:Final episode terminated without completion...
Epoch Episode No. Total rewards     episode Length
14       0           48.0                 48
14       1           39.0                 39
14       2           64.0                 64
14       3           36.0                 36
14       4           29.0                 29
14       5           50.0                 50
14       6           58.0                 58
14       7           38.0                 38
14       8           61.0                 61
14       9           70.0                 70
14       10           28.0                 28
14       11           37.0                 37
14       12           19.0                 19
14       13           34.0                 34
14       14           70.0                 70
14       15           49.0                 49
14       16           40.0                 40
14       17           82.0                 82
14       18           29.0                 29
14       19           41.0      

Alert:Final episode terminated without completion...
Epoch Episode No. Total rewards     episode Length
16       0           82.0                 82
16       1           24.0                 24
16       2           30.0                 30
16       3           39.0                 39
16       4           14.0                 14
16       5           36.0                 36
16       6           106.0                 106
16       7           25.0                 25
16       8           75.0                 75
16       9           63.0                 63
16       10           72.0                 72
16       11           24.0                 24
16       12           35.0                 35
16       13           24.0                 24
16       14           33.0                 33
16       15           37.0                 37
16       16           29.0                 29
16       17           47.0                 47
16       18           23.0                 23
16       19           46.0    

Alert:Final episode terminated without completion...
Epoch Episode No. Total rewards     episode Length
18       0           11.0                 11
18       1           55.0                 55
18       2           47.0                 47
18       3           81.0                 81
18       4           64.0                 64
18       5           51.0                 51
18       6           81.0                 81
18       7           41.0                 41
18       8           26.0                 26
18       9           103.0                 103
18       10           30.0                 30
18       11           28.0                 28
18       12           69.0                 69
18       13           37.0                 37
18       14           87.0                 87
18       15           23.0                 23
18       16           36.0                 36
18       17           17.0                 17
18       18           102.0                 102
18       19           35.0  