In [1]:
import numpy as np
from typing import Self
from functools import cache
from tqdm import tqdm
import json

from time import time_ns
from itertools import cycle
from tic_tac_toe_board import TicTacToeBoard

%load_ext autoreload
%autoreload 2

In [15]:
WON_VALUE, DID_NOT_WIN_VALUE, DEFAULT_VALUE = 1.0, 0.0, 0.5

@cache
def serialize_board(board: TicTacToeBoard) -> str:
    return "".join([str(c) for c in board.state])


def populate_value_if_non_existent(value_table: dict, board: TicTacToeBoard):
    key = serialize_board(board)
    if key not in value_table:
        value_table[key] = DEFAULT_VALUE
    if board.first_player_won():
        value_table[key] = WON_VALUE
    elif board.second_player_won():
        value_table[key] = WON_VALUE 
    elif board.tied():
        value_table[key] = DID_NOT_WIN_VALUE

def get_value(value_table: dict, board: TicTacToeBoard):
    populate_value_if_non_existent(value_table, board)
    return value_table[serialize_board(board)]

def sample_value_table_policy(value_table: dict, board: TicTacToeBoard, eps: float = 0.0) -> tuple[float, bool]:
    """Returns action sampled from policy and bool indicating whether or not a greedy move was selected"""
    populate_value_if_non_existent(value_table=value_table, board=board)
    possible_plays: list[tuple[int, float]] = []
    for idx in board.available_cell_indices():
        next_board = board.transition(idx)
        populate_value_if_non_existent(value_table=value_table, board=next_board)
        possible_plays.append((idx, get_value(value_table, next_board)))

    greedy_idx, _ = max(possible_plays, key=lambda t: t[1])
    if len(possible_plays) > 1 and np.random.rand() < eps:
        return np.random.choice([idx for idx, _ in possible_plays if idx != greedy_idx]), False

    return greedy_idx, True

def update_value(value_table: dict, last_board: TicTacToeBoard, current_board: TicTacToeBoard, learning_rate: float):
    key_last = serialize_board(last_board)
    # these two calls will auto‐create missing entries 
    v_last = get_value(value_table, last_board)
    v_curr = (
        get_value(value_table, current_board)
        if current_board.player_to_move() == last_board.player_to_move()
        else DID_NOT_WIN_VALUE # this is a terminal state backup, the last state did not lead to a winning terminal state
    )
    value_table[key_last] = v_last + learning_rate * (v_curr - v_last)


In [16]:
def save(value_table, path):
    print(f"Saving to {path}")
    # dump keys as strings for JSON
    export = {"".join(str(x) for x in k): v for k, v in value_table.items()}
    with open(path, "w") as f:
        json.dump(export, f)

def load(path):
    raw = json.load(open(path, "r"))
    return {tuple(int(c) for c in k): v for k, v in raw.items()}

In [17]:
# NOTE: self play loop, agent plays both first and second player, changing its
#       role after each turn

FIRST_PLAYER_WON="first player won"
SECOND_PLAYER_WON="second player won"
TIE="tie"

# TODO: this is so gross lol make it not gross pls 
# TODO: replace debug with using logs instead of prints so we can set log level at top of notebook
def train(value_table, n_episodes: int, learning_rate: float, epsilon: float, debug=False) -> dict:
    outcomes = {FIRST_PLAYER_WON: 0, SECOND_PLAYER_WON: 0, TIE: 0}
    for i in tqdm(range(n_episodes)):
        current = TicTacToeBoard()
        player_cycle = cycle([TicTacToeBoard.FIRST_PLAYER_CELL, TicTacToeBoard.SECOND_PLAYER_CELL])
        last_greedy_state = None


        while True:
            player = next(player_cycle)
            action, greedy = sample_value_table_policy(value_table=value_table, board=current, eps=epsilon)
            last = current
            current = current.transition(action)
            if greedy:
                update_value(value_table=value_table, last_board=last, current_board=current, learning_rate=learning_rate)
            if current.terminated():
                break
            if greedy:
                # NOTE: we do it after termination so that the player who went before this iter
                #       has last_greedy_state set to their state, not updated to this
                #       state (which already one)
                last_greedy_state = current

        if last_greedy_state is not None:
            # backup other player's last greedy state to incorporate the loss or tie signal
            update_value(value_table=value_table, last_board=last, current_board=current, learning_rate=learning_rate)

        assert current.terminated(), "should be terminated at end of episode"
        outcomes[FIRST_PLAYER_WON if current.first_player_won() else SECOND_PLAYER_WON if current.second_player_won() else TIE] += 1

    return outcomes

In [18]:
np.random.seed(42)

value_table = {}
train(value_table=value_table, n_episodes=1000000, learning_rate=0.05, epsilon=0.1)

100%|██████████| 1000000/1000000 [06:03<00:00, 2747.26it/s]


{'first player won': 749998, 'second player won': 183419, 'tie': 66583}

In [19]:
def sample_eval_policy(board: TicTacToeBoard) -> int:
    player_to_move = board.player_to_move()
    avail = board.available_cell_indices()
    if len(avail) == 0:
        raise ValueError("No moves left")
    for c in avail:
        if board.transition(c).player_won(player_to_move):
            return c
    return int(np.random.choice(avail))

In [23]:
def eval_policy(value_table, n_episodes, player):
    outcomes = {FIRST_PLAYER_WON: 0, SECOND_PLAYER_WON: 0, TIE: 0}
    # TODO: same as above training loop, make loop cleaner using cycle to switch through current player or somethign
    for i in tqdm(range(n_episodes)):
        current = TicTacToeBoard()

        while True:
            action = None
            if player == TicTacToeBoard.FIRST_PLAYER_CELL:
                action, _ = sample_value_table_policy(value_table=value_table, board=current, eps=0.0)
            else:
                action = sample_eval_policy(board=current)
            current = current.transition(action)

            # TODO: change to use logging
            # current.display()

            if current.terminated():
                break


            if player == TicTacToeBoard.SECOND_PLAYER_CELL:
                action, _ = sample_value_table_policy(value_table=value_table, board=current, eps=0.0)
            else:
                action = sample_eval_policy(board=current)
            current.transition(action)
            # TODO: change to use logging
            # current.display()
            if current.terminated():
                break
        
        assert current.terminated(), "should be terminated at end of episode"
        outcomes[FIRST_PLAYER_WON if current.first_player_won() else SECOND_PLAYER_WON if current.second_player_won() else TIE] += 1

    return outcomes

        

In [24]:
eval_policy(value_table=value_table, n_episodes=10000, player=TicTacToeBoard.FIRST_PLAYER_CELL)

100%|██████████| 10000/10000 [00:02<00:00, 3619.15it/s]


{'first player won': 10000, 'second player won': 0, 'tie': 0}

In [25]:
save(value_table=value_table, path=f"table{time_ns}.json")

Saving to table<built-in function time_ns>.json


In [None]:
# TODO: how we serialize the table here breaks how the web app expects it

# def serialize_to_ui_format(value_table):
#     return {"".join(["1" if c == AGENT_CELL else "2" if c == OPPONENT_CELL else "0" for c in s]): v for s, v in value_table.items()}
# save(value_table=serialize_to_ui_format(value_table=value_table), path=f"ui_value_table{time_ns}.json")