<a href="https://colab.research.google.com/github/enakai00/colab_rlbook/blob/master/Chapter05/02_Neural_Network_Policy_Estimation_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%tensorflow_version 2.x 

TensorFlow 2.x selected.


In [0]:
import numpy as np
from tensorflow.keras import layers, models

In [0]:
class Gridworld:
  def __init__(self, size=6, traps=[(4, y) for y in range(4)]):
    self.size = size    
    self.traps = traps
    self.start = (0, 0)
    self.goals = [(size-1, size-1)]
    self.states = [(x, y) for x in range(size) for y in range(size)]

  def move(self, s, a):
    if s in self.goals:
      return 0, s             # Reward, Next state

    s_new = (s[0] + a[0], s[1] + a[1])

    if s_new not in self.states:
      return 0, s             # Reward, Next state

    if s_new in self.traps:
      return -1, self.start   # Reward, Next state

    return -1, s_new          # Reward, Next state

In [0]:
class StateValue:
  def __init__(self, goals):
    self.goals = goals
    self.model = self.build_model()

  def build_model(self):
    state = layers.Input(shape=(2,))
    hidden1 = layers.Dense(16, activation='relu')(state)
    hidden2 = layers.Dense(8, activation='relu')(hidden1)
    value = layers.Dense(1)(hidden2)
    model = models.Model(inputs=[state], outputs=[value])
    model.compile(loss='mse')
    return model

  def get_value(self, s):
    if s in self.goals:
      return 0
    input_states = [np.array(s)]
    output_values = self.model.predict([input_states])
    value = output_values[0][0]
    return value

In [0]:
def show_values(world, state_value):
  for y in range(world.size):
    print('[ ', end='')
    for x in range(world.size):
      if (x, y) in world.traps:
        print('     ', end=' ')
      else:
        print('{:5.1f}'.format(state_value.get_value((x, y))), end=' ')
    print(']')
  print()

In [0]:
def get_episode(world):
  episode = []
  while True:
    s = (np.random.randint(world.size), np.random.randint(world.size))
    if s not in world.traps + world.goals:
      break

  while True:
    if np.random.random() < 0.5:
      a = (1, 0)
    else:
      a = (0, 1)
    r, s_new = world.move(s, a)
    episode.append((s, r, s_new))
    if s_new in world.goals:
      break
    s = s_new

  return episode

In [0]:
def train(world, state_value, num):
  for c in range(num):
    print('Iteration {:2d}: '.format(c+1))

    examples = []
    for _ in range(100):
      episode = get_episode(world)
      examples += get_episode(world)
    np.random.shuffle(examples)

    states = []
    labels = []
    for s, r, s_new in examples:
      states.append(np.array(s))
      v_new = state_value.get_value(s_new)
      labels.append(np.array(r + v_new))

    state_value.model.fit(np.array(states), np.array(labels),
                          batch_size=50, epochs=100, verbose=0)
    show_values(world, state_value)

In [8]:
world = Gridworld()
state_value = StateValue(world.goals)
state_value.model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 2)]               0         
_________________________________________________________________
dense (Dense)                (None, 16)                48        
_________________________________________________________________
dense_1 (Dense)              (None, 8)                 136       
_________________________________________________________________
dense_2 (Dense)              (None, 1)                 9         
Total params: 193
Trainable params: 193
Non-trainable params: 0
_________________________________________________________________


In [9]:
train(world, state_value, num=50)

Iteration  1: 
[  -0.9  -0.9  -1.0  -0.9        -0.8 ]
[  -0.7  -0.8  -0.9  -0.8        -0.7 ]
[  -0.6  -0.7  -0.8  -0.7        -0.6 ]
[  -0.3  -0.4  -0.5  -0.6        -0.5 ]
[  -0.0  -0.1  -0.1  -0.3  -0.4  -0.4 ]
[   0.3   0.3   0.3   0.2  -0.1   0.0 ]

Iteration  2: 
[  -1.8  -1.9  -1.9  -1.9        -1.4 ]
[  -1.7  -1.8  -1.8  -1.8        -1.2 ]
[  -1.5  -1.7  -1.8  -1.6        -1.1 ]
[  -1.3  -1.3  -1.4  -1.5        -1.0 ]
[  -0.8  -0.9  -1.0  -1.1  -1.0  -0.8 ]
[  -0.4  -0.4  -0.5  -0.5  -0.6   0.0 ]

Iteration  3: 
[  -2.8  -2.9  -2.8  -2.8        -1.8 ]
[  -2.6  -2.7  -2.7  -2.6        -1.6 ]
[  -2.4  -2.5  -2.6  -2.5        -1.3 ]
[  -2.0  -2.1  -2.3  -2.3        -1.1 ]
[  -1.4  -1.5  -1.6  -1.6  -1.4  -0.8 ]
[  -0.7  -0.8  -0.9  -0.8  -0.7   0.0 ]

Iteration  4: 
[  -3.8  -3.8  -3.8  -3.8        -2.6 ]
[  -3.6  -3.6  -3.6  -3.7        -2.2 ]
[  -3.3  -3.4  -3.5  -3.5        -1.8 ]
[  -2.7  -2.8  -3.1  -3.1        -1.5 ]
[  -2.1  -2.2  -2.3  -2.3  -1.9  -1.1 ]
[  -1.4  -1.5  -1