<a href="https://colab.research.google.com/github/enakai00/colab_rlbook/blob/master/Chapter05/01_Neural_Network_Policy_Estimation_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%tensorflow_version 2.x 

TensorFlow 2.x selected.


In [0]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models, initializers

In [0]:
class Gridworld:
  def __init__(self, size=8, goals=[7]):
    self.size = size
    self.goals = goals
    self.states = range(size)

  def move(self, s, a):
    if s in self.goals:
      return 0, s       # Reward, Next state

    s_new = s + a

    if s_new in self.goals:
      return 1, s_new   # Reward, Next state

    if s_new not in self.states:
      return -1, s      # Reward, Next state

    return -1, s_new    # Reward, Next state

In [0]:
class StateValue:
  def __init__(self):
    self.model = self.build_model()

  def build_model(self):
    state = layers.Input(shape=(1,), name='state_input')
    value = layers.Dense(1, kernel_initializer=initializers.TruncatedNormal(),
                         name='linear_function')(state)
    model = models.Model(inputs=state, outputs=value)
    model.compile(optimizer='adam', loss='mse')
    return model

  def get_value(self, world, s):
    if s in world.goals:
      return 0
    input_states = [np.array([s])]
    output_values = self.model.predict([input_states])
    value = output_values[0][0]
    return value

In [0]:
def show_values(world, state_value):
  print('[ ', end='')
  for s in world.states:
    print('{:5.2f}'.format(state_value.get_value(world, s)), end=' ')
  print(']')

In [6]:
state_value = StateValue()
state_value.model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
state_input (InputLayer)     [(None, 1)]               0         
_________________________________________________________________
linear_function (Dense)      (None, 1)                 2         
Total params: 2
Trainable params: 2
Non-trainable params: 0
_________________________________________________________________


In [0]:
def get_episode(world):
  episode = []
  s = np.random.randint(world.size-1)
  a = 1 # Always move to right
  while True:
    r, s_new = world.move(s, a)
    episode.append((s, r, s_new))
    if s_new in world.goals:
      break
    s = s_new

  return episode

In [0]:
def train(world, state_value, num):
  for c in range(num):
    print('Iteration {} : '.format(c+1), end='')
    examples = []
    for _ in range(200):
      episode = get_episode(world)
      examples += get_episode(world)
    np.random.shuffle(examples)

    states, labels = [], []
    for s, r, s_new in examples:
      states.append(np.array([s]))
      v_new = state_value.get_value(world, s_new)
      labels.append(np.array(r + v_new))

    state_value.model.fit(np.array(states), np.array(labels),
                          batch_size=50, epochs=100, verbose=0)
    show_values(world, state_value)

In [0]:
world = Gridworld()
state_value = StateValue()

In [11]:
train(world, state_value, num=20)

Iteration 1 : [ -1.11 -0.92 -0.74 -0.56 -0.38 -0.20 -0.01  0.00 ]
Iteration 2 : [ -2.25 -1.85 -1.46 -1.07 -0.68 -0.29  0.10  0.00 ]
Iteration 3 : [ -3.22 -2.63 -2.05 -1.46 -0.88 -0.29  0.30  0.00 ]
Iteration 4 : [ -4.01 -3.26 -2.51 -1.76 -1.00 -0.25  0.50  0.00 ]
Iteration 5 : [ -4.55 -3.68 -2.81 -1.94 -1.07 -0.20  0.67  0.00 ]
Iteration 6 : [ -4.87 -3.92 -2.98 -2.03 -1.08 -0.13  0.82  0.00 ]
Iteration 7 : [ -5.03 -4.04 -3.05 -2.06 -1.07 -0.08  0.91  0.00 ]
Iteration 8 : [ -5.08 -4.07 -3.07 -2.06 -1.05 -0.05  0.96  0.00 ]
Iteration 9 : [ -5.09 -4.07 -3.06 -2.05 -1.04 -0.02  0.99  0.00 ]
Iteration 10 : [ -5.07 -4.06 -3.05 -2.04 -1.02 -0.01  1.00  0.00 ]
Iteration 11 : [ -5.06 -4.05 -3.04 -2.03 -1.02 -0.01  1.00  0.00 ]
Iteration 12 : [ -5.04 -4.03 -3.02 -2.02 -1.01 -0.00  1.01  0.00 ]
Iteration 13 : [ -5.03 -4.02 -3.01 -2.01 -1.00  0.00  1.01  0.00 ]
Iteration 14 : [ -5.01 -4.01 -3.01 -2.00 -1.00  0.00  1.01  0.00 ]
Iteration 15 : [ -5.01 -4.00 -3.00 -2.00 -1.00  0.00  1.00  0.00 ]
Iter