<a href="https://colab.research.google.com/github/enakai00/colab_rlbook/blob/master/Chapter05/04_Walk_Game_with_Search.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
%tensorflow_version 2.x 

TensorFlow 2.x selected.


In [0]:
import numpy as np
import copy, random, time
from tensorflow.keras import layers, models
from IPython.display import clear_output

In [0]:
def get_field():
  field_img = '''
##############
#            #
#            #
#            #
#            #
#            #
#            #
#            #
#            #
#            #
#            #
#            #
#            #
##############
'''
  field = []
  for line in field_img.split('\n'):
    if line == '':
      continue 
    field.append(list(line))

  return field

In [0]:
class Environ:
  def __init__(self):
    self.action_map = [(0, 1), (1, 0), (0, -1), (-1, 0)]
    self.restart()

  def restart(self):
    self.field = get_field()
    for _ in range(10):
      y = np.random.randint(1, 13)
      x = np.random.randint(1, 13)
      self.field[y][x] = 'x'

  def move(self, s, a):
    x, y = s
    dx, dy = self.action_map[a]
    self.field[y][x] = '+'
    x += dx
    y += dy
    s_new = (x, y)
    if self.field[y][x] != ' ':
      return 0, s_new, True   # Reward, Next state, Game over
    return 1, s_new, False    # Reward, Next state, Game over

  def get_state(self, s):
    x, y = s
    walls = [[0.0 if c == ' ' else 1.0 for c in line] for line in self.field]
    walker = np.zeros((14, 14))
    walker[y][x] = 1.0
    state = np.zeros((14, 14, 2))
    state[:, :, 0] = walls
    state[:, :, 1] = walker
    return state.tolist()

In [0]:
class QValue:
  def __init__(self):
    self.model = None

  def get_action(self, state):
    states = []
    actions = []
    for a in range(4):
      states.append(np.array(state))
      action_onehot = np.zeros(4)
      action_onehot[a] = 1
      actions.append(action_onehot)
  
    q_values = self.model.predict([states, actions])
    optimal_action = np.argmax(q_values)
    return optimal_action, q_values[optimal_action][0]

In [0]:
def get_action_with_search(environ, q_value, s):
  update_q_values = []
  for a in range(4):
    field_backup = copy.deepcopy(environ.field)
    r, s_new, game_over = environ.move(s, a)
    if game_over:
      update_q_values.append(r + 0)
    else:
      state_new = environ.get_state(s_new)
      _, q_new = q_value.get_action(state_new)
      update_q_values.append(r + q_new)
    environ.field = field_backup

  optimal_action = np.argmax(update_q_values)
  return optimal_action

In [0]:
def get_episode(environ, q_value):
  trace = []
  environ.restart()
  s = (np.random.randint(1, 13), np.random.randint(1, 13))

  while True:
    trace.append(s)
    a = get_action_with_search(environ, q_value, s)
    _, s_new, game_over = environ.move(s, a)
    if game_over:
      break
    s = s_new

  return trace

In [0]:
  def show_sample(environ, q_value):
    trace = get_episode(environ, q_value)
    display = copy.deepcopy(environ.field)
    display = [[' ' if c == '+' else c for c in line] for line in display]
    for s in trace:
      x, y = s
      display[y][x] = '*'
      time.sleep(0.5)
      clear_output(wait=True)
      for line in display:
        print(''.join(line))
      display[y][x] = '+'

    print('Length: {}'.format(len(trace)))

In [0]:
environ = Environ()
q_value = QValue()

In [0]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [0]:
!ls -l '/content/gdrive/My Drive/walk_game_model.hd5'

-rw------- 1 root root 42608640 Mar  3 00:46 '/content/gdrive/My Drive/walk_game_model.hd5'


In [0]:
q_value.model = models.load_model('/content/gdrive/My Drive/walk_game_model.hd5')
q_value.model.summary()

In [0]:
show_sample(environ, q_value)

##############
#       +++*+#
#x  x   + +++#
#     +++ ++ #
#     +      #
#    x+++++++#
#+ x      +++#
#+ x    +++++#
#+++    +++++#
#  ++++ +++++#
#     +++++  #
# xx x++ ++  #
#x       ++x #
##############
Length: 60
