In [1]:
import random
import gym
import numpy as np
from collections import deque
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import Adam

EPISODES = 1000

Using TensorFlow backend.


In [2]:
class DQNAgent:
    def __init__(self, state_shape, action_shape):
        self.state_shape = state_shape
        self.action_shape = action_shape
        self.memory = deque(maxlen = 2000)
        self.gamma = 0.9
        self.epsilon = 1.0
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        self.learning_rate = 0.001
        self.model = self._build_model()
        print("agent input shape: {} action shape: {}".format(self.state_shape, self.action_shape))
    
    def _build_model(self):
        model = Sequential()
        model.add(Dense(256, input_shape=self.state_shape, activation='relu'))
        model.add(Dense(256, activation='relu'))
        model.add(Dense(self.action_shape, activation='linear'))
        model.compile(loss='mse', optimizer=Adam(lr=self.learning_rate))
        return model
    
    def avg_q(self):
        samples = random.sample(self.memory, min(200, len(self.memory)))
        return sum([np.amax(self.predict(next_state)) for _, _, _, next_state, _ in samples])
    
    def predict(self, state):
        return self.model.predict(state[None,:])[0]
    
    def fit(self, state, target):
        self.model.fit(state[None,:], target[None,:], epochs=1, verbose=0)

    def act(self, state):
        if np.random.rand() <= self.epsilon:
            action = random.randrange(self.action_shape)
        else:
            action = np.argmax(self.predict(state))
        return action
    
    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))
    
    def replay(self, batch_size):
        if len(self.memory) < batch_size: return
        minibatch = random.sample(self.memory, batch_size)
        for state, action, reward, next_state, done in minibatch:
            target = reward
            if not done:
                target = (reward + self.gamma * np.amax(self.predict(next_state)))
            target_f = self.predict(state)
            target_f[action] = target
            self.fit(state, target_f)
            self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)

In [3]:
env = gym.make('CartPole-v1')
agent = DQNAgent(env.observation_space.shape, env.action_space.n)
batch_size = 32

for en in range(EPISODES):
    state = env.reset()
    time = 0
    done = False
    score = 0
    while not done:
        time += 1
        #env.render()
        action = agent.act(state)
        next_state, reward, done, _ = env.step(action)
        #reward = reward if not done else -10
        score += reward
        agent.remember(state, action, reward, next_state, done)
        state = next_state
        agent.replay(batch_size)
    print("episode: {}/{}, time: {}, score: {}, e: {:.2} avgQ: {:12.0}".format(en, EPISODES, time, score, agent.epsilon, agent.avg_q()))

env.close()

[2017-09-13 21:50:30,716] Making new env: CartPole-v1


agent input shape: (4,) action shape: 2
episode: 0/1000, time: 18, score: 18.0, e: 1.0 avgQ:        1e+00
episode: 1/1000, time: 31, score: 31.0, e: 0.056 avgQ:        7e+05
episode: 2/1000, time: 8, score: 8.0, e: 0.015 avgQ:        8e+05
episode: 3/1000, time: 9, score: 9.0, e: 0.01 avgQ:        8e+05
episode: 4/1000, time: 9, score: 9.0, e: 0.01 avgQ:        8e+05
episode: 5/1000, time: 10, score: 10.0, e: 0.01 avgQ:        7e+05
episode: 6/1000, time: 9, score: 9.0, e: 0.01 avgQ:        6e+05
episode: 7/1000, time: 8, score: 8.0, e: 0.01 avgQ:        5e+05
episode: 8/1000, time: 9, score: 9.0, e: 0.01 avgQ:        4e+05
episode: 9/1000, time: 12, score: 12.0, e: 0.01 avgQ:        3e+05
episode: 10/1000, time: 8, score: 8.0, e: 0.01 avgQ:        3e+05
episode: 11/1000, time: 10, score: 10.0, e: 0.01 avgQ:        2e+05
episode: 12/1000, time: 11, score: 11.0, e: 0.01 avgQ:        1e+05
episode: 13/1000, time: 39, score: 39.0, e: 0.01 avgQ:        2e+04
episode: 14/1000, time: 14, sco

KeyboardInterrupt: 