In [6]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [6]:
from tic_tac_toe import Board, new_board, transition, next_marker_to_place, game_state, pretty_format, available_plays, GameState, Marker
from egocentric import EgocentricBoard, EgocentricMarker, canonicalize, remap_to_egocentric_board
from random import random, seed, choice
from itertools import cycle

In [21]:
QTable = dict[str, float]

class QLearner:
    def __init__(self):
        self._canonical_q_table: QTable = {}
        self._q_table: QTable = {}

    def serialize_state_action(self, state: Board, action: int) -> tuple[str, str]:
        """Returns the serialized canonical state action pair and the serialized state action pair"""
        def _serialize_state_action(ego_state: EgocentricBoard, action: int):
            return "".join([str(m) for m in ego_state] + [str(action)])

        ego_state: EgocentricBoard = remap_to_egocentric_board(state)
        canonical_state: EgocentricBoard = canonicalize(ego_state)
        canonical_state_action: str = _serialize_state_action(canonical_state, action)

        # NOTE: we only initialize the canonical q table entry because we
        #       will only ever set entries in the regular q table
        if canonical_state_action not in self._canonical_q_table:
            self._canonical_q_table[canonical_state_action] = 0.0

        return canonical_state_action, _serialize_state_action(ego_state, action)

    def get_action(self, state_t: Board, epsilon: float = 0.0) -> int:
        if random() < epsilon:
            # explore
            return choice(available_plays(state_t))

        # greedy
        next_qs: list[tuple[int, float]] = []
        for action in available_plays(state_t):
            canon_state_action_t, _ = self.serialize_state_action(state=state_t, action=action)
            next_qs.append((action, self._canonical_q_table[canon_state_action_t],))

        # TODO: consider randomly selecting from ties
        return max(next_qs, key = lambda t: t[1])[0]
    
    def _update_q_tables(self, state: Board, action: int, q: float):
        canonical_state_action_t, state_action_t = self.serialize_state_action(state=state, action=action)
        self._canonical_q_table[canonical_state_action_t] = q
        self._q_table[state_action_t] = q

    def update(self, state_t: Board, reward: float, action: int, state_t_next: Board, learning_rate: float, discount_factor: float = 1.0):
        canonical_state_action_t, _ = self.serialize_state_action(state=state_t, action=action)
        q_t = self._canonical_q_table[canonical_state_action_t]

        if game_state(state_t_next) != GameState.INCOMPLETE:
            # next state is terminal so all q values at next state will be 0 
            self._update_q_tables(
                state=state_t,
                action=action,
                q=q_t + learning_rate * (reward - q_t)
            )
            return

        next_transition_qs: list[float] = []
        for action_next in available_plays(state_t_next):
            canonical_state_action_t_next, _ = self.serialize_state_action(state=transition(board=state_t_next, idx=action_next), action=action_next)
            next_transition_qs.append(self._canonical_q_table[canonical_state_action_t_next])

        # TODO: consider randomly selecting from ties
        max_q_next = max(next_transition_qs)

        td_error = (
            reward +
            discount_factor * max_q_next -
            q_t
        )

        self._update_q_tables(state=state_t, action=action, q=q_t + learning_rate * td_error)



In [22]:
def reward_from_board_transition(board: Board) -> float:
    # NOTE: if next marker is second player, then first player just played
    first_player: bool = next_marker_to_place(board) == Marker.SECOND_PLAYER
    outcome = game_state(board)
    if outcome == GameState.FIRST_PLAYER_WON:
        return 1.0 if first_player else -1.0
    if outcome == GameState.SECOND_PLAYER_WON:
        return -1.0 if first_player else 1.0
    return 0.0

def rollout_self_play(q_learner: QLearner, epsilon: float, learning_rate: float):
    board: Board = new_board()
    while True:
        action = q_learner.get_action(state_t=board, epsilon=epsilon)
        next_board = transition(board, action)
        q_learner.update(
            state_t=board,
            reward=reward_from_board_transition(board=next_board),
            action=action,
            state_t_next=next_board,
            learning_rate=learning_rate
        )
        board = next_board
        if game_state(board) != GameState.INCOMPLETE:
            # terminal
            break
    return game_state(board)

In [23]:
seed(42)
q_learner = QLearner()

In [25]:
for i in range(10):
    outcome = rollout_self_play(q_learner=q_learner, epsilon=0.1, learning_rate=0.05)
    print(f"episode {i}: {outcome}")

episode 0: GameState.FIRST_PLAYER_WON
episode 1: GameState.FIRST_PLAYER_WON
episode 2: GameState.FIRST_PLAYER_WON
episode 3: GameState.FIRST_PLAYER_WON
episode 4: GameState.FIRST_PLAYER_WON
episode 5: GameState.FIRST_PLAYER_WON
episode 6: GameState.TIED
episode 7: GameState.SECOND_PLAYER_WON
episode 8: GameState.FIRST_PLAYER_WON
episode 9: GameState.FIRST_PLAYER_WON
