<a href="https://colab.research.google.com/github/hinsley/colabs/blob/master/TDt3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TDt3

Temporal difference learning model for tic-tac-toe

Inspired by Sutton & Barto

Plays as X

Board position encodings:

$\begin{bmatrix}0 & 1 & 2 & 3 & 4 & 5 & 6 & 7 & 8\end{bmatrix} \to \begin{pmatrix} 0 & 1 & 2 \\ 3 & 4 & 5 \\ 6 & 7 & 8 \end{pmatrix}$

In [0]:
import numpy as np
from typing import Dict, List, Optional, Tuple

In [0]:
ActionValue = Optional[float] # [0.0, 1.0]. Higher values are better. We
                              # consider draws to be losses.

PositionState = int # Enumeration. See below.
VALUE_BLANK = 0
VALUE_X = 1
VALUE_O = 2

BoardState = List[PositionState] # (3 x 3)

EncodedBoardState = int # Trinary encoding.

In [0]:
#@title f_eq(a: float, b: float, epsilon: float=1e-4) -> bool
def f_eq(a: float, b: float, epsilon: float=1e-4) -> bool:
  """ Test for float equivalency. Parameterizes rounding error tolerance. """

  return abs(a - b) < epsilon

In [0]:
#@title eval_state(state: BoardState) -> ActionValue
def eval_state(state: BoardState) -> ActionValue:
  """
  Used to generate a fresh action value. Only evaluates win/lose/draw
  conditions, outputting a 0.5 action value if the state does not signify the
  end of a game.

  Because it is not possible in Tic-Tac-Toe to achieve a board state such that
  both players have three moves in a row [sic], we can stop evaluating once the
  first row [sic] is detected.

  Draws are considered losses.
  """
  
  for group in [
    [0,1,2], # Rows
    [3,4,5],
    [6,7,8],
    [0,3,6], # Columns
    [1,4,7],
    [2,5,8],
    [0,4,8], # Diagonals
    [2,4,6],
  ]:
    if state[group[0]] == state[group[1]] == state[group[2]]:
      try:
        return {
          VALUE_X: 1.0,
          VALUE_O: 0.0,
        }[state[group[0]]]
      except KeyError: # Row of blanks
        continue
  
  if any([position == VALUE_BLANK for position in state]):
    return 0.5 # Board contains an empty space and no win condition achieved
  else:
    return 0.0 # Draw

In [0]:
#@title encode_board(state: BoardState) -> EncodedBoardState
def encode_board(state: BoardState) -> EncodedBoardState:
  """
  Encode a trinary board state into a single hashable value. Smallest position
  is top left.
  """

  return sum([value * 3 ** i for i, value in enumerate(state)])

In [0]:
#@title decode_board(encoded_state: EncodedBoardState) -> BoardState
def decode_board(encoded_state: EncodedBoardState) -> BoardState:
  """
  Decode a single hashable representation of board state into an indexable
  serialization of positions.
  """

  return [encoded_state // (3 ** i) % 3 for i in range(9)]

In [0]:
#@title pprint(state: Optional[BoardState])
def pprint(state: BoardState):
  """ Pretty prints a board state. """
  
  if state is None:
    return

  def graphical_position_state(position_state: PositionState) -> str:
    return {
        VALUE_BLANK: "-",
        VALUE_X: "X",
        VALUE_O: "O",
    }[position_state]
  
  print()
  print(*[graphical_position_state(pstate) for pstate in state[:3]])
  print(*[graphical_position_state(pstate) for pstate in state[3:6]])
  print(*[graphical_position_state(pstate) for pstate in state[6:]])

In [0]:
#@title class ActionEvaluator()

from random import randint

class ActionEvaluator():

  _action_values: Dict[EncodedBoardState, ActionValue]
  _prev_state: EncodedBoardState


  def __init__(self):

    self._action_values = {0: 0.5}
    self.reset_board_state()
  

  def reset_board_state(self):

    self._prev_state = encode_board([VALUE_BLANK] * 9)


  def evaluate(self, state: BoardState) -> ActionValue:

    encoded_state = encode_board(state)

    if not encoded_state in self._action_values: # Unexplored action.
      self._action_values[encoded_state] = eval_state(state)
    
    return self._action_values[encoded_state]


  def back_up_value(self,
                    new_state: BoardState,
                    learning_rate: float=0.5):
    
    self._action_values[self._prev_state] += learning_rate * (self.evaluate(new_state) - self.evaluate(decode_board(self._prev_state)))
    self._prev_state = encode_board(new_state)


  def best_move(self, state: BoardState) -> BoardState:

    def possible_move_generator():

      for i, position_value in enumerate(state):
        if position_value == VALUE_BLANK:
          state_after_move = state.copy()
          state_after_move[i] = VALUE_X
          yield state_after_move
    possible_states = possible_move_generator()

    max_action_value, argmax = 0.0, []
    for new_state in possible_states:
      action_value = self.evaluate(new_state)
      if action_value >= max_action_value: # Found a better/equal move!
        if action_value != max_action_value:
          argmax = []
        max_action_value = action_value
        argmax.append(new_state)
      if f_eq(max_action_value, 1.0): # No point in looking for a better move.
        break

    # We never have to worry about accessing an empty list here, as that means
    # we've already reached a draw.

    return argmax[randint(0, len(argmax)-1)]

In [0]:
#@title move_o(action_evaluator: ActionEvaluator, state: BoardState, row: int, col: int) -> Optional[BoardState]
def move_o(action_evaluator: ActionEvaluator, state: BoardState, row: int, col: int) -> Optional[BoardState]:
  """ Makes a move for O in place. Row and col are zero-indexed. """
  
  if state[row * 3 + col] != VALUE_BLANK:
    print("That's not valid -- someone has already moved there!")
    return

  prev_state = state.copy()
  state[row * 3 + col] = VALUE_O

  if eval_state(state) == 0.0:
    print("O wins!")
    action_evaluator.back_up_value(state)

  return state

In [0]:
#@title move_x(action_evaluator: ActionEvaluator, state: BoardState) -> BoardState
def move_x(action_evaluator: ActionEvaluator, state: BoardState) -> BoardState:
  """ Makes the best known move for X in place. """

  prev_state = state.copy()
  state[:] = action_evaluator.best_move(state)
  action_evaluator.back_up_value(state)

  state_value = eval_state(state)
  if state_value == 1.0:
    print("X wins!")
  elif state_value == 0.0:
    print("Draw!")

  return state

In [548]:
#@title Delete knowledge of the game
try:
  del initialized
  del action_evaluator
  print("Knowledge deleted successfully.")
except:
  print("Cannot delete knowledge -- it doesn't exist!")

Knowledge deleted successfully.


In [772]:
#@title Load knowledge of the game

knowledge = "{0: 0.5, 1: 0.5, 3: 0.5, 9: 0.5, 27: 0.5, 81: 0.5, 243: 0.5, 729: 0.5, 2187: 0.5, 6561: 0.5, 406: 0.5, 408: 0.5, 414: 0.5, 432: 0.5, 1134: 0.5, 2592: 0.5, 6966: 0.5, 1891: 0.5, 1893: 0.5, 1899: 0.5, 4077: 0.0, 8451: 0.25, 8469: 0.0, 172: 0.5, 174: 0.5, 198: 0.5, 900: 0.5, 2358: 0.5, 6732: 0.5, 3817: 0.5, 3819: 0.5, 3843: 0.5, 4059: 0.5, 10377: 0.5, 4330: 0.5, 4332: 0.5, 10890: 0.25, 10897: 0.0, 166: 0.5, 190: 0.5, 892: 0.5, 2350: 0.5, 6724: 0.5, 2365: 0.5, 2383: 0.5, 2599: 0.5, 3085: 0.5, 8917: 0.5, 2644: 0.5, 3346: 1.0, 9178: 0.25, 10636: 0.0, 192: 0.5, 894: 0.5, 2352: 0.5, 6726: 0.5, 211: 0.5, 427: 0.5, 913: 0.5, 2371: 0.0, 6745: 0.5, 1210: 0.5, 3154: 0.0, 7528: 0.25, 8014: 0.0, 205: 0.0, 421: 0.0, 907: 0.0, 6739: 0.0, 2662: 0.25, 3148: 0.5, 8980: 0.5, 16513: 0.0, 2376: 0.5, 3078: 0.5, 8910: 0.5, 16201: 0.5, 16203: 0.25, 16209: 0.0, 16227: 0.0, 16443: 0.5, 16205: 0.0, 13204: 0.5, 13206: 0.5, 13212: 0.5, 13230: 0.5, 13446: 0.5, 13932: 0.5, 15390: 0.5, 15397: 0.5, 15405: 0.5, 15423: 0.5, 15639: 0.5, 16125: 0.5, 15460: 0.5, 15694: 0.25, 16180: 0.5, 16441: 0.0, 14017: 0.5, 14025: 0.25, 14043: 0.5, 14259: 0.5, 14027: 0.0, 6750: 0.5, 7452: 0.5, 10369: 0.0, 10371: 0.25, 10395: 0.0, 10611: 0.0, 10389: 0.0, 7219: 0.5, 7221: 0.5, 7245: 0.5, 7947: 0.0, 9405: 0.25, 9459: 0.0, 100: 0.5, 102: 0.5, 126: 0.5, 342: 0.5, 828: 0.5, 2286: 0.5, 6660: 0.5, 13225: 0.25, 13249: 0.5, 13465: 0.5, 13951: 0.5, 15409: 0.5, 13711: 0.0, 11125: 0.0, 11127: 0.5, 11133: 0.5, 11367: 0.25, 11853: 0.0, 11373: 0.0, 4573: 0.5, 4575: 0.5, 4815: 0.5, 5301: 0.5, 4820: 0.25, 5306: 0.5, 11138: 0.5, 17942: 0.0, 11131: 0.0, 439: 0.5, 1141: 0.5, 6973: 0.5, 4795: 0.0, 10387: 0.0, 9961: 1.0, 10629: 0.0, 918: 0.5, 16229: 0.0, 13531: 0.5, 13539: 0.0, 13557: 0.5, 15717: 0.5, 13541: 0.0, 88: 0.5, 96: 0.5, 114: 0.5, 330: 0.5, 816: 0.5, 2274: 0.5, 6648: 0.5, 835: 0.5, 861: 0.0, 1077: 0.5, 3021: 0.5, 7395: 0.5, 863: 0.0, 11113: 0.0, 2839: 0.0, 2847: 0.0, 2865: 0.5, 3567: 0.5, 9399: 0.5, 2901: 0.0, 1555: 0.5, 1581: 0.5, 1797: 1.0, 3741: 0.5, 8115: 0.5, 1852: 0.5, 4038: 0.5, 8412: 1.0, 3811: 0.5, 3837: 0.5, 4053: 0.5, 3848: 0.0, 4064: 0.5, 10382: 0.5, 16970: 0.0, 1867: 0.5, 1873: 1.0, 4051: 0.5, 8425: 0.5, 14998: 1.0, 15310: 0.5, 15312: 0.5, 15318: 0.5, 15336: 0.5, 15552: 0.5, 16038: 0.5, 16057: 0.5, 16059: 0.0, 16083: 0.5, 16137: 0.5, 16299: 0.5, 16545: 0.0, 13366: 0.5, 13368: 0.5, 13374: 0.5, 13392: 0.5, 14094: 0.5, 15715: 0.5, 15723: 0.0, 15741: 0.5, 15725: 0.0, 262: 0.5, 264: 0.5, 288: 0.5, 990: 0.5, 2448: 0.5, 6822: 0.5, 397: 0.5, 399: 0.5, 1125: 0.5, 2583: 0.5, 6957: 0.5, 6962: 0.5, 7688: 0.0, 9146: 0.5, 7694: 0.0, 10936: 0.5, 10938: 0.5, 10944: 0.5, 10962: 0.5, 11016: 0.5, 11178: 0.5, 11664: 0.5, 11101: 0.5, 11107: 0.0, 11341: 0.5, 11827: 0.0, 11829: 0.5, 11835: 0.5, 12069: 0.5, 11833: 0.0, 16: 0.5, 42: 0.5, 258: 0.5, 744: 0.5, 2202: 0.5, 6576: 0.5, 7063: 0.5, 7089: 0.5, 7143: 0.5, 7791: 1.0, 9249: 0.5, 12166: 0.5, 12192: 0.5, 12246: 1.0, 32: 0.5, 38: 0.5, 110: 0.5, 272: 0.5, 758: 0.5, 2216: 0.5, 6590: 0.5, 13163: 0.5, 13235: 0.5, 13397: 0.5, 13883: 0.5, 15341: 0.0, 15503: 0.0, 3646: 0.5, 3648: 0.5, 3654: 0.5, 3672: 0.5, 3726: 0.5, 3888: 0.5, 10206: 0.5, 3835: 0.5, 4095: 0.0, 1486: 0.5, 1488: 0.5, 1494: 0.5, 1566: 0.5, 1728: 0.5, 8046: 0.5, 14617: 0.5, 14619: 0.0, 14697: 0.5, 14859: 0.5, 16803: 0.5, 18993: 0.0, 8001: 0.0, 3733: 0.5, 3759: 0.5, 3975: 0.5, 10293: 0.5, 4246: 0.0, 4254: 0.5, 10806: 0.5, 17377: 0.0, 3640: 0.0, 15475: 0.5, 15483: 0.0, 15501: 0.5, 15485: 0.0, 22: 0.5, 48: 0.5, 750: 0.5, 2208: 0.5, 6582: 0.5, 13873: 0.0, 13899: 0.0, 13953: 0.5, 14115: 0.5, 14278: 0.0, 14304: 0.5, 16464: 0.5, 16519: 0.0, 16258: 0.0, 16264: 0.5, 16498: 0.5, 16744: 0.0, 14023: 0.5, 14049: 0.0, 14265: 0.5, 16211: 0.0, 451: 0.5, 1153: 0.5, 2611: 0.5, 6985: 0.0, 8443: 0.0, 3850: 0.5, 4066: 0.0, 10384: 0.5, 17215: 0.0, 1139: 0.5, 1145: 0.5, 1163: 0.0, 3323: 0.5, 7697: 0.5, 14285: 0.0, 13171: 0.5, 13251: 0.5, 13413: 0.5, 15357: 0.5, 14385: 0.0, 2397: 0.5, 2613: 0.5, 3099: 0.5, 8931: 0.5, 3829: 0.0, 1130: 0.5, 3314: 0.0, 3320: 0.0, 949: 0.5, 955: 0.0, 1189: 0.5, 3133: 0.5, 7507: 0.5, 1441: 0.0, 5281: 0.0, 8185: 0.0, 8193: 0.5, 8211: 0.5, 8427: 0.5, 8203: 0.0, 14359: 0.0, 4579: 0.0, 1237: 0.5, 1263: 0.5, 1317: 0.0, 3423: 0.5, 7797: 0.5, 14439: 0.0, 2845: 0.5, 2863: 0.5, 3565: 0.5, 9397: 0.5, 2893: 0.0, 10413: 0.0, 203: 0.5, 419: 0.5, 905: 0.0, 2363: 0.5, 6737: 0.5, 3893: 0.5, 3899: 0.5, 3917: 0.5, 3971: 0.5, 10451: 0.0, 10505: 0.0, 14611: 0.0, 14689: 0.5, 14851: 0.5, 16795: 0.5, 16960: 0.5, 16966: 0.0, 17200: 0.5, 14051: 0.0, 11859: 0.0, 14691: 0.5, 14853: 0.5, 16797: 0.5, 18985: 0.0, 4456: 0.5, 4458: 0.5, 4464: 0.5, 4482: 0.5, 4698: 0.5, 5184: 0.5, 4969: 0.5, 4971: 0.5, 4977: 0.5, 5697: 0.5, 11529: 0.5, 18094: 0.0, 18102: 0.5, 18822: 0.5, 18112: 0.0}" #@param {type: "string"}

try:
  initialized = initialized
except:
  action_evaluator = ActionEvaluator()

initialized = 0

try:
  action_evaluator._action_values = eval(knowledge)
  print("Knowledge loaded successfully.")
except:
  print("Cannot load knowledge -- invalid.")

Knowledge loaded successfully.


In [768]:
#@title Turn 1: New game

try:
  initialized += 1
  action_evaluator.reset_board_state()
except:
  print("Initializing ActionEvaluator")
  initialized = 1
  action_evaluator = ActionEvaluator()

board_state = [VALUE_BLANK] * 9

print(f"Game {initialized} - Turn 1")
pprint(move_x(action_evaluator, board_state))

Game 53 - Turn 1

- X -
- - -
- - -


In [764]:
#@title Turns 2-3

row =  "Bottom"#@param ["Top", "Middle", "Bottom"]
column =  "Middle"#@param ["Left", "Middle", "Right"]

row = {
    "Top": 0,
    "Middle": 1,
    "Bottom": 2,
}[row]

column = {
    "Left": 0,
    "Middle": 1,
    "Right": 2,
}[column]

if not move_o(action_evaluator, board_state, row, column) is None:
  if f_eq(0.0, eval_state(board_state)):
    pprint(board_state)
  else:
    pprint(move_x(action_evaluator, board_state))


- - -
X X -
- O -


In [765]:
#@title Turns 4-5

row =  "Middle"#@param ["Top", "Middle", "Bottom"]
column =  "Right"#@param ["Left", "Middle", "Right"]

row = {
    "Top": 0,
    "Middle": 1,
    "Bottom": 2,
}[row]

column = {
    "Left": 0,
    "Middle": 1,
    "Right": 2,
}[column]

if not move_o(action_evaluator, board_state, row, column) is None:
  if f_eq(0.0, eval_state(board_state)):
    pprint(board_state)
  else:
    pprint(move_x(action_evaluator, board_state))


- X -
X X O
- O -


In [766]:
#@title Turns 6-7

row =  "Bottom"#@param ["Top", "Middle", "Bottom"]
column =  "Right"#@param ["Left", "Middle", "Right"]

row = {
    "Top": 0,
    "Middle": 1,
    "Bottom": 2,
}[row]

column = {
    "Left": 0,
    "Middle": 1,
    "Right": 2,
}[column]

if not move_o(action_evaluator, board_state, row, column) is None:
  if f_eq(0.0, eval_state(board_state)):
    pprint(board_state)
  else:
    pprint(move_x(action_evaluator, board_state))


X X -
X X O
- O O


In [767]:
#@title Turns 8-9: Showdown!

row =  "Top"#@param ["Top", "Middle", "Bottom"]
column =  "Right"#@param ["Left", "Middle", "Right"]

row = {
    "Top": 0,
    "Middle": 1,
    "Bottom": 2,
}[row]

column = {
    "Left": 0,
    "Middle": 1,
    "Right": 2,
}[column]

if not move_o(action_evaluator, board_state, row, column) is None:
  if f_eq(0.0, eval_state(board_state)):
    pprint(board_state)
  else:
    pprint(move_x(action_evaluator, board_state))

O wins!

X X O
X X O
- O O


In [771]:
#@title Run and copy the result of this cell to save/share your model's learned knowledge.
print(action_evaluator._action_values)

{0: 0.5, 1: 0.5, 3: 0.5, 9: 0.5, 27: 0.5, 81: 0.5, 243: 0.5, 729: 0.5, 2187: 0.5, 6561: 0.5, 406: 0.5, 408: 0.5, 414: 0.5, 432: 0.5, 1134: 0.5, 2592: 0.5, 6966: 0.5, 1891: 0.5, 1893: 0.5, 1899: 0.5, 4077: 0.0, 8451: 0.25, 8469: 0.0, 172: 0.5, 174: 0.5, 198: 0.5, 900: 0.5, 2358: 0.5, 6732: 0.5, 3817: 0.5, 3819: 0.5, 3843: 0.5, 4059: 0.5, 10377: 0.5, 4330: 0.5, 4332: 0.5, 10890: 0.25, 10897: 0.0, 166: 0.5, 190: 0.5, 892: 0.5, 2350: 0.5, 6724: 0.5, 2365: 0.5, 2383: 0.5, 2599: 0.5, 3085: 0.5, 8917: 0.5, 2644: 0.5, 3346: 1.0, 9178: 0.25, 10636: 0.0, 192: 0.5, 894: 0.5, 2352: 0.5, 6726: 0.5, 211: 0.5, 427: 0.5, 913: 0.5, 2371: 0.0, 6745: 0.5, 1210: 0.5, 3154: 0.0, 7528: 0.25, 8014: 0.0, 205: 0.0, 421: 0.0, 907: 0.0, 6739: 0.0, 2662: 0.25, 3148: 0.5, 8980: 0.5, 16513: 0.0, 2376: 0.5, 3078: 0.5, 8910: 0.5, 16201: 0.5, 16203: 0.25, 16209: 0.0, 16227: 0.0, 16443: 0.5, 16205: 0.0, 13204: 0.5, 13206: 0.5, 13212: 0.5, 13230: 0.5, 13446: 0.5, 13932: 0.5, 15390: 0.5, 15397: 0.5, 15405: 0.5, 15423: 0.