<a href="https://colab.research.google.com/github/enakai00/rl_book_solutions/blob/master/DQN/walk_game_dqn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%tensorflow_version 2.x 

TensorFlow 2.x selected.


In [0]:
import numpy as np
from time import sleep
import random
import copy

import tensorflow as tf
from tensorflow.keras import layers, models, initializers

from IPython.display import clear_output

In [0]:
field_img = '''
##############
#            #
#            #
#            #
#            #
#            #
#            #
#            #
#            #
#            #
#            #
#            #
#            #
##############
'''
def get_field(field_img):
  x_max = max(map(len, field_img.split('\n')))
  field = []
  for line in field_img.split('\n'):
    if line == '':
      continue 
    line += ' ' * x_max
    field.append(list(line)[:x_max])

  return np.array(field)

In [0]:
class Environ:
  def __init__(self): 
    self.actions = [(0, 1), (1, 0), (0, -1), (-1, 0)]
    self.model = self.build_model()
    self.restart()
    self.experience_memory = []


  def restart(self):
    self.field = get_field(field_img)
    for _ in range(10):
      y = random.randint(1, len(self.field)-2)
      x = random.randint(1, len(self.field[0])-2)
      self.field[y][x] = 'x'
    self.x = random.randint(1, len(self.field)-2)
    self.y = random.randint(1, len(self.field[0])-2)
    self.length = 1


  def build_model(self):
    cnn_input = layers.Input(shape=(14, 14, 2), name='cnn_input')
    cnn = layers.Conv2D(8, (5, 5), padding='same',
                        input_shape=(14, 14, 2),
                        kernel_initializer=initializers.TruncatedNormal(),
                        use_bias=True, activation='relu',
                        name='conv_filter')(cnn_input)
    cnn = layers.Flatten(name='flatten')(cnn)
    action_input = layers.Input(shape=(len(self.actions),), name='action_input')

    combined = layers.concatenate([cnn, action_input], name='combined')
    combined = layers.Dense(2048, activation='relu',
                        kernel_initializer=initializers.TruncatedNormal(),
                        name='hidden1')(combined)
    combined = layers.Dense(1024, activation='relu',
                        kernel_initializer=initializers.TruncatedNormal(),
                        name='hidden2')(combined)
    q_value = layers.Dense(1, activation='relu', name='q_value')(combined)

    model = models.Model(inputs=[cnn_input, action_input], outputs=q_value)
    model.compile(optimization='adam', loss='mse')
    return model


  def add_experience(self, pre_state, action, reward, post_state):
    self.experience_memory.append((pre_state, action, reward, post_state))
    if len(self.experience_memory) > 10000:
      i = random.randint(0, int(len(self.experience_memory)/2))
      self.experience_memory.pop(i)


  def train_model(self, samples=300):
    if len(self.experience_memory) < samples:
      return
    examples = self.experience_memory[-samples:]
    examples += random.sample(self.experience_memory, samples)
    random.shuffle(examples)

    states, actions, labels = [], [], []
    for pre_state, action, reward, post_state in examples:
      states.append(np.array(pre_state))
      action_onehot = np.zeros(len(self.actions))
      action_onehot[action] = 1
      actions.append(action_onehot)
      if not post_state:  # Terminal state
        next_q = 0
      else:
        _, next_q = self.get_optimal_action(post_state)
      labels.append(np.array(reward + next_q))
    self.model.fit([states, actions], np.array(labels),
                    batch_size=50, epochs=100)


  def get_optimal_action(self, state=None):
    if not state:
      state = self.get_state()
    states, actions = [], []
    for action in range(len(self.actions)):
      action_onehot = np.zeros(len(self.actions))
      action_onehot[action] = 1
      actions.append(action_onehot)
      states.append(np.array(state))
  
    q_values = self.model.predict([states, actions])
    optimal_action = np.argmax(q_values)
    return optimal_action, q_values[optimal_action][0]


  def get_optimal_action_with_search(self):
    q_values = []
    for action in range(len(self.actions)):
      _field = copy.copy(self.field)
      _x, _y = self.x, self.y
      _length = self.length
      result = self.move(action)
      if not result:
        q_values.append(-10**10)
      else:
        _, optimal_q_value = self.get_optimal_action()
        q_values.append(optimal_q_value)
      self.field = _field
      self.x, self.y = _x, _y
      self.length = _length

    optimal_action = np.argmax(q_values)
    return optimal_action


  def get_action(self, epsilon):
    if random.random() < epsilon:
      action = random.randint(0, len(self.actions)-1)
    else:
      action, _ = self.get_optimal_action()
    return action


  def get_state(self):
    walls = [[0.0 if c == ' ' else 1.0 for c in line] for line in self.field]
    walker = np.zeros((14, 14))
    walker[self.y][self.x] = 1.0
    state = np.zeros((14, 14, 2))
    state[:, :, 0] = walls
    state[:, :, 1] = walker
    return state.tolist()


  def show_environment(self):
    display = copy.copy(self.field)
    display[self.y][self.x] = '*'
    sleep(0.2)
    clear_output(wait=True)
    for line in display:
      print(''.join(line))
    print('length: {}'.format(self.length))


  def move(self, action):
    dx, dy = self.actions[action]
    self.field[self.y][self.x] = '+'
    self.x += dx
    self.y += dy
    self.length += 1
    if self.field[self.y][self.x] != ' ':
      return False
    return True

In [0]:
def trial(env, num=200, epsilon=0.1, eval=False, search=False):
  env.restart()
  if eval:
    epsilon = 0.0
  for _ in range(num):
    pre_state = env.get_state()
    if search:
      action = env.get_optimal_action_with_search()
    else:
      action = env.get_action(epsilon)
    result = env.move(action)
    if eval:
      env.show_environment()
      if not result:
        return
    else:
      if result:
        r = 1.0
        post_state = env.get_state()
      else:
        r = 0.0
        post_state = None # Terminal state
      env.add_experience(pre_state, action, r, post_state)
      if not result:
        env.restart()
      print('.', end='')

In [6]:
env = Environ()
env.model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
cnn_input (InputLayer)          [(None, 14, 14, 2)]  0                                            
__________________________________________________________________________________________________
conv_filter (Conv2D)            (None, 14, 14, 8)    408         cnn_input[0][0]                  
__________________________________________________________________________________________________
flatten (Flatten)               (None, 1568)         0           conv_filter[0][0]                
__________________________________________________________________________________________________
action_input (InputLayer)       [(None, 4)]          0                                            
______________________________________________________________________________________________

In [7]:
from google.colab import drive
drive.mount('/content/gdrive')
env.model = models.load_model('/content/gdrive/My Drive/dqn_cnn.hd5')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [8]:
for i in range(25):
  print('Iteration: {}'.format(i+1))
  trial(env, 200, epsilon=0.2)
  env.train_model(200)
  trial(env, 200, eval=True)

env.model.save('/content/gdrive/My Drive/dqn_cnn.hd5', save_format='h5')

##############
#  x         #
#            #
#     x    x #
#        x   #
#    ++      #
#x x+++     x#
#++++ + ++   #
#+    + ++++ #
#     +*+  +x#
#x     ++ ++ #
#       +x+  #
#       +++  #
##############
length: 32


In [9]:
trial(env, eval=True, search=False)

##############
#            #
#          ++#
#     x +++++#
#      ++x  +#
# x   ++     #
#     ++     #
#      +     #
#      *    x#
#    x     x #
#            #
#x           #
#  xx        #
##############
length: 16


In [10]:
trial(env, eval=True, search=True)

##############
#+           #
#+        x  #
#+   x       #
#+++   x +++ #
#  +++  ++ ++#
#    +x +   +#
#  x ++ +++++#
#     +++++++#
#      x+++++#
#   x   +*+x #
#        ++  #
#   x    ++  #
##############
length: 46
