In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from collections import deque
import random
import sys

sys.path.append('../../')
sys.path.append('game/')

import numpy as np

from base_network import BaseNetwork
from keras.models import load_model
from snake_env import SnakeEnvironment

EP = 100000

class Config:
    height = 20
    width = 30
    action_num = 5
    lr = 3e-5
    eps = 1.
    eps_min = 0.1
    eps_decay = 0.999
    gamma = 0.95

class DQAgent():
    def __init__(self, config):
        self.config = config
        self.memory = deque(maxlen=20000)

        # model
        self.q = BaseNetwork(config.height, config.width, config.action_num, config) 
        # target model
        self.qt = BaseNetwork(config.height, config.width, config.action_num, config)


    def act(self, state):
        if random.uniform(0, 1) < self.config.eps:
            return random.randrange(self.config.action_num)
        else:
            print('predict')
            return np.argmax(self.q.model.predict(self.rescale_color(state))[0])

    def replay(self, batch_size):
        print('do_replay')
        if batch_size > len(self.memory):
            return
        minibatch = random.sample(self.memory, batch_size)
        states = []
        targets = []
        for state, action, reward, next_state, done in minibatch:
            #print(state, action, reward, next_state, done)
            inpu = self.rescale_color(state)
            target = self.q.model.predict(inpu)
            if done:
                target[0][action] = reward
            else:
                inp = self.rescale_color(next_state)
                a = self.q.model.predict(inp)[0]
                t = self.qt.model.predict(inp)[0]
                target[0][action] = reward + self.config.gamma * t[np.argmax(a)]
            self.q.model.fit(inpu, target, epochs=1, verbose=0)
#             states.append(inpu)
#             targets.append(targets)
        print('set target')
#         self.q.model.fit(np.array(states), np.array(targets), epochs=1, verbose=0)
        print('finish replay')
        if self.config.epsilon > self.config.epsilon_min:
            self.config.epsilon *= self.config.epsilon_decay
        

    def rescale_color(self, state):
        return state / 3.
        

    def copy_param(self):
        self.qt.model.set_weights(self.q.model.get_weights())

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def save_model(self, path):
        self.q.model.save(path)
    
    def load_model(self, path):
        self.q.model = load_model(path)

def train():
    config = Config()
    env = SnakeEnvironment(config.width, config.height)
    agent = DQAgent(config)

    done = False
    batch_size = 32

    agent.copy_param()

    for i in range(EP):
        state = env.reset()
        
        state = np.expand_dims(state, axis=0)
        state = np.stack([state for _ in range(2)], axis=-1)
        #print(state.shape)
        done = False
        point = 0
        for t in range(5000):
            action = agent.act(state)
            tmp_state, reward, done = env.act(action)

            next_state = np.roll(state, -1, axis=-1)
            next_state[:,:,:,-1] = tmp_state

            point += reward

            agent.remember(state, action, reward, next_state, done)
            state = next_state.copy()
            if done:
                print('episode {}/{}, score: {}'.format(i, EP, point))

                # Print counted frames
                print(t)
                agent.copy_param()
                break
            
        agent.replay(batch_size)
        if i % 100 == 0:
            agent.save_model('model/snake.h5')

In [None]:
train()

episode 0/100000, score: -5
29
do_replay
episode 1/100000, score: -5
28
do_replay


In [6]:
config = Config()
env = SnakeEnvironment(config.width, config.height)
agent = DQAgent(config)

In [7]:
agent.q.model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_5 (Conv2D)            (None, 16, 26, 16)        816       
_________________________________________________________________
activation_7 (Activation)    (None, 16, 26, 16)        0         
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 14, 24, 32)        4640      
_________________________________________________________________
activation_8 (Activation)    (None, 14, 24, 32)        0         
_________________________________________________________________
global_average_pooling2d_1 ( (None, 32)                0         
_________________________________________________________________
dense_5 (Dense)              (None, 128)               4224      
_________________________________________________________________
activation_9 (Activation)    (None, 128)               0         
__________

In [None]:
done = False
batch_size = 32

agent.copy_param()

for i in range(EP):
    state = env.reset()

    state = np.expand_dims(state, axis=0)
    state = np.stack([state for _ in range(2)], axis=-1)
    #print(state.shape)
    done = False
    point = 0
    for t in range(5000):
        action = agent.act(state)
        tmp_state, reward, done = env.act(action)

        next_state = np.roll(state, -1, axis=-1)
        next_state[:,:,:,-1] = tmp_state

        point += reward

        agent.remember(state, action, reward, next_state, done)
        state = next_state.copy()
        if done:
            print('episode {}/{}, score: {}'.format(i, EP, point))

            # Print counted frames
            print(t)
            agent.copy_param()
            break

    agent.replay(batch_size)
    if i % 100 == 0:
        agent.save_model('model/snake.h5')

episode 0/100000, score: -5
14
do_replay
episode 1/100000, score: -5
22
do_replay
