In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from tic_tac_toe import Board, new_board, transition, next_marker_to_place, game_state, pretty_format, available_plays, GameState, Marker
from egocentric import EgocentricBoard, EgocentricMarker, canonicalize, remap_to_egocentric_board
from random import random, seed, choice
from tqdm import tqdm

from time import time_ns
import json
from pathlib import Path


In [3]:
QTable = dict[str, float]

class QLearner:
    def __init__(self):
        self._canonical_q_table: QTable = {}
        self._q_table: QTable = {}
    
    @property
    def canonical_q_table(self):
        return self._canonical_q_table
    
    @property
    def q_table(self):
        return self._q_table

    def serialize_state_action(self, state: Board, action: int) -> tuple[str, str]:
        """Returns the serialized canonical state action pair and the serialized state action pair"""
        def _serialize_state_action(ego_state: EgocentricBoard, action: int):
            return "".join([str(m) for m in ego_state] + [str(action)])

        ego_state: EgocentricBoard = remap_to_egocentric_board(state)
        canonical_state: EgocentricBoard = canonicalize(ego_state)
        canonical_state_action: str = _serialize_state_action(canonical_state, action)

        # NOTE: we only initialize the canonical q table entry because we
        #       will only ever set entries in the regular q table
        if canonical_state_action not in self._canonical_q_table:
            self._canonical_q_table[canonical_state_action] = 0.0

        return canonical_state_action, _serialize_state_action(ego_state, action)

    def get_action(self, state_t: Board, epsilon: float = 0.0) -> int:
        if random() < epsilon:
            # explore
            return choice(available_plays(state_t))

        # greedy
        next_qs: list[tuple[int, float]] = []
        for action in available_plays(state_t):
            canon_state_action_t, _ = self.serialize_state_action(state=state_t, action=action)
            next_qs.append((action, self._canonical_q_table[canon_state_action_t],))

        # TODO: consider randomly selecting from ties
        return max(next_qs, key = lambda t: t[1])[0]
    
    def _update_q_tables(self, state: Board, action: int, q: float):
        canonical_state_action_t, state_action_t = self.serialize_state_action(state=state, action=action)
        self._canonical_q_table[canonical_state_action_t] = q
        self._q_table[state_action_t] = q

    def update(self, state_t: Board, reward: float, action: int, state_t_next: Board, learning_rate: float, discount_factor: float = 1.0):
        canonical_state_action_t, _ = self.serialize_state_action(state=state_t, action=action)
        q_t = self._canonical_q_table[canonical_state_action_t]

        if game_state(state_t_next) != GameState.INCOMPLETE:
            # next state is terminal so all q values at next state will be 0 
            self._update_q_tables(
                state=state_t,
                action=action,
                q=q_t + learning_rate * (reward - q_t)
            )
            return

        next_transition_qs: list[float] = []
        for action_next in available_plays(state_t_next):
            canonical_state_action_t_next, _ = self.serialize_state_action(state=transition(board=state_t_next, idx=action_next), action=action_next)
            next_transition_qs.append(self._canonical_q_table[canonical_state_action_t_next])

        # TODO: consider randomly selecting from ties
        max_q_next = max(next_transition_qs)

        td_error = (
            reward +
            discount_factor * max_q_next -
            q_t
        )

        self._update_q_tables(state=state_t, action=action, q=q_t + learning_rate * td_error)



In [4]:
def save(q_learner: QLearner):
    dir = Path("./outputs") / f"{time_ns()}"
    dir.mkdir(parents=True)

    canonical_q_table_path: Path = dir / "canonical_q_table.json"
    print(f"Saving to {canonical_q_table_path}")
    with open(canonical_q_table_path, "w") as f:
        json.dump({"".join(str(x) for x in k): v for k, v in q_learner.canonical_q_table.items()}, f)

    q_table_path: Path = dir / "q_table.json"
    print(f"Saving to {q_table_path}")
    with open(q_table_path, "w") as f:
        json.dump({"".join(str(x) for x in k): v for k, v in q_learner.q_table.items()}, f)

def load(path):
    raw = json.load(open(path, "r"))
    return {tuple(int(c) for c in k): v for k, v in raw.items()}

In [5]:
def q_learner_play(q_learner: QLearner, state_t: Board, epsilon: float, learning_rate: float) -> Board:
    action = q_learner.get_action(state_t=state_t, epsilon=epsilon)
    next_board = transition(state_t, action)
    q_learner.update(
        state_t=state_t,
        reward=reward_from_board_transition(board=next_board),
        action=action,
        state_t_next=next_board,
        learning_rate=learning_rate
        )
    return next_board

def random_agent_play(state_t: Board) -> Board:
    return transition(board=state_t, idx=choice(available_plays(state_t)))

def reward_from_board_transition(board: Board) -> float:
    # NOTE: if next marker is second player, then first player just played
    first_player: bool = next_marker_to_place(board) == Marker.SECOND_PLAYER
    outcome = game_state(board)
    if outcome == GameState.FIRST_PLAYER_WON:
        return 1.0 if first_player else -1.0
    if outcome == GameState.SECOND_PLAYER_WON:
        return -1.0 if first_player else 1.0
    return 0.0

def rollout_self_play(q_learner: QLearner, epsilon: float, learning_rate: float):
    board: Board = new_board()
    while True:
        board = q_learner_play(
            q_learner=q_learner,
            state_t=board,
            epsilon=epsilon,
            learning_rate=learning_rate
        )
        if game_state(board) != GameState.INCOMPLETE:
            # terminal
            break
    return game_state(board)

def rollout_random_opponent(
    q_learner: QLearner,
    epsilon: float,
    learning_rate: float,
    agent_is_first_player: bool
):
    board: Board = new_board()
    while True:
        if next_marker_to_place(board) == Marker.FIRST_PLAYER:
            board = (
                q_learner_play(
                    q_learner=q_learner,
                    state_t=board,
                    epsilon=epsilon,
                    learning_rate=learning_rate
                ) if agent_is_first_player else random_agent_play(state_t=board)
            )
        else:
            board = (
                q_learner_play(
                    q_learner=q_learner,
                    state_t=board,
                    epsilon=epsilon,
                    learning_rate=learning_rate
                ) if not agent_is_first_player else random_agent_play(state_t=board)
            )
        if game_state(board) != GameState.INCOMPLETE:
            # terminal
            break
    return game_state(board)

In [6]:
seed(42)
q_learner = QLearner()

In [7]:
outcomes = {GameState.FIRST_PLAYER_WON: 0, GameState.SECOND_PLAYER_WON: 0, GameState.TIED: 0}

for _ in tqdm(range(5000000)):
    outcome = rollout_self_play(q_learner=q_learner, epsilon=0.1, learning_rate=0.05)
    outcomes[outcome] += 1

print(outcomes)

100%|██████████| 5000000/5000000 [15:28<00:00, 5383.81it/s]

{<GameState.FIRST_PLAYER_WON: 1>: 3375366, <GameState.SECOND_PLAYER_WON: 2>: 1212047, <GameState.TIED: 3>: 412587}





In [8]:
# q learner is first player

outcomes = {GameState.FIRST_PLAYER_WON: 0, GameState.SECOND_PLAYER_WON: 0, GameState.TIED: 0}

for _ in tqdm(range(5000000)):
    outcome = rollout_random_opponent(q_learner=q_learner, epsilon=0.1, learning_rate=0.05, agent_is_first_player=True)
    outcomes[outcome] += 1

print(outcomes)

100%|██████████| 5000000/5000000 [09:00<00:00, 9247.41it/s]

{<GameState.FIRST_PLAYER_WON: 1>: 3837175, <GameState.SECOND_PLAYER_WON: 2>: 517016, <GameState.TIED: 3>: 645809}





In [9]:
# q learner is second player

outcomes = {GameState.FIRST_PLAYER_WON: 0, GameState.SECOND_PLAYER_WON: 0, GameState.TIED: 0}

for _ in tqdm(range(5000000)):
    outcome = rollout_random_opponent(q_learner=q_learner, epsilon=0.1, learning_rate=0.05, agent_is_first_player=False)
    outcomes[outcome] += 1

print(outcomes)

100%|██████████| 5000000/5000000 [07:33<00:00, 11013.66it/s]

{<GameState.FIRST_PLAYER_WON: 1>: 1190328, <GameState.SECOND_PLAYER_WON: 2>: 2827542, <GameState.TIED: 3>: 982130}





In [10]:
save(q_learner=q_learner)

Saving to outputs/1753565858454676581/canonical_q_table.json
Saving to outputs/1753565858454676581/q_table.json
