In [1]:
import numpy as np
import matplotlib.pyplot as plt
import copy

In [2]:
class TicTacToe:

    def __init__(self):
        # Initializes the board and the current player
        self.board = [0 for _ in range(9)]
        self.current_player = -1 # -1 for X, 1 for O

    def print_board(self):

        # Prints the board by iterating over the board list
        for i in range(0, 9, 3):
            print(str(self.board[i]) + "|" + str(self.board[i + 1]) + "|" + str(self.board[i + 2])) # prints out one row

            if i < 6:
                print("-" * 5)

        print()

    def check_win(self, player):

        # Enumerates the list of win conditions and checks if any of them are satisfied
        win_conditions = [
                (0, 1, 2), (3, 4, 5), (6, 7, 8), # horizontal
                (0, 3, 6), (1, 4, 7), (2, 5, 8), # vertical
                (0, 4, 8), (2, 4, 6)            # diagonal
        ]

        for condition in win_conditions:
            if all(self.board[i] == player for i in condition):
                return True
            
        return False

    def step(self, position):
        '''
        Performs a step in the game by placing the current player's marker on the board

        Parameters
        ----------
        position: int
            The position on the board to place the marker

        Returns
        -------
        board: list
            The current state of the board
        current_player: int
            The current player
        done: bool
            Whether the game is over or not
        '''

        # Checks if the position is empty
        if self.board[position] == 0:

            # Places the current player's marker on the board and checks for a win
            self.board[position] = self.current_player

            if self.check_win(self.current_player):
                return self.board, self.current_player, True
            # If the board is full, but there's no win
            elif 0 not in self.board:
                return self.board, 0, True
            
            self.current_player = 1 if self.current_player == -1 else -1
            return self.board, self.current_player, False
        
        # If the position is not empty
        else:
            print("Cell already occupied. Try again.")
            return self.board, self.current_player, False

    def reset(self):
        self.__init__()

In [3]:
from enum import Enum

# TODO: Generate all possible states

class State(Enum):
    X = -1
    O = 1
    DRAW = 0
    NOT_TERMINAL = 2

def has_player_won(board, player):

    # List down all win conditions for a 3x3 board
    win_conditions = [
            (0, 1, 2), (3, 4, 5), (6, 7, 8), # horizontal
            (0, 3, 6), (1, 4, 7), (2, 5, 8), # vertical
            (0, 4, 8), (2, 4, 6)            # diagonal
    ]

    for condition in win_conditions:
        if all(board[i] == player for i in condition):
            return True
        
    return False

def is_board_terminal(board):
    '''
    Checks if the board is in a terminal state
    '''
    # Check if X or O have won
    if has_player_won(board, State.X.value):
        return State.X.value
    elif has_player_won(board, State.O.value):
        return State.O.value
    
    # If the board is full, but there's no win
    if 0 not in board:
        return State.DRAW.value
    
    # If the board is not full
    return State.NOT_TERMINAL.value

def reward_function(board):
    '''
    Returns the reward for the current board
    '''

    state_value = is_board_terminal(board)

    if state_value != State.NOT_TERMINAL.value:
        return state_value # could be a winning condition or a terminal draw
    else:
        return 0
    
def check_invalid_after_win(board):
    '''
    A utility function that checks if the board is in an invalid state after a win
    This can happen because X won and then O played again
    There are two cases to check if a winning satate is valid:
    1. If X has won, then there should be one more X than O
    2. If O has won, then there should be the same number of X and O
    3. If the game is a draw, then there should be one more X than O
    '''
    # Check if there is a winner
    has_x_won = has_player_won(board, State.X.value)
    has_o_won = has_player_won(board, State.O.value)
    state_value = is_board_terminal(board)

    # It is not possible that both have won
    if has_x_won and has_o_won:
        return True

    # If only X has won
    elif has_x_won:
        # Check if there is one more X than O
        if sum(board) == -1:
            return False
        else:
            return True
        
    # If only O has won
    elif has_o_won:
        # Check if there is the same number of X and O
        if sum(board) == 0:
            return False
        else:
            return True
    
    # If there is a draw then there should be one more X than O
    elif state_value == State.DRAW.value:
        # Should be one more X than O
        if sum(board) == -1:
            return False
        else:
            return True

    else:
        # If there is no winner, then the board is valid
        return False
    

# TODO: Generate all possible states
def generate_boards(board, index=0):
    all_boards = []

    # If the board is complete, add it to the list
    if index == len(board):
        all_boards.append(board.copy())
        return all_boards

    # Try placing X, O, or leaving the spot empty
    for value in [-1, 0, 1]:
        board[index] = value
        all_boards.extend(generate_boards(board, index + 1))

    return all_boards

def generate_all_states():
    '''
    Generates all possible valid states. This means checking the validity of each state, which means ensuring:
    1. The sum of the elements is either -1 or 0 (since X being -1 always moves first)
    2. There are no duplicate states
    3. There are no moves that have occured after the game has ended
    '''
    board = [0 for _ in range(9)]

    # Generate all possible boards without constraint
    all_boards = generate_boards(board)

    # Now check which of these have a sum that is not -1 or 0
    all_boards = [board for board in all_boards if sum(board) in [-1, 0]]

    # Get rid of all duplicate lists (keep only the unique ones)
    all_boards = list(set([tuple(board) for board in all_boards]))
    all_boards = [list(board) for board in all_boards]

    # Now check which of these have moves that have occured after the game has ended
    all_boards = [board for board in all_boards if not check_invalid_after_win(board)]

    # store the player whose turn it is in each state
    turns = [1 if sum(board) else -1 for board in all_boards]
    
    return all_boards, turns

states, turns = generate_all_states()
assert len(states) == 5478, f"There are {len(states)} possible states, expected 5478"

In [4]:
def get_actions(state):
    '''
    Return possible actions
    '''
    return [i for i, p in enumerate(state) if p == 0]    

def get_reward(state):
    '''
    Return reward for state
    '''
    reward = is_board_terminal(state)
    reward = 0 if reward == State.NOT_TERMINAL.value else reward

    return reward

def value_iteration(states, turns, max_iterations, gamma, epsilon_check=False):
    '''
    Performs Value Iteration and returns the final Value table, Q-table and policies
    '''

    # Initialise tables
    v_table = {tuple(state): 0 for state in states}
    q_table = {tuple(state): {action: 0 for action in get_actions(state) if is_board_terminal(state) == State.NOT_TERMINAL.value} for state in states}


    for _ in range(max_iterations):
        # defining optimisation type for each player
        optim = {-1: min, 1: max}

        # copying tables so we can update them at the end
        v = copy.deepcopy(v_table)
        q = copy.deepcopy(q_table)

        # epsilon value to check for convergence
        max_change = 0

        for state, turn in zip(states, turns):
            # skip if terminal state as no actions possible
            if is_board_terminal(state) != State.NOT_TERMINAL.value:
                continue
            
            # get all valid actions
            actions = get_actions(state)

            # store the values for next state in v_vals
            v_vals = []
            for action in actions:
                next_state = copy.deepcopy(state)
                next_state[action] = turn

                v_val = gamma*(get_reward(next_state) + v_table[tuple(next_state)])
                
                v_vals.append(v_val)

            # choose best option for that player and store the value
            best_v = optim[turn](v_vals)

            # update epsilon value
            max_change = max((max_change, abs(v[tuple(state)]-best_v)))

            # update the values in temp value table and q table
            v[tuple(state)] = best_v
            q[tuple(state)] = {actions[i]: v_vals[i] for i in range(len(actions))}
        
        # update the original v_table and q_table with the temp tables
        v_table = v
        q_table = q

        if max_change < 10e-18 and epsilon_check:
            break

            
    # find policies
    policy1 = {}
    policy2 = {}
    # defining policy for players
    policy = {-1: policy1, 1: policy2}
    for state, turn in zip(states, turns):
        q_row = q_table[tuple(state)]

        # check if any actions exist (they dont for terminal states)
        if q_row:
            v = list(q_row.values())
            k = list(q_row.keys())
            policy[turn][tuple(state)] = k[v.index(optim[turn](v))]

    return v_table, q_table, policy


v_table, q_table, policy = value_iteration(states, turns, max_iterations=100, gamma=0.99)

In [5]:
env = TicTacToe()
draws = 0
games = 100
for _ in range(games):
    env.reset()
    while True:
        # env.print_board() # You can comment this part out if you don't want to see the board

        # TODO: fit in the code to make a move based off the policy
        position = policy[env.current_player][tuple(env.board)]
        if not (0 <= position < 9):
            print("Invalid position. Please try again.")
            continue

        board, player, terminated = env.step(position)

        if terminated:
            if player == -1:
                print("Player -1 wins")
            elif player == 1:
                print("Player 1 wins")
            elif player == 0:
                draws += 1
            break


print(f"Draws: {draws/games*100}%")

Draws: 100.0%


In [26]:
# If you want to use the GUI version that is up to you. While testing your algorithm we will use a code simliar to the above.

import tkinter as tk
from tkinter import messagebox

class TicTacToeGUI:
    def __init__(self, root):
        self.root = root
        self.root.title("Tic Tac Toe")
        self.game = TicTacToe()
        self.buttons = []
        self.create_board()

    def create_board(self):
        for i in range(3):
            for j in range(3):
                button = tk.Button(self.root, text=" ", font=('Arial', 20), width=5, height=2,
                                   command=lambda row=i, col=j: self.on_click(row, col))
                button.grid(row=i, column=j)
                self.buttons.append(button)

    def update_board(self):
        for i in range(9):
            self.buttons[i]['text'] = 'X' if self.game.board[i] == -1 else 'O' if self.game.board[i] == 1 else ' '

    def on_click(self, row, col):
        position = row * 3 + col
        board, player, terminated = self.game.step(position)
        self.update_board()

        if terminated:  
            if player == -1:
                messagebox.showinfo("Game Over", "Player -1 wins")
            elif player == 1:
                messagebox.showinfo("Game Over", "Player 1 wins")
            elif player == 0:
                messagebox.showinfo("Game Over", "It's a draw")
            self.root.quit()

            return

        board, player, terminated = self.game.step(policy[1][tuple(board)])
        self.update_board()

        if terminated:
            if player == -1:
                messagebox.showinfo("Game Over", "Player -1 wins")
            elif player == 1:
                messagebox.showinfo("Game Over", "Player 1 wins")
            elif player == 0:
                messagebox.showinfo("Game Over", "It's a draw")
            self.root.quit()

if __name__ == '__main__':
    root = tk.Tk()
    tic_tac_toe_gui = TicTacToeGUI(root)
    root.mainloop()


You are required to solve Tic Tac Toe using **Value Iteration**. The optimal policy should give a draw no matter what.  

It should work whether you are player -1 or player 1. It doesn't matter what your turn is.

It also follows that your code should win board states where it can and make optimal decisions.

Our final goal is to solve for 3d Tic Tac Toe, 4 X 4 X 4, using reinforcement learning. But, we will start by solving the 2d case and then gradually build up to the 3d case.

There is no starter code available. You are free to choose your implementation. One suggestion is to give +1 reward for a win, 0 for a draw and -1 for a loss.