<a href="https://colab.research.google.com/github/hinsley/RL-depot/blob/master/temporal-difference/TDc4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TDc4

Temporal-difference learning for Connect 4

Uses logistic NTD2(0) search algorithm described in [this paper](https://link.springer.com/article/10.1007/s10994-012-5280-0).

In [0]:
import collections
import math
import multiprocessing as mp
import time
import torch
from random import choice, choices
from typing import Counter, List, Tuple, Union

In [0]:
VALUE_BLANK = 0
VALUE_X = 1
VALUE_O = 2

BOARD_SIZE = (6, 7) # Rows x Cols

In [0]:
ActiveFeatures = Counter[Union[Tuple[int, int, int, int],
                               Tuple[int, int, int, int, int]]]

In [0]:
#@title Hyperparameter Configuration

LEARNING_RATE =  1.0#@param {type: "number"}

### Quantity of Location-Independent Features

For a given feature size $n \times n$, the quantity of features in the equivalence class constructed by horizontally mirroring features is given by the following function $f_{LI}$, depending on the number $s$ of possible states per board position:

$f_{LI}(n) = \frac{s^{n^2} + s^n}{2} - 1$.

We subtract $1$ from $f_{LI}$ because the single feature with all blank position states in each set of features with a given size is "neutral".

In *Connect 4*, there are 3 possible position states. Combining each feature size $n \times n$ for $1 \leq n \leq 4$, we find there are a total of $21,533,300$ features.

### Quantity of Location-Dependent Features

For a given feature size $n \times n$, the quantity of features in the equivalence class constructed by horizontally mirroring the entire board is given by the following function $f_{LD}$, depending on the number $s$ of possible states per board position:

$f_{LD}(n) = s^{n^2} \cdot (7-n) \cdot \left(4-\left\lfloor\frac{n}{2}\right\rfloor\right) - 1$.

Again we subtract $1$ from $f_{LD}$ here due to an all-blank feature being "neutral".

Combining each feature size $n \times n$ for $1 \leq n \leq 4$, we find there are a total of $258,517,805$ features.

In [0]:
manager = mp.Manager()

In [0]:
MAX_FEATURE_SIZE = 6

# Feature size (0 ... MAX_FEATURE_SIZE - 1) -> Row (0-indexed) -> Col (0-indexed) -> Encoding -> weight
weights = manager.list()
for feature_size in range(MAX_FEATURE_SIZE):
  rows = manager.dict()
  for row in range(BOARD_SIZE[0] - feature_size + 1):
    rows[row] = manager.dict()
    for col in range(math.ceil(BOARD_SIZE[1] / 2) - feature_size // 2):
      rows[row][col] = manager.dict()
  weights.append(rows)
weights = [{row: {col: dict() for col in range(math.ceil(BOARD_SIZE[1] / 2) - feature_size // 2)} for row in range(BOARD_SIZE[0] - feature_size + 1)} for feature_size in range(MAX_FEATURE_SIZE)]
# Location independent.
for feature_size in range(MAX_FEATURE_SIZE):
  weights[feature_size][-1] = manager.dict()

In [0]:
def empty_board() -> torch.ByteTensor:
  return torch.ones(BOARD_SIZE, dtype=torch.uint8) * VALUE_BLANK

In [0]:
def game_over(state: torch.ByteTensor,
              last_move_pos: Tuple[int, int]) -> bool:
  # This does NOT check for draws.

  last_move = state[last_move_pos]
  lm_row = last_move_pos[0]
  lm_col = last_move_pos[1]
  
  for start_offset in range(-3, 1):
    # Check for horizontal wins.
    try:
      accumulator = 0
      for i in range(4):
        col = lm_col + start_offset + i
        if col < 0 or col >= BOARD_SIZE[1]:
          break
        if last_move == state[lm_row, col]:
          accumulator += 1
      if accumulator == 4:
        return True
    except IndexError:
      pass

    # Check for vertical wins.
    try:
      accumulator = 0
      for i in range(4):
        row = lm_row + start_offset + i
        if row < 0 or row >= BOARD_SIZE[0]:
          break
        if last_move == state[row, lm_col]:
          accumulator += 1
      if accumulator == 4:
        return True
    except IndexError:
      pass
    
    # Check for diagonal (backslash direction) wins.
    try:
      accumulator = 0
      for i in range(4):
        row = lm_row + start_offset + i
        col = lm_col + start_offset + i
        if row < 0 or col < 0 or row >= BOARD_SIZE[0] or col >= BOARD_SIZE[1]:
          break
        if last_move == state[row, col]:
          accumulator += 1
      if accumulator == 4:
        return True
    except IndexError:
      pass

    # Check for diagonal (forward slash direction) wins.
    try:
      accumulator = 0
      for i in range(4):
        row = lm_row + start_offset + i
        col = lm_col + start_offset + 4 - i
        if row < 0 or col < 0 or row >= BOARD_SIZE[0] or col >= BOARD_SIZE[1]:
          break
        if last_move == state[row, col]:
          accumulator += 1
      if accumulator == 4:
        return True
    except IndexError:
      pass
  
  return False

In [0]:
def board_full(state: torch.ByteTensor) -> bool:
  return (state != VALUE_BLANK).all()

In [0]:
def pprint(state: torch.ByteTensor):
  """ Displays a given game configuration. """
  def rasterize_position(position_state: int) -> str:
    return {
        VALUE_BLANK: " ",
        VALUE_X: "X",
        VALUE_O: "O",
    }[position_state]

  print(" __ _ _ _ _ _ __")
  for i, row in enumerate(state):
    print(f"""{i+1}|{' '.join([rasterize_position(position_state.item()) for
                              position_state in
                              row])}|""")
  print(" -- - - - - - --")
  print("  A B C D E F G ")

In [13]:
"""
 __ _ _ _ _ _ __
1|  X X        |
2|  O O        |
3|  O X   X O X|
4|  O X   O X O|
5|O X O   X O O|
6|X O X   X O X|
 -- - - - - - --
  A B C D E F G 
Game 7 of 10,000 (0.0700%): O wins
"""

'\n __ _ _ _ _ _ __\n1|  X X        |\n2|  O O        |\n3|  O X   X O X|\n4|  O X   O X O|\n5|O X O   X O O|\n6|X O X   X O X|\n -- - - - - - --\n  A B C D E F G \nGame 7 of 10,000 (0.0700%): O wins\n'

In [14]:
board_state = empty_board()
board_state[(0,0,2,2,2,3,3,4,4,5,5,5,5), (1,2,2,4,6,2,5,1,4,0,2,4,6)] = VALUE_X
board_state[(1,1,2,2,3,3,3,4,4,4,4,5,5), (1,2,1,5,1,4,6,0,2,5,6,1,5)] = VALUE_O
pprint(board_state)

 __ _ _ _ _ _ __
1|  X X        |
2|  O O        |
3|  O X   X O X|
4|  O X   O X O|
5|O X O   X O O|
6|X O X   X O X|
 -- - - - - - --
  A B C D E F G 


In [15]:
game_over(board_state, (4, 0))

False

In [0]:
def encode_feature(feature: torch.ByteTensor, check_inverted: bool=True) -> Tuple[int, bool]:
  """
  Returns the encoding as well as a boolean value stating whether the feature
  has been inverted.
  """
  def encode_feature_without_inversion(feature: torch.ByteTensor) -> int:
    offsets = 3 ** torch.LongTensor(list(range(feature.numel())))
    return (feature * offsets.view(feature.size())).sum().item() - 1
  
  orig_encoding = encode_feature_without_inversion(feature)

  if check_inverted:
    inverted_feature = feature.clone()
    inverted_feature[feature == VALUE_X] = VALUE_O
    inverted_feature[feature == VALUE_O] = VALUE_X

    inverted_encoding = encode_feature_without_inversion(inverted_feature)
  else:
    inverted_encoding = 1e21

  if orig_encoding < inverted_encoding:
    return orig_encoding, False
  else:
    return inverted_encoding, True

In [0]:
def encode_feature_li(feature: torch.ByteTensor) -> Tuple[int, bool, bool]:
  """
  Returns the encoding as well as two boolean values stating whether the feature
  has been flipped and/or inverted respectively.
  """
  encoding, _ = encode_feature(feature, False)

  flipped_feature = feature.flip(1)
  encoding_flipped, _ = encode_feature(flipped_feature, False)
  
  inverted_feature = feature.clone()
  inverted_feature[feature == VALUE_X] = VALUE_O
  inverted_feature[feature == VALUE_O] = VALUE_X
  encoding_inverted, _ = encode_feature(inverted_feature, False)

  inverted_flipped_feature = flipped_feature.clone()
  inverted_flipped_feature[flipped_feature == VALUE_X] = VALUE_O
  inverted_flipped_feature[flipped_feature == VALUE_O] = VALUE_X
  encoding_inverted_flipped, _ = encode_feature(inverted_flipped_feature, False)

  encodings = [encoding,
               encoding_flipped,
               encoding_inverted,
               encoding_inverted_flipped]

  min_encoding_index = encodings.index(min(encodings))
  flipped = min_encoding_index in [1, 3]
  inverted = min_encoding_index in [2, 3]

  return encodings[min_encoding_index], flipped, inverted

In [0]:
def active_features(state: torch.ByteTensor) -> ActiveFeatures:
  """
  Returns a list of encoded features.
  Ignores all empty features.
  """
  # It's relatively inexpensive to just return a list of tuples instead of a
  # tree structure.
  features = collections.Counter()
  for feature_size in range(1, MAX_FEATURE_SIZE + 1):
    for row in range(state.size()[0] - feature_size + 1):
      cols_left = 4 - feature_size // 2
      for col in range(cols_left):
        feature = state[row : row + feature_size, col : col + feature_size]
        # We do not care about totally blank features.
        if (feature != VALUE_BLANK).any():
          # Location independent feature.
          encoding, _, inverted = encode_feature_li(feature)
          features[(
            feature_size - 1,
            -1,
            encoding,
            1 - 2 * int(inverted),
          )] += 1
          # Location dependent feature.
          encoding, inverted = encode_feature(feature)
          features[(
              feature_size - 1,
              row,
              col,
              encoding,
              1 - 2 * int(inverted),
          )] += 1

      flipped_state = state.flip(1)

      cols_right = 4 - math.ceil(feature_size / 2)
      for col in range(cols_right):
        feature = flipped_state[row : row + feature_size, col : col + feature_size] 
        # We do not care about totally blank features.
        if (feature != VALUE_BLANK).any():
          # Location independent feature.
          encoding, _, inverted = encode_feature_li(feature)
          features[(
              feature_size - 1,
              -1,
              encoding,
              1 - 2 * int(inverted),
          )] += 1
          # Location dependent feature.
          encoding, inverted = encode_feature(feature)
          features[(
              feature_size - 1,
              row,
              col,
              encoding,
              1 - 2 * int(inverted),
          )] += 1
          
  return features

In [19]:
active_features(board_state)

Counter({(0, -1, 0, -1): 13,
         (0, -1, 0, 1): 13,
         (0, 0, 1, 0, 1): 1,
         (0, 0, 2, 0, 1): 1,
         (0, 1, 1, 0, -1): 1,
         (0, 1, 2, 0, -1): 1,
         (0, 2, 0, 0, 1): 1,
         (0, 2, 1, 0, -1): 2,
         (0, 2, 2, 0, 1): 2,
         (0, 3, 0, 0, -1): 1,
         (0, 3, 1, 0, -1): 1,
         (0, 3, 1, 0, 1): 1,
         (0, 3, 2, 0, -1): 1,
         (0, 3, 2, 0, 1): 1,
         (0, 4, 0, 0, -1): 2,
         (0, 4, 1, 0, -1): 1,
         (0, 4, 1, 0, 1): 1,
         (0, 4, 2, 0, -1): 1,
         (0, 4, 2, 0, 1): 1,
         (0, 5, 0, 0, 1): 2,
         (0, 5, 1, 0, -1): 2,
         (0, 5, 2, 0, 1): 2,
         (1, -1, 8, 1): 1,
         (1, -1, 9, -1): 2,
         (1, -1, 9, 1): 2,
         (1, -1, 10, -1): 4,
         (1, -1, 10, 1): 3,
         (1, -1, 40, -1): 1,
         (1, -1, 43, -1): 1,
         (1, -1, 44, 1): 2,
         (1, -1, 45, -1): 1,
         (1, -1, 48, -1): 2,
         (1, -1, 49, 1): 2,
         (1, -1, 51, 1): 6,
         (1, 0

In [0]:
def evaluate(state: torch.ByteTensor,
             last_move_pos: Tuple[int, int]) -> float:
  if state is None or last_move_pos is None:
    return 0.5

  if game_over(state, last_move_pos):
    return float(state[last_move_pos] == VALUE_X)
  
  features: ActiveFeatures = active_features(state)

  accumulator = 0
  for feature, count in features.items():
    if len(feature) == 4: # Location independent.
      try:
        accumulator += count * weights[feature[0]][feature[1]][feature[2]] * feature[3]
      except KeyError:
        weights[feature[0]][feature[1]][feature[2]] = 0
    else: # Location dependent.
      try:
        accumulator += count * weights[feature[0]][feature[1]][feature[2]][feature[3]] * feature[4]
      except KeyError:
        weights[feature[0]][feature[1]][feature[2]][feature[3]] = 0
  
  try:
    return 1 / (1 + math.exp(-accumulator)) # Sigmoid to squash to [0.0, 1.0].
  except OverflowError:
    return 1.0 if accumulator > 0 else 0.0

In [21]:
evaluate(board_state, (5, 3))

0.0

In [0]:
def td_update(agent: "Agent",
              new_state: torch.ByteTensor,
              last_move_pos: Tuple[int, int],
              learning_rate: float = LEARNING_RATE) -> float:
  signal_power = sum([count * count for count in agent.prev_active_features.values()]) # This may need to be updated to utilize weight sharing for inversion-equivalent features.
  afterstate_value = evaluate(new_state, last_move_pos)

  for feature, count in agent.prev_active_features.items():
    delta = learning_rate * count / signal_power * (afterstate_value - agent.prev_state_value)
    if len(feature) == 4: # Location independent.
      try:
        weights[feature[0]][feature[1]][feature[2]] += delta * feature[3]
      except KeyError:
        weights[feature[0]][feature[1]][feature[2]] = delta * feature[3]
    else: # Location dependent.
      try:
        weights[feature[0]][feature[1]][feature[2]][feature[3]] += delta * feature[4]
      except KeyError:
        weights[feature[0]][feature[1]][feature[2]][feature[3]] = delta * feature[4]

  return afterstate_value

In [0]:
def drop_piece(state: torch.ByteTensor,
               column: int,
               x_player: bool) -> Tuple[torch.ByteTensor, Tuple[int, int]]:
  drop_row = state.size()[0] - 1 # In case entire column is empty.

  for row in range(state.size()[0]):
    if state[row, column] != VALUE_BLANK:
      drop_row = row - 1
      break
  
  new_state = state.clone()
  new_state[drop_row, column] = VALUE_X if x_player else VALUE_O

  return new_state, (drop_row, column)

In [0]:
class Agent():

  x_player: bool # X player will try to maximize reward, O player will try to
                 # minimize reward.
  prev_move_rollouts: int
  prev_active_features: ActiveFeatures
  prev_state_value: float

  def __init__(self, x_player: bool):
    self.x_player = x_player
    self.prev_move_rollouts = 0
    self.reset_prev_state()

  def reset_prev_state(self):
    self.prev_active_features = active_features(empty_board())
    self.prev_state_value = evaluate(empty_board(), None)

  def best_move(self,
                state: torch.ByteTensor,
                epsilon: float=0.15) -> Tuple[torch.ByteTensor, Tuple[int, int]]:
    """
    Uses an epsilon-greedy method to select the best move, then performs a TD
    update for the afterstate of each possible action evaluated.

    Returns the afterstate of the selected action, its value estimate, and a
    tuple of integers describing the location of the move chosen.
    """

    possible_afterstates = [drop_piece(state, column, self.x_player) for
                            column in
                            range(state.size()[1]) if
                            state[0, column] == VALUE_BLANK]

    # Explore.
    if choices((True, False), weights=[epsilon, 1.0-epsilon])[0]:
      new_state, last_move_pos = choice(possible_afterstates)
      self.prev_state_value = td_update(self, new_state, last_move_pos)
    else: # Exploit.
      afterstate_values = [td_update(self, new_state, last_move_pos) for
                           new_state, last_move_pos in
                           possible_afterstates]

      optimum_value = max(afterstate_values + [0.0]) if self.x_player else min(afterstate_values + [1.0])
      optimum_states = [possible_afterstates[i] for
                        i, afterstate_value in
                        enumerate(afterstate_values) if
                        afterstate_value == optimum_value]
      
      self.prev_state_value = optimum_value
      new_state, last_move_pos = choice(optimum_states)

    self.prev_active_features = active_features(new_state)

    return new_state, last_move_pos

In [27]:
#@title Self-Play Training

games = 5#@param {type: "number"}
epsilon = 0.2#@param {type: "number"}
show_games = False#@param {type: "boolean"}

x_player = Agent(True)
o_player = Agent(False)

def train_game(initial_state: torch.ByteTensor) -> float:
  train_state = initial_state.clone()

  x_player.reset_prev_state()
  o_player.reset_prev_state()

  last_move_pos = None

  if show_games:
    pprint(train_state)
  while True:
    if show_games:
      print()
    train_state, last_move_pos = x_player.best_move(
      train_state,
      epsilon=epsilon
    )
    if show_games:
      pprint(train_state)
    if game_over(train_state, last_move_pos):
      return 1.0 # X wins.
    if show_games:
      print()

    train_state, last_move_pos = o_player.best_move(
      train_state,
      epsilon=epsilon
    )
    if show_games:
      pprint(train_state)
    if game_over(train_state, last_move_pos):
      return 0.0 # O wins.
    if board_full(train_state):
      return 0.5 # Draw.

start_time = time.time()

if show_games:
  for game in range(games):
    result_statement = {
      0.0: "O wins",
      0.5: "Draw",
      1.0: "X wins",
    }[train_game(empty_board())]

    print(f"Game {game+1:,} of {games:,} ({(game+1)/games:.4%}): {result_statement}")
else:
  with mp.Pool() as p:
    for game, result in enumerate(p.imap_unordered(train_game, (empty_board() for _ in range(games)))):
      result_statement = {
        0.0: "O wins",
        0.5: "Draw",
        1.0: "X wins",
      }[result]

      print(f"Game {game+1:,} of {games:,} ({(game+1)/games:.4%}): {result_statement}")

time_elapsed = time.time() - start_time
print(f"Played {games:,} games in {time_elapsed:,.2f} seconds.")

Game 1 of 5 (20.0000%): X wins
Game 2 of 5 (40.0000%): X wins
Game 3 of 5 (60.0000%): X wins
Game 4 of 5 (80.0000%): X wins
Game 5 of 5 (100.0000%): O wins
Played 5 games in 5.54 seconds.
