In [1]:
'''
based on :
https://github.com/lbbc1117/ClassicControlDQN
https://github.com/EderSantana/KerasPlaysCatch
'''
%matplotlib inline
import tensorflow as tf
import gym
import numpy as np
import random
import os
import sys
from collections import deque
import matplotlib.pyplot as plt
from skimage import color, transform, exposure

In [2]:
from keras.models import Sequential
from keras.layers import Dense, Convolution2D, Activation
from keras.optimizers import *

Using TensorFlow backend.


In [3]:
GAME = 'CartPole-v0'

In [6]:
class QAgent(object):
    def __init__(self):
        self.game = GAME
        self.out_dir = 'saved-models/' + self.game + '.h5'
        self.env = gym.make(self.game)
        
        self.max_episodes = 3000
        self.action_dim = self.env.action_space.n
        self.observation_dim = 4
        self.gamma = 0.9
        self.init_eps = 1.0
        self.final_eps = 1e-5
        self.eps_decay = 0.95
        self.eps_annel_steps = 10
        self.learning_rate = 1e-4
        self.global_cnt = 0
        
        self.max_experience = 2000
        self.batch_size = 256
        self.test_size = 100
        
        self.memory = deque()
        self.eps = self.init_eps
        self.model = self.build_model()
        self.INPUTS = []
        self.TARGETS = []
        
    def build_model(self):
        model = Sequential()
        model.add(Dense(input_dim=self.observation_dim, output_dim=32, init='glorot_uniform'))
        model.add(Activation('relu'))
        model.add(Dense(output_dim=16))
        model.add(Activation('relu'))
        model.add(Dense(output_dim=self.action_dim))
        model.compile(loss='mse', optimizer='adam')
        print("Finished the build model")
        return model
    
    def remember(self, states):
        self.memory.append(states)
        if len(self.memory) > self.max_experience:
            self.memory.popleft()
            
    def get_minibatch(self, minibatch):
        inputs = deque()
        targets = deque()
        
        for i, memory in enumerate(minibatch):
            state_m, action_m, reward_m, new_state_m, done = memory
            
            old_qval = self.model.predict(np.array([state_m]))
            new_qval = self.model.predict(np.array([new_state_m]))
            maxQ = np.max(new_qval)
            
            y = np.zeros_like(old_qval)
            y = old_qval
            
            if done: #terminal state
                update = reward_m
            else:
                update = reward_m + self.gamma * maxQ
                
            y[0][action_m] = update
            inputs.append(state_m)
            targets.append(y.reshape(-1))
            
            Inputs = np.array(inputs)
            Targets = np.array(targets)
            return Inputs, Targets
        
    def do_action(self, state):
        #e-greedy policy
        if np.random.rand() <= self.eps:
            action = np.random.randint(0, self.action_dim)
        else:
            qval = self.model.predict(np.array([state]), batch_size=self.batch_size)
            action = np.argmax(qval[0])
            
        #Decay epsilon value
        if self.eps > self.final_eps and self.global_cnt % self.eps_annel_steps == 0:
            self.eps *= self.eps_decay
            
        new_state, reward, done, _ = self.env.step(action)
        reward = -1.0 if done else 1.0 #clip reward
        
        return new_state, action, reward, done
    
    def train_model(self):
        try:
            for epoch in xrange(self.max_episodes):
                loss = 0.0
                state = self.env.reset()
                done = False
                self.global_cnt = 0
                
                while not done:
                    self.global_cnt += 1
                    
                    new_state, action, reward, done = self.do_action(state)
                    #Experience Replay
                    self.remember([state, action, reward, new_state, done])
                    if len(self.memory) >= self.batch_size and epoch >= self.max_experience:
                        minibatch = random.sample(self.memory, self.batch_size)
                        inputs, targets = self.get_minibatch(minibatch)
                        
                        self.Inputs = inputs
                        self.Targets = targets
                        
                        loss += self.model.train_on_batch(self.Inputs, self.Targets)
                    state = new_state
                    
                print("{} epoch | cnt: {} | Loss: {} | eps {}"
                     .format(epoch+1, self.global_cnt, loss, self.eps))
                if epoch %100 == 0 and epoch >= self.max_experience:
                    #Save checkpoint
                    self.model.save_weights(self.out_dir, overwrite=True)
            
        except KeyboardInterrupt:
            pass
        except Exception, e:
            print e
        
    def test_model(self):
        if os.path.exists(self.out_dir):
            self.model.load_weights(self.out_dir)
        else:
            self.train_model()
            
        total_reward = 0
        
        for _ in xrange(self.test_size):
            state = self.env.reset()
            done = False
            self.global_cnt = 0
            
            while not done:
                self.global_cnt += 1
                
                qval = self.model.predict(np.array([state]))
                action = np.argmax(qval)
                new_state, reward, done, _ = self.env.step(action)
                total_reward += reward
                state = new_state
                
            avg_reward = total_reward / self.test_size
            print("Avg. Reward : {}".format(avg_reward))
            

In [8]:
game = QAgent()
game.action_dim, game.env.observation_space.shape
game.test_model()

[2016-12-27 09:22:24,424] Making new env: CartPole-v0


Finished the build model
1 epoch | cnt: 16 | Loss: 0.0 | eps 0.95
2 epoch | cnt: 32 | Loss: 0.0 | eps 0.81450625
3 epoch | cnt: 46 | Loss: 0.0 | eps 0.663420431289
4 epoch | cnt: 12 | Loss: 0.0 | eps 0.630249409725
5 epoch | cnt: 11 | Loss: 0.0 | eps 0.598736939238
6 epoch | cnt: 12 | Loss: 0.0 | eps 0.568800092276
7 epoch | cnt: 10 | Loss: 0.0 | eps 0.540360087663
8 epoch | cnt: 12 | Loss: 0.0 | eps 0.51334208328
9 epoch | cnt: 9 | Loss: 0.0 | eps 0.51334208328
10 epoch | cnt: 10 | Loss: 0.0 | eps 0.487674979116
11 epoch | cnt: 8 | Loss: 0.0 | eps 0.487674979116
12 epoch | cnt: 12 | Loss: 0.0 | eps 0.46329123016
13 epoch | cnt: 16 | Loss: 0.0 | eps 0.440126668652
14 epoch | cnt: 10 | Loss: 0.0 | eps 0.418120335219
15 epoch | cnt: 10 | Loss: 0.0 | eps 0.397214318458
16 epoch | cnt: 12 | Loss: 0.0 | eps 0.377353602535
17 epoch | cnt: 12 | Loss: 0.0 | eps 0.358485922409
18 epoch | cnt: 12 | Loss: 0.0 | eps 0.340561626288
19 epoch | cnt: 9 | Loss: 0.0 | eps 0.340561626288
20 epoch | cnt: 