In [71]:
import tqdm as tqdm

In [72]:
class env:
    """
    This class is used to define a tic tac toe environment
    """

    def __init__(self):
        self.board = [0, 0, 0, 0, 0, 0, 0, 0, 0]
        self.turn = 1
        self.winner = 0

    def reset(self, position = None):
        if position is None:
            position = [0, 0, 0, 0, 0, 0, 0, 0, 0]
        if type(position) != list:
            raise ValueError("The position must be a list of 9 elements")
        if len(position) != 9:
            raise ValueError("The position must be a list of 9 elements")
        self.board = position
        self.turn = 1
        self.winner = 0
        return self.board

    def step(self, action):
        """
        This function is used to make a move on the board
        Returns:
            - The new board
            - The reward (0 if there is no winner)
            - A boolean indicating if the game is finished
        If the action is not possible, the game is forfeited with a reward of -10 * self.turn
        """
        if self.board[action] == 0:
            self.board[action] = self.turn
            self.turn = -self.turn
            self.winner = self.check_winner()
            return self.board, self.winner, self.winner != 0
        else:
            return self.board, -10 * self.turn, True

    def check_winner(self, board = None):
        """
        This function is used to check if there is a winner
        """
        if board is None:
            board = self.board
        for i in range(3):
            if board[i] == board[i + 3] == board[i + 6] != 0:
                return board[i]
            if board[i * 3] == board[i * 3 + 1] == board[i * 3 + 2] != 0:
                return board[i * 3]
        if board[0] == board[4] == board[8] != 0:
            return board[0]
        if board[2] == board[4] == board[6] != 0:
            return board[2]
        if 0 not in board:
            return 0
        return 0
    
    def __str__(self):
        """
        This function is used to print the board
        """
        symbols = {0: " ", 1: "X", -1: "O"}
        return f"{symbols[self.board[0]]}|{symbols[self.board[1]]}|{symbols[self.board[2]]}\n-----\n{symbols[self.board[3]]}|{symbols[self.board[4]]}|{symbols[self.board[5]]}\n-----\n{symbols[self.board[6]]}|{symbols[self.board[7]]}|{symbols[self.board[8]]}"

In [73]:
ttt = env()
ttt.reset([1, 0, 0, 0, 1, 0, 0, 0, -1])
print(ttt)

X| | 
-----
 |X| 
-----
 | |O


In [74]:
class Agent():
    """
    This class is used to define a tic tac toe agent
    """

    def __init__(self, env):
        self.env = env
        self.init_policy()

    def init_policy(self):
        """
        This function is used to initialize the policy
        """
        self.state_actions, isos = enumerate_states()
        state_order = [(4,4),(3,4),(3,3),(2,3),(2,2),(1,2),(1,1),(0,1),(0,0)]
        ordered_non_terminal_states = []
        for counts in state_order:
            if sum(counts) < 9:
                ordered_non_terminal_states.extend(isos[counts])
        self.states = ordered_non_terminal_states
        self.values = {}
        self.policy = {}

        for state in self.state_actions:
            if self.state_actions[state]:
                self.policy[state] = self.state_actions[state][0]
                self.set_value(state, 0)
            else:
                self.set_value(state, self.env.check_winner(state))
        print("Initialization done!")

    def get_value(self, state):
        """
        This function is used to get the value of a state
        """
        for iso in get_isomorphisms(state):
            if tuple(iso) in self.values:
                return self.values[tuple(iso)]
        else:
            return False
    
    def get_policy(self, state):
        symmetry0 = [0, 1, 2, 3, 4, 5, 6, 7, 8] # e
        symmetry1 = [2, 5, 8, 1, 4, 7, 0, 3, 6] # r1
        symmetry2 = [8, 7, 6, 5, 4, 3, 2, 1, 0] # r2
        symmetry3 = [6, 3, 0, 7, 4, 1, 8, 5, 2] # r3
        symmetry4 = [2, 1, 0, 5, 4, 3, 8, 7, 6] # s
        symmetry5 = [6, 7, 8, 3, 4, 5, 0, 1, 2] # sr1
        symmetry6 = [8, 5, 2, 7, 4, 1, 6, 3, 0] # sr2
        symmetry7 = [0, 3, 6, 1, 4, 7, 2, 5, 8] # sr3
        symmetries = [symmetry0, symmetry1, symmetry2, symmetry3, symmetry4, symmetry5, symmetry6, symmetry7] 

        for s, iso in enumerate(get_isomorphisms(state)):
            if tuple(iso) in self.policy:
                # the symmetry
                sym = symmetries[s]
                # inverse of the symmetry
                sym_inv = [sym.index(i) for i in range(9)]
                return sym_inv[self.policy[tuple(iso)]]
        else:
            return False
            
    def set_value(self, state, value):
        """
        This function is used to set the value of a state
        """
        for iso in get_isomorphisms(state):
            if tuple(iso) in self.values:
                self.values[tuple(iso)] = value
                return True
        else:
            self.values[state] = value
            return False
        
    def set_policy(self, state, action):
        for iso in get_isomorphisms(state):
            if tuple(iso) in self.policy:
                self.policy[tuple(iso)] = action
                return True
        else:
            self.policy[state] = action
            return False

    def train(self, gamma = 0.9, theta = 0.1):
        """
        This function is used to train the agent
        """
        self.policy_iteration(gamma, theta)

    def test(self, n = 1000):
        """
        This function is used to test the agent against itself
        """
        player_1 = 0
        draws = 0
        for _ in tqdm.tqdm(range(n)):
            state = self.env.reset()
            turn = 1
            while True:
                state = state if turn == 1 else [-s for s in state]
                action = self.get_policy(tuple(state))
                state, reward, done = self.env.step(action)
                turn = -turn
                if done:
                    print(self.env)
                    print()
                    if reward == 1:
                        player_1 += 1
                    elif reward == 0:
                        draws += 1
                    break
        print(f"Player 1 wins: {player_1}, draws: {draws}, player 2 wins: {n - player_1 - draws}")

    # Todo: Add a function to have a random agent play against the trained agent

    def policy_iteration(self, gamma, theta):
        epochs = 0
        while True:
            self.iterative_policy_evaluation(gamma, theta)
            unstable = 0
            for state in tqdm.tqdm(self.states, desc=f"Epoch {epochs}"):
                old_action = self.get_policy(state)                  
                new_action = self.policy_improvement(state, gamma)
                if old_action != new_action:
                    unstable += 1
            epochs += 1
            print(f"Epoch {epochs}, unstability: {unstable}")
            if unstable == 0:
                print("Converged!")
                break

    def policy_improvement(self, state, gamma):
        state = tuple(state)
        best_action = self.get_policy(state)
        best_value = self.get_value(state)
        for action in self.state_actions[state]:
            self.env.reset(list(state))
            new_state, reward, _ = self.env.step(action)
            v = reward + gamma * self.get_value(tuple(new_state))
            if v > best_value:
                best_value = v
                best_action = action
        self.set_policy(state, best_action)
        return best_action
    
    def iterative_policy_evaluation(self, gamma, theta):
        while True:
            delta = 0
            for state in self.state_actions:
                if num_winners(list(state)) > 0:
                    continue
                v = self.get_value(state)
                new_v = 0
                for action in self.state_actions[state]:
                    self.env.reset(list(state))
                    new_state, reward, _ = self.env.step(action)
                    new_v += 1 / len(self.state_actions[state]) * (reward + gamma * self.get_value(tuple(new_state)))
                self.set_value(state, new_v)
                delta = max(delta, abs(v - new_v))
            if delta < theta:
                break

def check_symmetry(board1, board2):
    if board1 in get_isomorphisms(board2):
        return True
        
def get_isomorphisms(board):
    # Abstract algebra! Lol who would've thought I would ever use that
    symmetry0 = [0, 1, 2, 3, 4, 5, 6, 7, 8] # e
    symmetry1 = [2, 5, 8, 1, 4, 7, 0, 3, 6] # r1
    symmetry2 = [8, 7, 6, 5, 4, 3, 2, 1, 0] # r2
    symmetry3 = [6, 3, 0, 7, 4, 1, 8, 5, 2] # r3
    symmetry4 = [2, 1, 0, 5, 4, 3, 8, 7, 6] # s
    symmetry5 = [6, 7, 8, 3, 4, 5, 0, 1, 2] # sr1
    symmetry6 = [8, 5, 2, 7, 4, 1, 6, 3, 0] # sr2
    symmetry7 = [0, 3, 6, 1, 4, 7, 2, 5, 8] # sr3
    symmetries = [symmetry0, symmetry1, symmetry2, symmetry3, symmetry4, symmetry5, symmetry6, symmetry7]
    return [[board[i] for i in symmetry] for symmetry in symmetries]

def enumerate_states():
    """
    This function is used to enumerate all the possible states  and corresponding actions of the tic tac toe game
    UPTO ISOMORPHISM

    1 indicates agent's symbol when it's their turn so the number of 1s is at least the number of -1s minus 1
    and at most the number of -1s
    """
    state_actions = {}
    isomorphic_boards = {}
    for i in range(3 ** 9):
        board = []
        for j in range(9):
            board.append((i // (3 ** j)) % 3 - 1)
        if board.count(1) < board.count(-1) - 1 or board.count(1) > board.count(-1):
            continue
        iso = False
        if (board.count(1), board.count(-1)) not in isomorphic_boards:
            isomorphic_boards[(board.count(1), board.count(-1))] = []
        for ib in isomorphic_boards[(board.count(1), board.count(-1))]:
            if check_symmetry(board, ib):
                iso = True
                break
        if iso:
            continue
        
        # check that the board is valid (i.e. if both players have 3 in a row, the board
        # is invalid)
        
        winners = num_winners(board)
        if winners > 1: 
            continue
        if winners == 1:
            actions = []
        else:
            actions = [i for i in range(9) if board[i] == 0]
        state_actions[tuple(board)] = actions
        isomorphic_boards[(board.count(1), board.count(-1))].append(board)
    return state_actions, isomorphic_boards

def num_winners(board):
    winners = set()
    for i in range(3):
        if board[i] == board[i + 3] == board[i + 6] != 0:
            winners.add(board[i])
        if board[i * 3] == board[i * 3 + 1] == board[i * 3 + 2] != 0:
            winners.add(board[i * 3])
        if board[0] == board[4] == board[8] != 0:
            winners.add(board[0])
        if board[2] == board[4] == board[6] != 0:
            winners.add(board[2])
    return len(winners)

In [75]:
ttt = env()
a = Agent(ttt)

Initialization done!


In [76]:
a.train()

Epoch 0: 100%|██████████| 809/809 [00:00<00:00, 3405.28it/s]


Epoch 1, unstability: 211


Epoch 1: 100%|██████████| 809/809 [00:00<00:00, 3192.75it/s]

Epoch 2, unstability: 0
Converged!





In [78]:
a.test()

 95%|█████████▌| 950/1000 [00:00<00:00, 3140.35it/s]

X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
----

100%|██████████| 1000/1000 [00:00<00:00, 3080.35it/s]

X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
-----
 | | 
-----
 |O|O
X|X|X
----


