In [None]:
import numpy as np
import matplotlib.pyplot as plt
import copy
import random
import json
import os

# Specifying path for policy table
file_path = 'phase_1_policy.json'

In [2]:
class TicTacToe:
    def __init__(self):
        self.board = [0 for _ in range(9)]
        self.current_player = -1

    def print_board(self):
        for i in range(0, 9, 3):

            print(str(self.board[i]) + "|" + str(self.board[i + 1]) + "|" + str(self.board[i + 2]))
            if i < 6:
                print("-" * 5)

        print()

    def check_win(self, player):
        win_conditions = [(0, 1, 2), (3, 4, 5), (6, 7, 8),
                        (0, 3, 6), (1, 4, 7), (2, 5, 8),
                        (0, 4, 8), (2, 4, 6)]

        for condition in win_conditions:
            if all(self.board[i] == player for i in condition):
                return True
        return False

    def step(self, position):
        if self.board[position] == 0:
            self.board[position] = self.current_player
            if self.check_win(self.current_player):
                return self.board, self.current_player, True
            elif 0 not in self.board:
                return self.board, 0, True
            self.current_player = 1 if self.current_player == -1 else -1
            return self.board, self.current_player, False
        else:
            print("Cell already occupied. Try again.")
            return self.board, self.current_player, False


    def reset(self):
        self.__init__()

In [3]:
# This cell defines all the helper functions needed for this environment
win_conditions = [(0, 1, 2), (3, 4, 5), (6, 7, 8), (0, 3, 6), (1, 4, 7), (2, 5, 8), (0, 4, 8), (2, 4, 6)]

def generate_state_space(steps):
    """
    Generates the state space of a 2d tic-tac-toe.
    1) Iterates over each stage of the game assuming "Player -1" starts first
    2) Alternates between "Player -1 and "Player +1" for each stage
    3) Avoids invalid states by skipping the iterations for a state meeting terminal states requirements
    4) Removes duplicates by including only the unique states present at each stage
    5) Returns a list of arrays correspoinding to each stage (shape = (9,))
    """
    state_space = []

    state_space.append(np.array([0 for _ in range(steps)]))

    for j in range(steps):
        if j == 0:
            curr_action = -1
            step_state = []
            for i in range(steps):
                step = [0 for _ in range(steps)]
                step[i] = curr_action
                step_state.append(step)
            step_state = np.array(step_state)
        else:
            step_state = []
            curr_action = -1 * prev_action
            for i in range(len(prev_step_state)):
                step_temp = []
                for k in range(steps):
                    state = copy.deepcopy(prev_step_state[i])
                    if win_state(state, prev_action):
                        continue
                    if state[k] == -1 or state[k] == 1:
                        continue
                    state[k] = curr_action
                    step_temp.append(state)
                if not step_temp:
                    continue
                else:
                    step_state.append(step_temp)
            shape = np.shape(step_state)
            step_state = np.array(step_state).reshape(shape[0] * shape[1], shape[2])
            step_state = np.unique(step_state, axis=0)

        state_space.append(step_state)
        prev_action = curr_action
        prev_step_state = copy.deepcopy(step_state)

    return state_space

def get_value_table(num_states, total_state_space, terminal_states, gamma):
    """
    Runs value iterations for the 2d tic-tac-toe.
    1) Avoids updating terminal states
    2) Assumes the system is deterministic
    3) Returns a value table corresponding to the value of each state in state space
    """
    value_table = np.zeros(num_states)

    for j in range(100):

        state_space = copy.deepcopy(total_state_space)

        for i, state in enumerate(state_space):

            if not np.any(np.all(state == terminal_states, axis=1)):

                action, position = get_possible_actions(state)
                update_values = []

                for pos in position:
                    next_state = find_next_state(state, action, pos)
                    next_state_idx = find_next_state_index(next_state, state_space)
                    value = get_reward(next_state, action) + gamma * value_table[next_state_idx]
                    update_values.append(value)

                if action == -1:
                    value_table[i] = min(update_values)
                elif action == 1:
                    value_table[i] = max(update_values)

    return value_table

def get_policy(state_space, terminal_states, value_table, gamma):
    """
    Returns the optimal policy for the state space.
    1) optimal policy is a list containing the state, action value, and the position of the value to be placed
    """

    policy_table = []

    for state in state_space:

        if not np.any(np.all(state == terminal_states, axis=1)):

            action, position = get_possible_actions(state)
            values = []

            for pos in position:
                next_state = find_next_state(state, action, pos)
                next_state_idx = find_next_state_index(next_state, state_space)
                value = get_reward(next_state, action) + gamma * value_table[next_state_idx]
                values.append(value)

            if action == -1:
                idx = values.index(min(values))
            elif action == 1:
                idx = values.index(max(values))

            opt_pos = position[idx]
            opt_action = [state.tolist(), int(action), int(opt_pos)]

            policy_table.append(opt_action)

        else:

            policy_table.append([state.tolist(), int(0), int(0)])

    return policy_table

def win_state(state, action):
    """
    Checks for the win condition of a state-action pair.
    1) Takes a state and the corresponding action as an argument
    2) Checks for each condition in the win_conditions
    3) Returns a boolean "True" if any one of the win_conditions is met and vice versa
    """
    win = False
    for condition in win_conditions:
        if all(state[i] == action for i in condition):
            win = True
            break
    return win

def get_reward(state, action):
    """
    Returns a reward for the corresponding state-action pair.
    1) Takes a state and the corresponding action as an argument
    2) Returns -1 (for player 1) or +1 (for player 2) if the state-action pair meets the win_condition. Returns 0 otherwise
    """
    reward = 0
    if win_state(state, action):
        reward = action
    return reward

def find_next_state_index(state, ss):
    """
    Returns the row index of a state from the state space.
    1) Takes state and state space "ss" as an argument
    2) Returns the row index
    """
    index = np.where(np.all(state == ss, axis=1))[0][0]
    return index

def get_possible_actions(state):
    """
    Returns the possible value of action and its valid positions for a given state.
    1) Takes a state as an argument
    2) Sets a possible action and its applicable positions in the state
    3) Returns an action and a list of positions
    """
    if np.count_nonzero(state == -1) <= np.count_nonzero(state == 1):
        action = -1
    elif np.count_nonzero(state == 1) < np.count_nonzero(state == -1):
        action = 1
    position = np.where(state == 0)[0]
    return action, position

def find_next_state(state, action, pos):
    """
    Finds the next state to a corresponding state-action pair.
    1) Takes state, action value, and the position to be applied as an argument
    2) Returns the next state after applying the action on the corresponding state
    """
    s = copy.deepcopy(state)
    s[pos] = action
    return s

def get_terminal_states(ss):
    """
    Finds the terminal states of the state space.
    1) Take state space as an argument
    2) Checks for the state-wise win condtion for all stages except the last stage and appends it in the terminal stages
    3) Appends every state in the last stage into the terminal states
    4) Returns a numpy array of terminal states
    """
    terminal_states = []

    for i, stage_state_space in enumerate(ss):
        if i == 0:
            if (win_state(stage_state_space, -1) or win_state(stage_state_space, 1)):
                  terminal_states.append(stage_state_space)
        else:
            for state in stage_state_space:
                if (win_state(state, -1) or win_state(state, 1)) and i != 9:
                    terminal_states.append(state)
                if i == 9:
                    terminal_states.append(state)

    return np.array(terminal_states)

def get_position(state, policy_table):
    """
    Returns the optimum position for the player from the policy table
    """
    for x in policy_table:
        if np.array_equal(x[0], state):
            pos = x[2]
    return pos

In [4]:
# This cell computes the state space, value table, and policy table for this problem

ss = generate_state_space(9)
total_state_space = np.vstack(ss)
num_states = np.shape(total_state_space)[0]

terminal_states = get_terminal_states(ss)
gamma = 0.99

value_table = get_value_table(num_states, total_state_space, terminal_states, gamma)

In [5]:
# This cell computes the policy for each state
policy = get_policy(total_state_space, terminal_states, value_table, gamma)

with open(file_path, 'w') as file:
    json.dump(policy, file)

In [11]:
env = TicTacToe()
terminated = False
player_1_wins = 0
player_2_wins = 0
draws = 0

for i in range(1000):
    if terminated:
        env.reset()
    while True:
        position = get_position(env.board, policy)

        board, player, terminated = env.step(position)

        if terminated:
            if player == -1:
                player_1_wins += 1
            elif player == 1:
                player_2_wins += 1
            elif player == 0:
                draws += 1
            break
            
print(f"Draws: {draws}")
print(f"Player 1 wins: {player_1_wins}")
print(f"Player 2 wins: {player_2_wins}")

Draws: 1000
Player 1 wins: 0
Player 2 wins: 0
