<a href="https://colab.research.google.com/github/hinsley/colabs/blob/master/TDt3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TDt3

Temporal difference learning model for tic-tac-toe

Inspired by [Sutton & Barto - Reinforcement Learning: An Introduction](http://incompleteideas.net/book/the-book-2nd.html)

Learns entirely from self-play

Board position encodings:

$\begin{bmatrix}0 & 1 & 2 & 3 & 4 & 5 & 6 & 7 & 8\end{bmatrix} \to \begin{pmatrix} 0 & 1 & 2 \\ 3 & 4 & 5 \\ 6 & 7 & 8 \end{pmatrix}$

In [0]:
import numpy as np
from time import time
from typing import Dict, List, Optional, Tuple

In [0]:
ActionValue = Optional[float] # [0.0, 1.0]. Higher values are better.

PositionState = int # Enumeration. See below.
VALUE_BLANK = 0
VALUE_X = 1
VALUE_O = 2

BoardState = List[PositionState] # (3 x 3)

EncodedBoardState = int # Trinary encoding.

In [0]:
#@title f_eq(a: float, b: float, epsilon: float=1e-4) -> bool
def f_eq(a: float, b: float, epsilon: float=1e-4) -> bool:
  """ Test for float equivalency. Parameterizes rounding error tolerance. """

  return abs(a - b) < epsilon

In [0]:
#@title eval_state(state: BoardState, x_player: bool) -> ActionValue
def eval_state(state: BoardState, x_player: bool) -> ActionValue:
  """
  Used to generate a fresh action value. Only evaluates win/lose/draw
  conditions, outputting a 0.5 action value if the state does not signify the
  end of a game.

  Because it is not possible in Tic-Tac-Toe to achieve a board state such that
  both players have three moves in a row [sic], we can stop evaluating once the
  first row [sic] is detected.
  """
  
  for group in [
    [0,1,2], # Rows
    [3,4,5],
    [6,7,8],
    [0,3,6], # Columns
    [1,4,7],
    [2,5,8],
    [0,4,8], # Diagonals
    [2,4,6],
  ]:
    if state[group[0]] == state[group[1]] == state[group[2]]:
      try:
        return {
          VALUE_X: 1.0 if x_player else 0.0,
          VALUE_O: 0.0 if x_player else 1.0,
        }[state[group[0]]]
      except KeyError: # Row of blanks
        continue
  
  if any([position == VALUE_BLANK for position in state]):
    return 0.5 # Board contains an empty space and no win condition achieved
  else:
    return 0.75 # Draw

In [0]:
#@title encode_board(state: BoardState) -> EncodedBoardState
def encode_board(state: BoardState) -> EncodedBoardState:
  """
  Encode a trinary board state into a single hashable value. Smallest position
  is top left.
  """

  return sum([value * 3 ** i for i, value in enumerate(state)])

In [0]:
#@title decode_board(encoded_state: EncodedBoardState) -> BoardState
def decode_board(encoded_state: EncodedBoardState) -> BoardState:
  """
  Decode a single hashable representation of board state into an indexable
  serialization of positions.
  """

  return [encoded_state // (3 ** i) % 3 for i in range(9)]

In [0]:
#@title pprint(state: Optional[BoardState])
def pprint(state: BoardState):
  """ Pretty prints a board state. """
  
  if state is None:
    return

  def graphical_position_state(position_state: PositionState) -> str:
    return {
        VALUE_BLANK: "-",
        VALUE_X: "X",
        VALUE_O: "O",
    }[position_state]
  
  print()
  print(*[graphical_position_state(pstate) for pstate in state[:3]])
  print(*[graphical_position_state(pstate) for pstate in state[3:6]])
  print(*[graphical_position_state(pstate) for pstate in state[6:]])

In [0]:
#@title class ActionEvaluator()

from random import choice, choices

class ActionEvaluator():

  _action_values: Dict[EncodedBoardState, ActionValue]
  _prev_state: EncodedBoardState
  _x_player: bool


  def __init__(self, x_player: bool=True):

    self._action_values = {0: 0.5}
    self.reset_board_state()
    self._x_player = x_player
  

  def reset_board_state(self):

    self._prev_state = encode_board([VALUE_BLANK] * 9)


  def evaluate(self, state: BoardState) -> ActionValue:

    encoded_state = encode_board(state)

    if not encoded_state in self._action_values: # Unexplored action.
      self._action_values[encoded_state] = eval_state(state, self._x_player)
    
    return self._action_values[encoded_state]


  def back_up_value(self,
                    new_state: BoardState,
                    learning_rate: float=0.25):
    
    self._action_values[self._prev_state] += learning_rate * (self.evaluate(new_state) - self.evaluate(decode_board(self._prev_state)))
    self._prev_state = encode_board(new_state)


  def best_move(self, state: BoardState, beta: float) -> BoardState:

    possible_states = []

    for i, position_value in enumerate(state):
      if position_value == VALUE_BLANK:
        state_after_move = state.copy()
        state_after_move[i] = (VALUE_X if self._x_player else VALUE_O)
        possible_states.append(state_after_move)
    
    # Exploitation-Exploration selection (for training purposes).
    if choices([True, False], weights=[beta, 1.0-beta])[0]:
      return choice(possible_states)

    max_action_value, argmax = 0.0, []
    for new_state in possible_states:
      action_value = self.evaluate(new_state)
      if action_value >= max_action_value: # Found a better/equal move!
        if action_value != max_action_value:
          argmax = []
        max_action_value = action_value
        argmax.append(new_state)
      if f_eq(max_action_value, 1.0): # No point in looking for a better move.
        break

    # We never have to worry about accessing an empty list here, as that means
    # we've already reached a draw.

    return choice(argmax)

In [0]:
#@title move(x_player: ActionEvaluator, o_player: ActionEvaluator, state: BoardState, x_move: bool, beta: float=0.0) -> bool
def move(x_player: ActionEvaluator, o_player: ActionEvaluator, state: BoardState, x_move: bool, beta: float=0.0) -> bool:
  """
  Makes the best known move in place.
  
  Returns a boolean: True if game is over, False otherwise.
  """

  player = (x_player if x_move else o_player)

  prev_state = state.copy()
  state[:] = player.best_move(state, beta)

  if x_move:
    x_player.back_up_value(state)
  else:
    o_player.back_up_value(state)

  state_value = eval_state(state, x_move)
  game_over = False
  if f_eq(1.0, state_value):
    if x_move:
      print("X wins!")
      o_player.back_up_value(state)
    else:
      print("O wins!")
      x_player.back_up_value(state)
    game_over = True
  elif f_eq(0.75, state_value):
    print("Draw!")
    if x_move:
      o_player.back_up_value(state)
    else:
      x_player.back_up_value(state)
    game_over = True
  
  if game_over:
    x_player.reset_board_state()
    o_player.reset_board_state()

  return game_over

In [0]:
#@title move_o(x_player: ActionEvaluator, state: BoardState, row: int, col: int) -> bool
def move_o(x_player: ActionEvaluator, o_player: ActionEvaluator, state: BoardState, row: int, col: int) -> bool:
  """
  Makes a move for O in place. Row and col are zero-indexed.
  
  Returns a boolean stating whether the game is over. True if so, False if not.
  """
  
  if state[row * 3 + col] != VALUE_BLANK:
    print("That's not valid -- someone has already moved there!")
    return

  prev_state = state.copy()
  state[row * 3 + col] = VALUE_O

  game_over = eval_state(state, x_player=True) == 0.0

  if game_over:
    print("O wins!")
    x_player.back_up_value(state)
    o_player.back_up_value(state)
    x_player.reset_board_state()
    o_player.reset_board_state()
  else:
    o_player.back_up_value(state)

  return game_over

In [13]:
#@title Training by Self-Play

x_player = ActionEvaluator(x_player=True)
o_player = ActionEvaluator(x_player=False)

verbose = False

games = 750 #@param {type: "number"}
start_time = time()
for game in range(games):
  board_state = [VALUE_BLANK] * 9
  if verbose:
    pprint(board_state)
  game_over = False
  while True:
    if verbose:
      print()
    game_over = move(x_player, o_player, board_state, x_move=True, beta=0.1)
    if verbose:
      pprint(board_state)
    if game_over:
      break
    if verbose:
      print()
    game_over = move(x_player, o_player, board_state, x_move=False, beta=0.1)
    if verbose:
      pprint(board_state)
    if game_over:
      break
time_elapsed = time() - start_time

print(f"Played {games:,} games in {time_elapsed:.2f} seconds.")

X wins!
X wins!
X wins!
X wins!
O wins!
X wins!
X wins!
X wins!
X wins!
X wins!
O wins!
X wins!
O wins!
O wins!
O wins!
O wins!
X wins!
O wins!
O wins!
O wins!
X wins!
Draw!
X wins!
X wins!
X wins!
O wins!
X wins!
X wins!
X wins!
X wins!
Draw!
Draw!
Draw!
Draw!
Draw!
Draw!
Draw!
Draw!
Draw!
Draw!
X wins!
O wins!
X wins!
Draw!
Draw!
X wins!
X wins!
X wins!
X wins!
O wins!
X wins!
X wins!
Draw!
Draw!
X wins!
X wins!
X wins!
Draw!
Draw!
Draw!
Draw!
X wins!
Draw!
Draw!
Draw!
X wins!
Draw!
X wins!
Draw!
Draw!
O wins!
Draw!
Draw!
Draw!
Draw!
X wins!
Draw!
Draw!
Draw!
X wins!
X wins!
X wins!
Draw!
X wins!
X wins!
Draw!
Draw!
X wins!
Draw!
Draw!
Draw!
Draw!
Draw!
Draw!
Draw!
Draw!
Draw!
Draw!
X wins!
Draw!
X wins!
Draw!
X wins!
Draw!
Draw!
X wins!
Draw!
X wins!
Draw!
Draw!
Draw!
Draw!
Draw!
X wins!
Draw!
Draw!
X wins!
X wins!
Draw!
Draw!
X wins!
Draw!
Draw!
X wins!
X wins!
X wins!
Draw!
Draw!
X wins!
X wins!
Draw!
X wins!
O wins!
X wins!
Draw!
X wins!
Draw!
Draw!
X wins!
X wins!
X wins!
Draw!


In [25]:
#@title Play!

#@markdown Game restarts automatically -- just re-run the cell once a game ends.

row =  "Middle"#@param ["Top", "Middle", "Bottom"]
column =  "Left"#@param ["Left", "Middle", "Right"]

row = {
    "Top": 0,
    "Middle": 1,
    "Bottom": 2,
}[row]

column = {
    "Left": 0,
    "Middle": 1,
    "Right": 2,
}[column]

try:
  turn = turn
except:
  turn = 1

print(f"Turn {turn}")

if turn == 1:
  board_state = [VALUE_BLANK] * 9
  move(x_player, None, board_state, x_move=True)
  turn += 1
else:
  if not move_o(x_player, o_player, board_state, row, column):
    if not move(x_player, o_player, board_state, x_move=True):
      turn += 1
    else:
      turn = 1
      row = ""
  else:
    turn = 1
    column = ""

pprint(board_state)

Turn 1

- - -
- X -
- - -


In [18]:
#@title Run and copy the result of this cell to save/share your model's learned knowledge.
print(x_player._action_values)

{0: 0.6980851406860542, 1: 0.6323852944356076, 3: 0.634667270835799, 9: 0.6182453775545582, 27: 0.5, 81: 0.7337002782915856, 243: 0.5491795539855957, 729: 0.5952906238555392, 2187: 0.5772919654846191, 6561: 0.5, 13126: 0.59375, 13134: 0.5, 13152: 0.5, 13206: 0.625, 13368: 0.5, 13854: 0.5, 15312: 0.5, 17509: 1.0, 734: 0.5, 740: 0.625, 758: 0.5, 812: 0.8301233057765636, 974: 0.55877685546875, 2918: 0.625, 7292: 0.625, 5189: 0.5, 5195: 1.0, 34: 0.5, 42: 0.5, 114: 0.8440080881118774, 276: 0.5, 762: 0.4944371413439512, 2220: 0.5, 6594: 0.5, 781: 1.0, 5104: 0.5, 5161: 0.5, 5167: 0.5, 5239: 0.5, 5401: 0.5, 11719: 0.5, 5416: 0.5, 5488: 0.625, 11968: 0.5, 12067: 1.0, 406: 0.5, 408: 0.5, 414: 0.5, 432: 0.5, 1134: 0.5, 2592: 0.5, 6966: 0.71875, 2647: 0.5, 2649: 0.5, 2655: 0.5, 3375: 0.5, 9207: 0.5, 10668: 0.375, 10670: 0.0, 4456: 0.5, 4458: 0.5, 4464: 0.8367114067077637, 4482: 0.5, 4698: 0.5, 5184: 0.5, 11016: 0.5, 4519: 0.5, 4521: 0.5, 4761: 0.5, 5247: 1.0, 5106: 0.5, 5112: 0.5, 5130: 0.81151789