In [69]:
import tqdm as tqdm
import numpy as np

In [70]:
class env:
    """
    This class is used to define a tic tac toe environment
    """

    def __init__(self):
        self.previous_board = [0, 0, 0, 0, 0, 0, 0, 0, 0]
        self.board = [0, 0, 0, 0, 0, 0, 0, 0, 0]
        self.turn = 1
        self.winner = 0

    def reset(self, position = None):
        if position is None:
            position = [0, 0, 0, 0, 0, 0, 0, 0, 0]
        if type(position) != list:
            raise ValueError("The position must be a list of 9 elements")
        if len(position) != 9:
            raise ValueError("The position must be a list of 9 elements")
        self.board = position
        self.turn = 1
        self.winner = 0
        return self.board

    def pop(self):
        """
        This function is used to remove the last move
        """
        self.board = self.previous_board
        self.turn = -self.turn
        self.winner = 0

    def step(self, action):
        """
        This function is used to make a move on the board
        Returns:
            - The new board
            - The reward (0 if there is no winner)
            - A boolean indicating if the game is finished
        If the action is not possible, the game is forfeited with a reward of -10 * self.turn
        """
        self.previous_board = self.board.copy()
        if self.board[action] == 0:
            self.board[action] = self.turn
            self.turn = -self.turn
            self.winner = self.check_winner()
            done = self.winner != 0 or 0 not in self.board
            return self.board, abs(self.winner), done
        else:
            return self.board, -10 * self.turn, True

    def check_winner(self, board = None):
        """
        This function is used to check if there is a winner
        """
        if board is None:
            board = self.board
        for i in range(3):
            if board[i] == board[i + 3] == board[i + 6] != 0:
                return board[i]
            if board[i * 3] == board[i * 3 + 1] == board[i * 3 + 2] != 0:
                return board[i * 3]
        if board[0] == board[4] == board[8] != 0:
            return board[0]
        if board[2] == board[4] == board[6] != 0:
            return board[2]
        if 0 not in board:
            return 0
        return 0
    
    def get_actions(self):
        """
        This function is used to get the possible actions
        """
        return [i for i in range(9) if self.board[i] == 0]
    
    def __str__(self):
        """
        This function is used to print the board
        """
        symbols = {0: " ", 1: "X", -1: "O"}
        return f"{symbols[self.board[0]]}|{symbols[self.board[1]]}|{symbols[self.board[2]]}\n-----\n{symbols[self.board[3]]}|{symbols[self.board[4]]}|{symbols[self.board[5]]}\n-----\n{symbols[self.board[6]]}|{symbols[self.board[7]]}|{symbols[self.board[8]]}"

In [71]:
class Agent():
    """
    This class is used to define a tic tac toe agent
    """

    def __init__(self, env):
        self.env = env
        self.init_policy()

    def init_policy(self):
        """
        This function is used to initialize the policy
        """
        self.s_plus, self.s = enumerate_states()
        self.terminals = {s for s in self.s_plus if s not in self.s}
        self.values = {}
        self.policy = {}

        # initialize the values of the terminal states to 0
        # initialize the policy to random and value to arbitrary value for the non
        # terminal states
        for state in self.s_plus:
            if state in self.s:
                self.values[state] = .1
                self.policy[state] = 0
            else:
                self.values[state] = 0

        print("Initialization done!")

    def get_value(self, state):
        """
        This function is used to get the value of a state
        """
        for iso in get_isomorphisms(state):
            if tuple(iso) in self.values:
                return self.values[tuple(iso)]
        else:
            raise ValueError("The state is not in the state space")
    
    def get_policy(self, state):
        symmetry0 = [0, 1, 2, 3, 4, 5, 6, 7, 8] # e
        symmetry1 = [2, 5, 8, 1, 4, 7, 0, 3, 6] # r1
        symmetry2 = [8, 7, 6, 5, 4, 3, 2, 1, 0] # r2
        symmetry3 = [6, 3, 0, 7, 4, 1, 8, 5, 2] # r3
        symmetry4 = [2, 1, 0, 5, 4, 3, 8, 7, 6] # s
        symmetry5 = [6, 7, 8, 3, 4, 5, 0, 1, 2] # sr1
        symmetry6 = [8, 5, 2, 7, 4, 1, 6, 3, 0] # sr2
        symmetry7 = [0, 3, 6, 1, 4, 7, 2, 5, 8] # sr3
        symmetries = [symmetry0, symmetry1, symmetry2, symmetry3, symmetry4, symmetry5, symmetry6, symmetry7] 

        for s, iso in enumerate(get_isomorphisms(state)):
            if tuple(iso) in self.policy:
                # the symmetry
                sym = symmetries[s]
                # inverse of the symmetry
                sym_inv = [sym.index(i) for i in range(9)]
                return sym_inv[self.policy[tuple(iso)]]
        else:
            raise ValueError("The state is not in the state space")
            
    def set_value(self, state, value):
        """
        This function is used to set the value of a state
        """
        for iso in get_isomorphisms(state):
            if tuple(iso) in self.values:
                self.values[tuple(iso)] = value
                return True
        else:
            raise ValueError("The state is not in the state space")
        
    def set_policy(self, state, action):
        for iso in get_isomorphisms(state):
            if tuple(iso) in self.policy:
                self.policy[tuple(iso)] = action
                return True
        else:
            raise ValueError("The state is not in the state space")
       
    def value_iteration(self, gamma = 1, theta = 1e-5):
        delta = np.inf
        while delta > theta:
            delta = 0
            for state in tqdm.tqdm(self.s):
                self.env.reset(list(state))
                old_v = self.get_value(state)
                v = - np.inf
                for action in self.env.get_actions():
                    next_state, reward, done = self.env.step(action)
                    v = max(v, reward + gamma * -1 * self.get_value(tuple(-1 * i for i in next_state)))
                    self.env.pop()
                self.set_value(state, v)
                delta = max(delta, abs(old_v - v))
        
        # output policy according to the value function
        for state in self.s:
            self.env.reset(list(state))
            v = - np.inf
            for action in self.env.get_actions():
                next_state, reward, done = self.env.step(action)
                if reward + gamma * -1 * self.get_value(tuple(-1 * i for i in next_state)) > v:
                    v = reward + gamma * -1 * self.get_value(tuple(-1 * i for i in next_state))
                    best_action = action
                self.env.pop()
            self.set_policy(state, best_action)

    def train(self, gamma = .9, theta = 1e-5):
        self.value_iteration(gamma, theta)
        print("Value iteration done!")

    def self_test(self):
        """
        This function is used to test the agent
        """
        self.env.reset()
        print(self.env)
        state = self.env.board
        done = False
        while not done:
            action = self.get_policy(tuple(state))
            state, _, done = self.env.step(action)
            state = [-1 * i for i in state]
            print(self.env)
            print()
        if self.env.check_winner() == 0:
            print("It's a draw!")
        elif self.env.check_winner() == 1:
            print("X wins!")
        else:
            print("O wins!")
    
    def play_random(self, n=10000):
        """
        This function is used to play random games
        """
        wins = 0
        draws = 0
        losses = 0
        for game in tqdm.tqdm(range(n)):
            symbol = game % 2 * 2 - 1

            self.env.reset()
            state = self.env.board
            done = False
            while not done:
                if symbol == self.env.turn:
                    action = self.get_policy(tuple(state))
                else:
                    action = np.random.choice(self.env.get_actions())
                state, _, done = self.env.step(action)
                state = [-1 * i for i in state]
            
            if self.env.check_winner() == 0:
                draws += 1
            elif self.env.check_winner() == symbol:
                wins += 1
            else:
                losses += 1
        print(f"Wins: {wins}, Draws: {draws}, Losses: {losses}")
        return wins, draws, losses

def check_symmetry(board1, board2):
    if board1 in get_isomorphisms(board2):
        return True
        
def get_isomorphisms(board):
    # Abstract algebra! Lol who would've thought I would ever use that
    symmetry0 = [0, 1, 2, 3, 4, 5, 6, 7, 8] # e
    symmetry1 = [2, 5, 8, 1, 4, 7, 0, 3, 6] # r1
    symmetry2 = [8, 7, 6, 5, 4, 3, 2, 1, 0] # r2
    symmetry3 = [6, 3, 0, 7, 4, 1, 8, 5, 2] # r3
    symmetry4 = [2, 1, 0, 5, 4, 3, 8, 7, 6] # s
    symmetry5 = [6, 7, 8, 3, 4, 5, 0, 1, 2] # sr1
    symmetry6 = [8, 5, 2, 7, 4, 1, 6, 3, 0] # sr2
    symmetry7 = [0, 3, 6, 1, 4, 7, 2, 5, 8] # sr3
    symmetries = [symmetry0, symmetry1, symmetry2, symmetry3, symmetry4, symmetry5, symmetry6, symmetry7]
    return [[board[i] for i in symmetry] for symmetry in symmetries]

def enumerate_states():
    """
    This function is used to enumerate all the possible states  and corresponding actions of the tic tac toe game
    UPTO ISOMORPHISM

    1 indicates agent's symbol when it's their turn so the number of 1s is at least the number of -1s minus 1
    and at most the number of -1s
    """
    isomorphic_boards = {}
    non_terminal = set()
    for i in range(3 ** 9):
        board = []
        for j in range(9):
            board.append((i // (3 ** j)) % 3 - 1)
        if board.count(1) < board.count(-1) - 1 or board.count(1) > board.count(-1):
            continue
        iso = False
        if (board.count(1), board.count(-1)) not in isomorphic_boards:
            isomorphic_boards[(board.count(1), board.count(-1))] = []

        for ib in isomorphic_boards[(board.count(1), board.count(-1))]:
            if check_symmetry(board, ib):
                iso = True
                break
        if iso:
            continue
        
        # check that the board is valid (i.e. if both players have 3 in a row, the board
        # is invalid)
        
        winners = num_winners(board)
        if winners > 1: 
            continue
        if winners == 0 and board.count(0) > 0:
            non_terminal.add(tuple(board))
        
        isomorphic_boards[(board.count(1), board.count(-1))].append(board)

    s_plus = set()
    for k, v in isomorphic_boards.items():
        for board in v:
            s_plus.add(tuple(board))
    s = non_terminal
    return s_plus, s

def num_winners(board):
    winners = set()
    for i in range(3):
        if board[i] == board[i + 3] == board[i + 6] != 0:
            winners.add(board[i])
        if board[i * 3] == board[i * 3 + 1] == board[i * 3 + 2] != 0:
            winners.add(board[i * 3])
        if board[0] == board[4] == board[8] != 0:
            winners.add(board[0])
        if board[2] == board[4] == board[6] != 0:
            winners.add(board[2])
    return len(winners)

In [72]:
ttt = env()
a = Agent(ttt)

Initialization done!


In [77]:
a.train(gamma = .9, theta = 1e-10)

100%|██████████| 627/627 [00:00<00:00, 721.05it/s] 


Value iteration done!


In [None]:
a.self_test()

In [78]:
a.play_random(10000)

  0%|          | 0/10000 [00:00<?, ?it/s]

100%|██████████| 10000/10000 [00:07<00:00, 1301.33it/s]

Wins: 6146, Draws: 3689, Losses: 165





(6146, 3689, 165)