In [1]:
import numpy as np
import tensorflow as tf
import random
import dqn
from collections import deque
import gym

In [2]:
env = gym.make('CartPole-v0')
env._max_episode_steps = 10001

In [3]:
#Constants defining our neural network
input_size = env.observation_space.shape[0]
output_size = env.action_space.n

dis = 0.9
REPLAY_MEMORY = 50000

In [4]:
def replay_train(mainDQN, targetDQN, train_batch):
    x_stack = np.empty(0).reshape(0, input_size)
    y_stack = np.empty(0).reshape(0, output_size)
    
    #Get stored information from the buffer
    for state, action, reward, next_state, done in train_batch:
        Q = mainDQN.predict(state)
        
        #terminal?
        if done:
            Q[0, action] = reward
        else:
            #Obtain the Q' values by feeding the new state through our network
            Q[0, action] = reward + dis * np.max(targetDQN.predict(next_state))
        
        y_stack = np.vstack([y_stack, Q])
        x_stack = np.vstack([x_stack, state])
        
        #Train our network using target and predicted Q values on each episode
        return mainDQN.update(x_stack, y_stack)

In [5]:
def get_copy_var_ops(dest_scope_name = "target", src_scope_name = "main"):
    #Copy variables src_scope to dest_scope
    op_holder = []
    
    src_vars = tf.get_collection(
        tf.GraphKeys.TRAINABLE_VARIABLES, scope=src_scope_name)
    dest_vars = tf.get_collection(
        tf.GraphKeys.TRAINABLE_VARIABLES, scope=dest_scope_name)
    
    for src_var, dest_var in zip(src_vars, dest_vars):
        op_holder.append(dest_var.assign(src_var.value()))
    
    return op_holder

In [6]:
def bot_play(mainDQN):
    #See our trained network in action
    s = env.reset()
    reward_sum = 0
    while True:
        env.render()
        a = np.argmax(mainDQN.predict(s))
        s, reward, done, _ = env.step(a)
        reward_sum += reward
        if done:
            print("Total score: {}".format(reward_sum))
            break

In [7]:
def main():
    max_episodes = 5000
    
    #store the previous observations in replay memory
    replay_buffer = deque()
    
    with tf.Session() as sess:
        mainDQN = dqn.DQN(sess, input_size, output_size, name="main")
        targetDQN = dqn.DQN(sess, input_size, output_size, name="target")
        tf.global_variables_initializer().run()
        
        #initial copy q_net -> target_net
        copy_ops = get_copy_var_ops(dest_scope_name="target", src_scope_name="main")
        
        sess.run(copy_ops)
        
        for episode in range(max_episodes):
            e = 1. / ((episode / 10) + 1)
            done = False
            step_count = 0
            
            state = env.reset()
            
            while not done:
                if np.random.rand(1) < e:
                    action = env.action_space.sample()
                else:
                    #Choose an action by greedily form the Q-network
                    action = np.argmax(mainDQN.predict(state))
                    
                #Get new state and reward from environment
                next_state, reward, done, _ = env.step(action)
                if done:
                    reward = -100
                
                #Save the experience to our buffer
                replay_buffer.append((state, action, reward, next_state, done))
                if len(replay_buffer) > REPLAY_MEMORY:
                    replay_buffer.popleft()
                
                state = next_state
                step_count += 1
                if step_count > 10000:
                    break
                    
            print("Episode: {}        steps: {}".format(episode, step_count))
            if step_count > 10000:
                pass
                #break
            
            if episode % 10 == 1:
                #Get  a random batch of experiences.
                for _ in range(50):
                    #Minibatch works better
                    minibatch = random.sample(replay_buffer, 10)
                    loss, _ = replay_train(mainDQN, targetDQN, minibatch)
                    print("Loss: ",loss)
        
        bot_play(mainDQN)                    

In [8]:
if __name__ == "__main__":
    main()

Episode: 0        steps: 13
Episode: 1        steps: 17
('Loss: ', 0.46528021)
('Loss: ', 0.53925246)
('Loss: ', 4960.4341)
('Loss: ', 0.47446024)
('Loss: ', 0.030336509)
('Loss: ', 0.29856053)
('Loss: ', 0.73937756)
('Loss: ', 0.93227375)
('Loss: ', 0.38749978)
('Loss: ', 0.10569821)
('Loss: ', 0.047556873)
('Loss: ', 0.0012950462)
('Loss: ', 1.7177192)
('Loss: ', 0.51039028)
('Loss: ', 0.61169785)
('Loss: ', 0.52845019)
('Loss: ', 0.49990529)
('Loss: ', 0.11276781)
('Loss: ', 2.0584664)
('Loss: ', 0.36510217)
('Loss: ', 3.9188857)
('Loss: ', 0.3579866)
('Loss: ', 4.0086384)
('Loss: ', 0.78816509)
('Loss: ', 1.1431792)
('Loss: ', 0.80865306)
('Loss: ', 0.90092742)
('Loss: ', 1.0618119)
('Loss: ', 0.27519664)
('Loss: ', 0.12837166)
('Loss: ', 0.45403719)
('Loss: ', 1.6000298)
('Loss: ', 5.3345704)
('Loss: ', 3.9150434)
('Loss: ', 2.7510431)
('Loss: ', 1.2731987)
('Loss: ', 3.2978559)
('Loss: ', 0.44144937)
('Loss: ', 0.53747255)
('Loss: ', 1.8740286)
('Loss: ', 0.56461972)
('Loss: ', 1

('Loss: ', 4.4024968)
('Loss: ', 0.67893332)
('Loss: ', 0.18687795)
('Loss: ', 0.036447067)
Episode: 52        steps: 12
Episode: 53        steps: 10
Episode: 54        steps: 9
Episode: 55        steps: 10
Episode: 56        steps: 9
Episode: 57        steps: 11
Episode: 58        steps: 9
Episode: 59        steps: 9
Episode: 60        steps: 13
Episode: 61        steps: 11
('Loss: ', 2.6734004)
('Loss: ', 1.7237403)
('Loss: ', 5.7833195)
('Loss: ', 11.347511)
('Loss: ', 6.1293449)
('Loss: ', 11.812668)
('Loss: ', 3.775383)
('Loss: ', 4729.6665)
('Loss: ', 6.8353987)
('Loss: ', 0.31275403)
('Loss: ', 0.5223248)
('Loss: ', 5.3892059)
('Loss: ', 44.165401)
('Loss: ', 7.5694776)
('Loss: ', 6036.4019)
('Loss: ', 1.8661548)
('Loss: ', 14.165524)
('Loss: ', 2.3460581)
('Loss: ', 0.80989456)
('Loss: ', 9.5442238)
('Loss: ', 16.634129)
('Loss: ', 0.33267385)
('Loss: ', 21.590094)
('Loss: ', 20.9498)
('Loss: ', 11.070854)
('Loss: ', 0.37086278)
('Loss: ', 11.181103)
('Loss: ', 3.3471236)
('Los

Episode: 112        steps: 1743
Episode: 113        steps: 139
Episode: 114        steps: 660
Episode: 115        steps: 185
Episode: 116        steps: 337
Episode: 117        steps: 565
Episode: 118        steps: 166
Episode: 119        steps: 619
Episode: 120        steps: 878
Episode: 121        steps: 517
('Loss: ', 28.737015)
('Loss: ', 0.8855347)
('Loss: ', 4.0300679)
('Loss: ', 1.1501297)
('Loss: ', 0.26003352)
('Loss: ', 4.1514449)
('Loss: ', 1.0508929)
('Loss: ', 0.5785026)
('Loss: ', 0.76336622)
('Loss: ', 9.0170288)
('Loss: ', 1.2347795)
('Loss: ', 8.9130481e-05)
('Loss: ', 0.39562297)
('Loss: ', 2.3569379)
('Loss: ', 8.9074125)
('Loss: ', 7.3648248)
('Loss: ', 14.145162)
('Loss: ', 0.51464444)
('Loss: ', 4.0491667)
('Loss: ', 0.96985507)
('Loss: ', 24.36919)
('Loss: ', 10.586269)
('Loss: ', 0.40263414)
('Loss: ', 1.4459136)
('Loss: ', 1.5059735)
('Loss: ', 0.047611997)
('Loss: ', 8.7120447)
('Loss: ', 4556.4009)
('Loss: ', 1.0691772)
('Loss: ', 3.7378244)
('Loss: ', 2.70831

Episode: 172        steps: 10001
Episode: 173        steps: 10001
Episode: 174        steps: 10001
Episode: 175        steps: 10001
Episode: 176        steps: 10001
Episode: 177        steps: 10001
Episode: 178        steps: 10001
Episode: 179        steps: 10001
Episode: 180        steps: 10001
Episode: 181        steps: 10001
('Loss: ', 1.304948e-05)
('Loss: ', 0.44731635)
('Loss: ', 0.011473848)
('Loss: ', 0.080390997)
('Loss: ', 0.010300132)
('Loss: ', 0.42682651)
('Loss: ', 0.20294614)
('Loss: ', 0.032735091)
('Loss: ', 0.58278066)
('Loss: ', 0.42391148)
('Loss: ', 0.54257703)
('Loss: ', 0.38001436)
('Loss: ', 0.0020667431)
('Loss: ', 0.60845864)
('Loss: ', 0.4858717)
('Loss: ', 0.10989796)
('Loss: ', 0.10019439)
('Loss: ', 0.027901338)
('Loss: ', 0.47173315)
('Loss: ', 0.046854556)
('Loss: ', 0.45268095)
('Loss: ', 1.1828974)
('Loss: ', 0.21039571)
('Loss: ', 0.049755428)
('Loss: ', 0.0027254792)
('Loss: ', 0.0027767301)
('Loss: ', 0.1480256)
('Loss: ', 0.010719644)
('Loss: ', 0.

Episode: 232        steps: 10001
Episode: 233        steps: 10001
Episode: 234        steps: 10001
Episode: 235        steps: 10001
Episode: 236        steps: 10001
Episode: 237        steps: 10001
Episode: 238        steps: 10001
Episode: 239        steps: 10001
Episode: 240        steps: 10001
Episode: 241        steps: 10001
('Loss: ', 0.092727654)
('Loss: ', 0.051106848)
('Loss: ', 0.03151308)
('Loss: ', 0.045702908)
('Loss: ', 0.076490298)
('Loss: ', 0.073340468)
('Loss: ', 0.087186582)
('Loss: ', 3.0404992e-05)
('Loss: ', 0.034925763)
('Loss: ', 0.16527161)
('Loss: ', 0.086258143)
('Loss: ', 0.0095638325)
('Loss: ', 0.037622496)
('Loss: ', 0.0021539822)
('Loss: ', 0.017554395)
('Loss: ', 0.001634002)
('Loss: ', 0.0011450971)
('Loss: ', 0.0011415457)
('Loss: ', 0.12462266)
('Loss: ', 0.0016645234)
('Loss: ', 0.0027668867)
('Loss: ', 0.008896376)
('Loss: ', 0.0042992649)
('Loss: ', 0.010759067)
('Loss: ', 0.036460809)
('Loss: ', 0.069316663)
('Loss: ', 0.56367809)
('Loss: ', 0.3082

Episode: 292        steps: 10001
Episode: 293        steps: 10001
Episode: 294        steps: 10001
Episode: 295        steps: 10001
Episode: 296        steps: 10001
Episode: 297        steps: 10001
Episode: 298        steps: 10001
Episode: 299        steps: 10001
Episode: 300        steps: 10001
Episode: 301        steps: 10001
('Loss: ', 0.37895861)
('Loss: ', 0.026622374)
('Loss: ', 0.0090832161)
('Loss: ', 0.024725089)
('Loss: ', 0.0063479836)
('Loss: ', 0.6838944)
('Loss: ', 0.86149555)
('Loss: ', 0.0015280035)
('Loss: ', 0.36270791)
('Loss: ', 0.0019783741)
('Loss: ', 0.0033697479)
('Loss: ', 0.27872157)
('Loss: ', 0.29621083)
('Loss: ', 2.9465562e-05)
('Loss: ', 0.012810487)
('Loss: ', 0.0014971417)
('Loss: ', 0.42357716)
('Loss: ', 0.23491457)
('Loss: ', 0.0085368818)
('Loss: ', 0.0091546476)
('Loss: ', 0.0044867191)
('Loss: ', 0.11777085)
('Loss: ', 0.003332078)
('Loss: ', 0.11310062)
('Loss: ', 0.045163356)
('Loss: ', 0.14160049)
('Loss: ', 0.02946987)
('Loss: ', 0.025017628)


Episode: 352        steps: 10001
Episode: 353        steps: 10001
Episode: 354        steps: 10001
Episode: 355        steps: 10001
Episode: 356        steps: 10001
Episode: 357        steps: 10001
Episode: 358        steps: 10001
Episode: 359        steps: 10001
Episode: 360        steps: 10001
Episode: 361        steps: 10001
('Loss: ', 0.001220058)
('Loss: ', 0.054769717)
('Loss: ', 1.0343516)
('Loss: ', 1.0909022)
('Loss: ', 0.18828717)
('Loss: ', 0.039212558)
('Loss: ', 0.14758785)
('Loss: ', 0.96835959)
('Loss: ', 0.93018305)
('Loss: ', 3.98095e-08)
('Loss: ', 0.05815573)
('Loss: ', 0.039621104)
('Loss: ', 0.019804889)
('Loss: ', 0.0068607149)
('Loss: ', 0.032268584)
('Loss: ', 0.022143433)
('Loss: ', 0.79023993)
('Loss: ', 0.20344596)
('Loss: ', 0.0194344)
('Loss: ', 0.44297647)
('Loss: ', 0.037398737)
('Loss: ', 0.75357431)
('Loss: ', 0.58823442)
('Loss: ', 0.37165126)
('Loss: ', 0.59200114)
('Loss: ', 0.0056137745)
('Loss: ', 0.52746487)
('Loss: ', 0.034330238)
('Loss: ', 0.05

Episode: 412        steps: 10001
Episode: 413        steps: 10001
Episode: 414        steps: 10001
Episode: 415        steps: 10001
Episode: 416        steps: 10001
Episode: 417        steps: 10001
Episode: 418        steps: 10001
Episode: 419        steps: 10001
Episode: 420        steps: 10001
Episode: 421        steps: 10001
('Loss: ', 0.052623522)
('Loss: ', 0.50184166)
('Loss: ', 0.58234984)
('Loss: ', 0.04004389)
('Loss: ', 0.49305081)
('Loss: ', 0.48594892)
('Loss: ', 0.67742485)
('Loss: ', 0.029187175)
('Loss: ', 0.57735169)
('Loss: ', 0.027060192)
('Loss: ', 0.55630958)
('Loss: ', 0.0065815616)
('Loss: ', 0.53130364)
('Loss: ', 0.00056723418)
('Loss: ', 0.47811514)
('Loss: ', 0.0023397792)
('Loss: ', 0.0038416949)
('Loss: ', 0.41117018)
('Loss: ', 0.62835515)
('Loss: ', 0.22248076)
('Loss: ', 0.61020207)
('Loss: ', 0.0050750561)
('Loss: ', 0.6851849)
('Loss: ', 0.031456482)
('Loss: ', 0.58218098)
('Loss: ', 0.45993835)
('Loss: ', 0.21960591)
('Loss: ', 0.418093)
('Loss: ', 0.3

Episode: 472        steps: 766
Episode: 473        steps: 1791
Episode: 474        steps: 764
Episode: 475        steps: 786
Episode: 476        steps: 710
Episode: 477        steps: 505
Episode: 478        steps: 518
Episode: 479        steps: 634
Episode: 480        steps: 1392
Episode: 481        steps: 498
('Loss: ', 0.03415798)
('Loss: ', 0.49913254)
('Loss: ', 0.05037285)
('Loss: ', 0.24945259)
('Loss: ', 0.019741533)
('Loss: ', 0.32094127)
('Loss: ', 0.074836172)
('Loss: ', 0.028335884)
('Loss: ', 0.32032421)
('Loss: ', 0.76931685)
('Loss: ', 0.0086978236)
('Loss: ', 0.019885601)
('Loss: ', 0.24585766)
('Loss: ', 0.27319729)
('Loss: ', 0.6756143)
('Loss: ', 0.0097838165)
('Loss: ', 0.1499231)
('Loss: ', 0.19415052)
('Loss: ', 0.060616337)
('Loss: ', 0.20653629)
('Loss: ', 0.26908624)
('Loss: ', 0.016601702)
('Loss: ', 0.20652588)
('Loss: ', 0.84311652)
('Loss: ', 0.23240973)
('Loss: ', 0.065854572)
('Loss: ', 0.56436181)
('Loss: ', 0.39351025)
('Loss: ', 0.030042678)
('Loss: ', 

Episode: 532        steps: 968
Episode: 533        steps: 982
Episode: 534        steps: 807
Episode: 535        steps: 833
Episode: 536        steps: 410
Episode: 537        steps: 1566
Episode: 538        steps: 946
Episode: 539        steps: 798
Episode: 540        steps: 308
Episode: 541        steps: 806
('Loss: ', 0.00012253829)
('Loss: ', 0.47348553)
('Loss: ', 0.5550229)
('Loss: ', 0.3239916)
('Loss: ', 0.71093899)
('Loss: ', 0.13926774)
('Loss: ', 0.29929519)
('Loss: ', 0.073108055)
('Loss: ', 0.57554471)
('Loss: ', 0.016242526)
('Loss: ', 0.38782275)
('Loss: ', 0.085111998)
('Loss: ', 0.79201478)
('Loss: ', 0.44820991)
('Loss: ', 0.0083311843)
('Loss: ', 0.00075114594)
('Loss: ', 0.13961647)
('Loss: ', 0.83451164)
('Loss: ', 0.09603665)
('Loss: ', 0.0099802967)
('Loss: ', 0.071357474)
('Loss: ', 0.019626215)
('Loss: ', 0.11070352)
('Loss: ', 0.0094102025)
('Loss: ', 0.6638633)
('Loss: ', 0.10045123)
('Loss: ', 0.00051909511)
('Loss: ', 0.056965247)
('Loss: ', 0.094642088)
('L

Episode: 592        steps: 6949
Episode: 593        steps: 10001
Episode: 594        steps: 4595
Episode: 595        steps: 2616
Episode: 596        steps: 10001
Episode: 597        steps: 1109
Episode: 598        steps: 795
Episode: 599        steps: 10001
Episode: 600        steps: 10001
Episode: 601        steps: 1381
('Loss: ', 0.35155085)
('Loss: ', 0.33036092)
('Loss: ', 0.011624894)
('Loss: ', 0.18352494)
('Loss: ', 0.1091932)
('Loss: ', 0.98731464)
('Loss: ', 0.0045122877)
('Loss: ', 0.032326955)
('Loss: ', 0.01373784)
('Loss: ', 0.30039674)
('Loss: ', 0.0038034522)
('Loss: ', 0.65816903)
('Loss: ', 0.00047294734)
('Loss: ', 0.12555207)
('Loss: ', 0.20031492)
('Loss: ', 0.88321048)
('Loss: ', 0.70316511)
('Loss: ', 0.11278941)
('Loss: ', 1.5356293)
('Loss: ', 0.88623971)
('Loss: ', 0.15707332)
('Loss: ', 0.76093012)
('Loss: ', 0.00027083547)
('Loss: ', 0.53081471)
('Loss: ', 0.12299097)
('Loss: ', 0.083135962)
('Loss: ', 0.581011)
('Loss: ', 0.26867336)
('Loss: ', 0.021328153)


Episode: 652        steps: 10001
Episode: 653        steps: 10001
Episode: 654        steps: 10001
Episode: 655        steps: 10001
Episode: 656        steps: 10001
Episode: 657        steps: 10001
Episode: 658        steps: 10001
Episode: 659        steps: 10001
Episode: 660        steps: 10001
Episode: 661        steps: 10001
('Loss: ', 0.29274464)
('Loss: ', 0.26995867)
('Loss: ', 0.59215754)
('Loss: ', 0.77400076)
('Loss: ', 0.051303159)
('Loss: ', 0.068150662)
('Loss: ', 0.98157793)
('Loss: ', 0.58497101)
('Loss: ', 0.0077967425)
('Loss: ', 0.034558751)
('Loss: ', 0.34754255)
('Loss: ', 0.80622303)
('Loss: ', 0.36750171)
('Loss: ', 0.61340654)
('Loss: ', 0.037052192)
('Loss: ', 3.1344531e-05)
('Loss: ', 0.14605494)
('Loss: ', 0.4611508)
('Loss: ', 0.22464789)
('Loss: ', 0.22092305)
('Loss: ', 0.092038535)
('Loss: ', 0.11349434)
('Loss: ', 0.10969207)
('Loss: ', 0.3126069)
('Loss: ', 0.92673731)
('Loss: ', 0.024303723)
('Loss: ', 0.098141544)
('Loss: ', 0.078009337)
('Loss: ', 0.50

KeyboardInterrupt: 