<a href="https://colab.research.google.com/github/enakai00/colab_rlbook/blob/master/Chapter03/Tic_tac_toe.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
from enum import Enum
from copy import deepcopy
import random

In [0]:
class StateInfo(Enum):
  PLAYER1 = 1
  PLAYER2 = 2
  WIN1 = 3
  WIN2 = 4
  TIE = 5

def flatten(lst):
  return sum(lst, [])

In [0]:
class State:
  def __init__(self, board):
    self.board = board
    self.info = self.get_state_info(board)

  def __hash__(self):
    return hash(tuple(flatten(self.board)))

  def __eq__(self, other):
    return self.board == other.board

  def __repr__(self):
    result = ''
    for row in self.board:
      for c in row:
        if c == 1:
          result += 'o'
        elif c == 2:
          result += 'x'
        else:
          result += '.'
      result += '\n'
    return result
    

  def get_state_info(self, board):
    num1 = sum(map(lambda x: x == 1, flatten(board)))
    num2 = sum(map(lambda x: x == 2, flatten(board)))
    if num1 != num2 and num1 != num2 + 1:
      return None

    lines = [[(x, y) for x in range(3)] for y in range(3)]
    lines += [[(x, y) for y in range(3)] for x in range(3)]
    lines += [[(x, x) for x in range(3)], [(x, 2-x) for x in range(3)]]
    win1, win2 = 0, 0  
    for line in lines:
      if all(map(lambda pos: board[pos[1]][pos[0]] == 1, line)):
        win1 += 1
      if all(map(lambda pos: board[pos[1]][pos[0]] == 2, line)):
        win2 += 2

    if win1 > 0 and win2 > 0:
      return None

    if win1 > 0:
      return StateInfo.WIN1

    if win2 > 0:
      return StateInfo.WIN2

    if sum(map(lambda x: x == 0,  flatten(board))) == 0:
      return StateInfo.TIE

    if num1 == num2:
      return StateInfo.PLAYER1
    else:
      return StateInfo.PLAYER2

In [0]:
class Agent:
  def __init__(self, player=1):
    self.player = player

    self.states = []
    rows = [[a, b, c] for a in range(3) for b in range(3) for c in range(3)]
    boards = [[a.copy(), b.copy(), c.copy()] for a in rows for b in rows for c in rows]
    for board in boards:
      state = State(board)
      if self.player == 1 and state.info not in (None, StateInfo.PLAYER2):
          self.states.append(state)
      if self.player == 2 and state.info not in (None, StateInfo.PLAYER1):
          self.states.append(state)

    self.value = {}
    for state in self.states:
      self.value[state] = 0

    self.policy = {}
    for state in self.states:
      if self.is_myturn(state):
        self.policy[state] = random.choice(self.get_actions(state))
      else:
        self.policy[state] = None


  def is_myturn(self, state):
    if self.player == 1 and state.info == StateInfo.PLAYER1:
      return True
    if self.player == 2 and state.info == StateInfo.PLAYER2:
      return True
    return False


  def is_win(self, state):
    if self.player == 1 and state.info == StateInfo.WIN1:
      return True
    if self.player == 2 and state.info == StateInfo.WIN2:
      return True
    return False


  def is_lost(self, state):
    if self.player == 1 and state.info == StateInfo.WIN2:
      return True
    if self.player == 2 and state.info == StateInfo.WIN1:
      return True
    return False


  def get_actions(self, state):
    actions = [(x, y) for y in range(3) for x in range(3) if state.board[y][x] == 0]
    return actions


  def put(self, state, pos, player):
    x, y = pos
    board = deepcopy(state.board)
    board[y][x] = player
    return State(board)


  def move(self, state, pos, opponent):
    if not self.is_myturn(state): # Terminal state
      return 0, state

    next_state = self.put(state, pos, self.player)

    if self.is_win(next_state):
      return 1, next_state

    if next_state.info == StateInfo.TIE:
      return 0, next_state

    pos = opponent.policy[next_state]
    after_state = self.put(next_state, pos, opponent.player)
    if self.is_lost(after_state):
      return -1, after_state
  
    return 0, after_state

In [0]:
def policy_eval(agent, opponent, gamma=1.0, delta=0.01):
  while(True):
    delta_max = 0
    for state in agent.states:
      r, s_new = agent.move(state, agent.policy[state], opponent)
      v_new = r + gamma * agent.value[s_new]
      delta_max = max(delta_max, abs(agent.value[state] - v_new))
      agent.value[state] = v_new

    if delta_max < delta:
      break

In [0]:
def policy_update(agent, opponent, gamma=1.0):
  update = False

  for state in agent.states:
    if not agent.is_myturn(state):
      continue

    q_max = -99
    pos_best = None
    for pos in agent.get_actions(state):
      r, state_new = agent.move(state, pos, opponent)
      q = r + gamma * agent.value[state_new]
      if q > q_max:
        q_max = q
        pos_best = pos

    if agent.policy[state] != pos_best:
      update = True
    agent.policy[state] = pos_best

  return update

In [0]:
agent1 = Agent(player=1)
agent2 = Agent(player=2)

In [8]:
for _ in range(6):
  while True:
    print('.', end='')
    policy_eval(agent1, agent2)
    if not policy_update(agent1, agent2):
      break

  while True:
    print('.', end='')
    policy_eval(agent2, agent1)
    if not policy_update(agent2, agent1):
      break
  
  print('')

........
........
......
.....
...
..


In [10]:
state = agent1.states[0]
initial = True
while True:
  if initial:
    state = agent1.put(state, random.choice(agent1.get_actions(state)), 1)
    initial = False
  else:
    state = agent1.put(state, agent1.policy[state], 1)
  print(state)
  if state.info != StateInfo.PLAYER2:
    break
  state = agent1.put(state, agent2.policy[state], 2)
  print(state)
  if state.info != StateInfo.PLAYER1:
    break

print(state.info)

...
.o.
...

x..
.o.
...

xo.
.o.
...

xo.
.o.
.x.

xo.
oo.
.x.

xo.
oox
.x.

xoo
oox
.x.

xoo
oox
xx.

xoo
oox
xxo

StateInfo.TIE
