In [7]:
import numpy as np
from typing import Self
from functools import cache
from tqdm import tqdm
import json

In [8]:
class TicTacToeBoard:
    EMPTY_CELL, FIRST_PLAYER_CELL, SECOND_PLAYER_CELL = 0, 1, 2

    WIN_COMBOS = (
        (0, 1, 2),
        (3, 4, 5),
        (6, 7, 8),
        (0, 3, 6),
        (1, 4, 7),
        (2, 5, 8),
        (0, 4, 8),
        (2, 4, 6),
    )

    WIN_ARRAY = np.array(WIN_COMBOS, dtype=int)  # shape (8,3)

    def __init__(self):
        # NOTE: using EMPTY in case that value changes
        #       for whatever reason, but obv if EMPTY=0
        #       then this is equivalent to np.zeros((9,))
        self._state = np.ones((9,)) * TicTacToeBoard.EMPTY_CELL

    @property
    def state(self) -> tuple[int, ...]:
        return tuple(self._state.tolist())

    # NOTE: __eq__ and __hash__ need to be implemented for
    #       cache functionality to work 
    def __eq__(self, other: Self):
        return (self._state == other._state).all()
    
    def __hash__(self):
        return hash(self.state)

    @cache
    def player_to_move(self) -> int:
        return TicTacToeBoard.FIRST_PLAYER_CELL if len(self.available_cell_indices()) % 2 == 1 else TicTacToeBoard.SECOND_PLAYER_CELL

    def reset(self) -> None:
        self._state *= TicTacToeBoard.EMPTY_CELL

    @cache
    def available_cell_indices(self) -> tuple[int, ...]:
        return tuple((self._state == TicTacToeBoard.EMPTY_CELL).nonzero()[0].tolist())
    
    @cache
    def terminated(self) -> bool:
        return self.tied() or self.first_player_won() or self.second_player_won()

    @cache
    def tied(self) -> bool:
        return (self._state != TicTacToeBoard.EMPTY_CELL).all()

    @cache
    def first_player_won(self) -> bool:
        return self.player_won(player=TicTacToeBoard.FIRST_PLAYER_CELL)

    @cache
    def second_player_won(self) -> bool:
        return self.player_won(player=TicTacToeBoard.SECOND_PLAYER_CELL)

    @cache
    def player_won(self, player: int) -> bool:
        cells = self._state[TicTacToeBoard.WIN_ARRAY]
        return bool(np.any(np.all(cells == player, axis=1)))

    @cache
    def transition(self, idx: int) -> Self:
        if self.terminated():
            raise RuntimeError("move attempted on completed game")

        if idx not in self.available_cell_indices():
            raise RuntimeError("illegal move")
        new_board = self.__class__()
        new_board._state = self._state.copy()
        new_board._state[idx] = self.player_to_move()
        return new_board
    
    def display(self):
        keys = {TicTacToeBoard.FIRST_PLAYER_CELL: "X", TicTacToeBoard.SECOND_PLAYER_CELL: "O", TicTacToeBoard.EMPTY_CELL: "_"}
        for i in range(0, 9, 3):
            print(f"{keys[self._state[i]]} {keys[self._state[i+1]]} {keys[self._state[i+2]]}")
        print()

In [9]:
AGENT_CELL, OPPONENT_CELL, EMPTY_CELL = "A", "O", "_"
# NOTE: ties are considered same as default value
WON_VALUE, LOST_VALUE, DEFAULT_VALUE = 1.0, 0.0, 0.5

@cache
def serialize_board(board: TicTacToeBoard, player: int) -> str:
    return "".join([
        EMPTY_CELL if c == TicTacToeBoard.EMPTY_CELL
        else (AGENT_CELL if c == player else OPPONENT_CELL)
        for c in board.state
    ])


def populate_value_if_non_existent(value_table: dict, board: TicTacToeBoard, player: int):
    key = serialize_board(board, player=player)
    if key not in value_table:
        value_table[key] = DEFAULT_VALUE
    if board.first_player_won():
        value_table[key] = (WON_VALUE if player == TicTacToeBoard.FIRST_PLAYER_CELL else LOST_VALUE)
    elif board.second_player_won():
        value_table[key] = (WON_VALUE if player == TicTacToeBoard.SECOND_PLAYER_CELL else LOST_VALUE)

def get_value(value_table: dict, board: TicTacToeBoard, player: int):
    populate_value_if_non_existent(value_table, board, player)
    return value_table[serialize_board(board, player)]

def sample_policy(value_table: dict, board: TicTacToeBoard, player: int, eps: float = 0.0) -> tuple[float, bool]:
    """Returns action sampled from policy and bool indicating whether or not a greedy move was selected"""
    populate_value_if_non_existent(value_table=value_table, board=board, player=player)
    possible_plays: list[tuple[int, float]] = []
    for idx in board.available_cell_indices():
        next_board = board.transition(idx)
        populate_value_if_non_existent(value_table=value_table, board=next_board, player=player)
        possible_plays.append((idx, get_value(value_table, next_board, player)))

    greedy_idx, _ = max(possible_plays, key=lambda t: t[1])
    if len(possible_plays) > 1 and np.random.rand() < eps:
        return np.random.choice([idx for idx, _ in possible_plays if idx != greedy_idx]), False

    return greedy_idx, True

def update_value(value_table: dict, last_board: TicTacToeBoard, current_board: TicTacToeBoard, player: int, learning_rate: float):
    key_last = serialize_board(last_board, player)
    # these two calls will auto‐create missing entries 
    v_last = get_value(value_table, last_board, player)
    v_curr = get_value(value_table, current_board, player)
    value_table[key_last] = v_last + learning_rate * (v_curr - v_last)


In [10]:
def save(value_table, path):
    # dump keys as strings for JSON
    export = {"".join(str(x) for x in k): v for k, v in value_table.items()}
    with open(path, "w") as f:
        json.dump(export, f)

def load(path):
    raw = json.load(open(path, "r"))
    return {tuple(int(c) for c in k): v for k, v in raw.items()}

In [11]:
# NOTE: self play loop, agent plays both first and second player, changing its
#       role after each turn

FIRST_PLAYER_WON="first player won"
SECOND_PLAYER_WON="second player won"
TIE="tie"

# TODO: this is so gross lol make it not gross pls 
# TODO: replace debug with using logs instead of prints so we can set log level at top of notebook
def train(value_table, n_episodes: int, learning_rate: float, epsilon: float, debug=False) -> dict:
    outcomes = {FIRST_PLAYER_WON: 0, SECOND_PLAYER_WON: 0, TIE: 0}
    for i in tqdm(range(n_episodes)):
        current = TicTacToeBoard()
        p1_greedy = False
        p2_greedy = False
        last_p1_turn = None
        last_p2_turn = None

        while True:
            action, p1_greedy = sample_policy(value_table=value_table, board=current, player=TicTacToeBoard.FIRST_PLAYER_CELL, eps=epsilon)
            last_p1_turn = current

            current = current.transition(action)
            if p2_greedy:
                update_value(value_table=value_table, last_board=last_p2_turn, current_board=current, player=TicTacToeBoard.SECOND_PLAYER_CELL, learning_rate=learning_rate)
                p2_greedy = False

            if debug:
                current.display()

            if current.terminated():
                break

            action, p2_greedy = sample_policy(value_table=value_table, board=current, player=TicTacToeBoard.SECOND_PLAYER_CELL, eps=epsilon)
            last_p2_turn = current

            current = current.transition(action)
            if p1_greedy:
                update_value(value_table=value_table, last_board=last_p1_turn, current_board=current, player=TicTacToeBoard.FIRST_PLAYER_CELL, learning_rate=learning_rate)
                p1_greedy = False

            if debug:
                current.display()

            if current.terminated():
                break

        if p1_greedy:
            update_value(value_table=value_table, last_board=last_p1_turn, current_board=current, player=TicTacToeBoard.FIRST_PLAYER_CELL, learning_rate=learning_rate)
        if p2_greedy:
            update_value(value_table=value_table, last_board=last_p2_turn, current_board=current, player=TicTacToeBoard.SECOND_PLAYER_CELL, learning_rate=learning_rate)

        assert current.terminated(), "should be terminated at end of episode"
        outcomes[FIRST_PLAYER_WON if current.first_player_won() else SECOND_PLAYER_WON if current.second_player_won() else TIE] += 1

    return outcomes

In [12]:
np.random.seed(42)

value_table = {}
train(value_table=value_table, n_episodes=1000000, learning_rate=0.05, epsilon=0.1)

100%|██████████| 1000000/1000000 [04:22<00:00, 3806.15it/s]


{'first player won': 852998, 'second player won': 133820, 'tie': 13182}

In [13]:
def sample_eval_policy(board: TicTacToeBoard) -> int:
    player_to_move = board.player_to_move()
    avail = board.available_cell_indices()
    if len(avail) == 0:
        raise ValueError("No moves left")
    for c in avail:
        if board.transition(c).player_won(player_to_move):
            return c
    return int(np.random.choice(avail))

In [18]:
def eval_policy(value_table, n_episodes, player):
    outcomes = {FIRST_PLAYER_WON: 0, SECOND_PLAYER_WON: 0, TIE: 0}
    # TODO: same as above training loop, make loop cleaner using cycle to switch through current player or somethign
    for i in tqdm(range(n_episodes)):
        current = TicTacToeBoard()

        while True:
            action = None
            if player == TicTacToeBoard.FIRST_PLAYER_CELL:
                action, _ = sample_policy(value_table=value_table, board=current, player=player, eps=0.0)
            else:
                action = sample_eval_policy(board=current)
            current = current.transition(action)

            # TODO: change to use logging
            # current.display()

            if current.terminated():
                break


            if player == TicTacToeBoard.SECOND_PLAYER_CELL:
                action, _ = sample_policy(value_table=value_table, board=current, player=player, eps=0.0)
            else:
                action = sample_eval_policy(board=current)
            current.transition(action)
            # TODO: change to use logging
            # current.display()
            if current.terminated():
                break
        
        assert current.terminated(), "should be terminated at end of episode"
        outcomes[FIRST_PLAYER_WON if current.first_player_won() else SECOND_PLAYER_WON if current.second_player_won() else TIE] += 1

    return outcomes

        

In [19]:
eval_policy(value_table=value_table, n_episodes=10000, player=TicTacToeBoard.FIRST_PLAYER_CELL)

100%|██████████| 10000/10000 [00:02<00:00, 4623.70it/s]


{'first player won': 10000, 'second player won': 0, 'tie': 0}

In [20]:
save(value_table=value_table, path="table.json")

In [21]:
def serialize_to_ui_format(value_table):
    return {"".join(["1" if c == AGENT_CELL else "2" if c == OPPONENT_CELL else "0" for c in s]): v for s, v in value_table.items()}

save(value_table=serialize_to_ui_format(value_table=value_table), path="ui_value_table.json")