In [1]:
class DQN:
    def __init__(self, session, input_size, output_size, name="main"):
        self.session = session
        self.input_size = input_size
        self.output_size = output_size
        self.net_name = name
        
        self._build_network()
    def _build_network(self, h_size=10, l_rate=1e-1):
        with tf.variable_scope(self.net_name):
            self._X = tf.placeholder(tf.float32, [None, self.input_size], name="input_x")
            W1 = tf.get_variable("W1", shape=[self.input_size, h_size],
                                initializer=tf.contrib.layers.xavier_initializer())
            layer1 = tf.nn.tanh(tf.matmul(self._X, W1))
            W2 = tf.get_variable("W2", shape=[h_size, self.output_size],
                                initializer=tf.contrib.layers.xavier_initializer())
            self._Qpred = tf.matmul(layer1, W2)
            
        self._Y = tf.placeholder(
            shape=[None, self.output_size], dtype=tf.float32)
        self._loss = tf.reduce_mean(tf.square(self._Y - self._Qpred))
        self._train = tf.train.AdamOptimizer(
            learning_rate=l_rate).minimize(self._loss)

    def predict(self, state):
        X = np.reshape(state, [1, self.input_size])
        return self.session.run(self._Qpred, feed_dict={self._X: X})
    
    def update(self, x_stack, y_stack):
        return self.session.run([self._loss, self._train], feed_dict={
            self._X: x_stack, self._Y: y_stack
        })

In [2]:
import numpy as np
import tensorflow as tf
import random
from collections import deque
from gym.envs.registration import register

In [3]:
import gym
env = gym.make('CartPole-v0')
register(
    id='CartPole-v2',
    entry_point='gym.envs.classic_control:CartPoleEnv',
    tags={'wrapper_config.TimeLimit.max_episode_steps': 10000},
    reward_threshold=10000.0,
)

[2017-07-28 17:33:35,375] Making new env: CartPole-v0


In [4]:
input_size = env.observation_space.shape[0]
output_size = env.action_space.n

In [5]:
dis = 0.9
REPLAY_MEMORY = 50000

In [6]:
def simple_replay_train(DQN, train_batch):
    x_stack = np.empty(0).reshape(0, DQN.input_size)
    y_stack = np.empty(0).reshape(0, DQN.output_size)
    
    for state, action, reward, next_state, done in train_batch:
        Q = DQN.predict(state)
        if done:
            Q[0, action] = reward
        else:
            Q[0, action] = reward + dis * np.max(DQN.predict(next_state))
        y_stack = np.vstack([y_stack, Q])
        x_stack = np.vstack([x_stack, state])
    return DQN.update(x_stack, y_stack)

In [7]:
def bot_play(mainDQN):
    s = env.reset()
    reward_sum = 0
    while True:
        env.render()
        a = np.argmax(mainDQN.predict(s))
        s, reward, done, _ = env.step(a)
        reward_sum += reward
        if done:
            print("Total score: {}".format(reward_sum))
            break

In [8]:
def main():
    max_episodes = 5000
    replay_buffer = deque()
    with tf.Session() as sess:
        mainDQN = DQN(sess, input_size, output_size)
        tf.global_variables_initializer().run()
        
        for episode in range(max_episodes):
            e = 1. / ((episode / 10) + 1)
            done = False
            step_count = 0
            
            state = env.reset()
            while not done:
                if np.random.rand(1) < e:
                    action = env.action_space.sample()
                else:
                    action = np.argmax(mainDQN.predict(state))
                next_state, reward, done, _ = env.step(action)
                if done:
                    reward = -100
                
                replay_buffer.append((state, action, reward, next_state, done))
                if len(replay_buffer) > REPLAY_MEMORY:
                    replay_buffer.popleft()
                
                state = next_state
                step_count += 1
                if step_count > 10000:
                    break
            print("Episode: {} steps: {}".format(episode, step_count))
            if step_count > 10000:
                pass
            if episode % 10 ==1:
                for _ in range(50):
                    minibatch = random.sample(replay_buffer, 10)
                    loss, _ = simple_replay_train(mainDQN, minibatch)
        
        bot_play(mainDQN)

In [9]:
main()

Episode: 0 steps: 15
Episode: 1 steps: 21
Episode: 2 steps: 15
Episode: 3 steps: 16
Episode: 4 steps: 12
Episode: 5 steps: 12
Episode: 6 steps: 10
Episode: 7 steps: 13
Episode: 8 steps: 15
Episode: 9 steps: 14
Episode: 10 steps: 10
Episode: 11 steps: 9
Episode: 12 steps: 16
Episode: 13 steps: 22
Episode: 14 steps: 19
Episode: 15 steps: 18
Episode: 16 steps: 18
Episode: 17 steps: 12
Episode: 18 steps: 43
Episode: 19 steps: 12
Episode: 20 steps: 12
Episode: 21 steps: 13
Episode: 22 steps: 19
Episode: 23 steps: 21
Episode: 24 steps: 23
Episode: 25 steps: 14
Episode: 26 steps: 23
Episode: 27 steps: 13
Episode: 28 steps: 28
Episode: 29 steps: 14
Episode: 30 steps: 34
Episode: 31 steps: 31
Episode: 32 steps: 25
Episode: 33 steps: 25
Episode: 34 steps: 36
Episode: 35 steps: 20
Episode: 36 steps: 30
Episode: 37 steps: 25
Episode: 38 steps: 26
Episode: 39 steps: 22
Episode: 40 steps: 20
Episode: 41 steps: 20
Episode: 42 steps: 32
Episode: 43 steps: 40
Episode: 44 steps: 30
Episode: 45 steps: 37

Episode: 359 steps: 200
Episode: 360 steps: 200
Episode: 361 steps: 128
Episode: 362 steps: 10
Episode: 363 steps: 11
Episode: 364 steps: 10
Episode: 365 steps: 9
Episode: 366 steps: 10
Episode: 367 steps: 10
Episode: 368 steps: 9
Episode: 369 steps: 9
Episode: 370 steps: 10
Episode: 371 steps: 10
Episode: 372 steps: 200
Episode: 373 steps: 200
Episode: 374 steps: 195
Episode: 375 steps: 87
Episode: 376 steps: 200
Episode: 377 steps: 200
Episode: 378 steps: 187
Episode: 379 steps: 200
Episode: 380 steps: 200
Episode: 381 steps: 200
Episode: 382 steps: 200
Episode: 383 steps: 189
Episode: 384 steps: 177
Episode: 385 steps: 130
Episode: 386 steps: 101
Episode: 387 steps: 200
Episode: 388 steps: 200
Episode: 389 steps: 200
Episode: 390 steps: 105
Episode: 391 steps: 139
Episode: 392 steps: 13
Episode: 393 steps: 92
Episode: 394 steps: 31
Episode: 395 steps: 24
Episode: 396 steps: 117
Episode: 397 steps: 12
Episode: 398 steps: 10
Episode: 399 steps: 50
Episode: 400 steps: 200
Episode: 401 

Episode: 722 steps: 172
Episode: 723 steps: 174
Episode: 724 steps: 167
Episode: 725 steps: 200
Episode: 726 steps: 200
Episode: 727 steps: 90
Episode: 728 steps: 197
Episode: 729 steps: 189
Episode: 730 steps: 200
Episode: 731 steps: 199
Episode: 732 steps: 200
Episode: 733 steps: 130
Episode: 734 steps: 44
Episode: 735 steps: 95
Episode: 736 steps: 52
Episode: 737 steps: 102
Episode: 738 steps: 45
Episode: 739 steps: 46
Episode: 740 steps: 70
Episode: 741 steps: 134
Episode: 742 steps: 50
Episode: 743 steps: 80
Episode: 744 steps: 84
Episode: 745 steps: 77
Episode: 746 steps: 163
Episode: 747 steps: 50
Episode: 748 steps: 40
Episode: 749 steps: 121
Episode: 750 steps: 183
Episode: 751 steps: 93
Episode: 752 steps: 45
Episode: 753 steps: 79
Episode: 754 steps: 98
Episode: 755 steps: 71
Episode: 756 steps: 40
Episode: 757 steps: 40
Episode: 758 steps: 51
Episode: 759 steps: 49
Episode: 760 steps: 53
Episode: 761 steps: 64
Episode: 762 steps: 23
Episode: 763 steps: 35
Episode: 764 steps

Episode: 1076 steps: 200
Episode: 1077 steps: 200
Episode: 1078 steps: 74
Episode: 1079 steps: 200
Episode: 1080 steps: 78
Episode: 1081 steps: 134
Episode: 1082 steps: 32
Episode: 1083 steps: 38
Episode: 1084 steps: 36
Episode: 1085 steps: 40
Episode: 1086 steps: 81
Episode: 1087 steps: 103
Episode: 1088 steps: 31
Episode: 1089 steps: 40
Episode: 1090 steps: 34
Episode: 1091 steps: 37
Episode: 1092 steps: 76
Episode: 1093 steps: 46
Episode: 1094 steps: 50
Episode: 1095 steps: 47
Episode: 1096 steps: 83
Episode: 1097 steps: 31
Episode: 1098 steps: 47
Episode: 1099 steps: 33
Episode: 1100 steps: 54
Episode: 1101 steps: 36
Episode: 1102 steps: 20
Episode: 1103 steps: 24
Episode: 1104 steps: 42
Episode: 1105 steps: 36
Episode: 1106 steps: 60
Episode: 1107 steps: 40
Episode: 1108 steps: 23
Episode: 1109 steps: 22
Episode: 1110 steps: 36
Episode: 1111 steps: 52
Episode: 1112 steps: 9
Episode: 1113 steps: 8
Episode: 1114 steps: 8
Episode: 1115 steps: 15
Episode: 1116 steps: 9
Episode: 1117 s

KeyboardInterrupt: 