In [2]:
import numpy as np
import time

In [3]:
from enum import Enum
import numpy as np
from scipy import signal
from typing import Callable, Optional, Tuple

BoardPiece = np.int8  # The data type (dtype) of the board
NO_PLAYER = BoardPiece(
    0)  # board[i, j] == NO_PLAYER where the position is empty
PLAYER1 = BoardPiece(
    1
)  # board[i, j] == PLAYER1 where player 1 (player to move first) has a piece
PLAYER2 = BoardPiece(
    -1
)  # board[i, j] == PLAYER2 where player 2 (player to move second) has a piece
rows = BoardPiece(6)
columns = BoardPiece(7)

BoardPiecePrint = str  # dtype for string representation of BoardPiece
NO_PLAYER_PRINT = BoardPiecePrint(' ')
PLAYER1_PRINT = BoardPiecePrint(f'X')
PLAYER2_PRINT = BoardPiecePrint('O')

PlayerAction = np.int8  # The column to be played

class GameState(Enum):
    IS_WIN = 1
    IS_DRAW = 2
    STILL_PLAYING = 0

class SavedState:
    pass

GenMove = Callable[
    [np.ndarray, BoardPiece, Optional[SavedState]],  # Arguments for the generate_move function
    Tuple[PlayerAction, Optional[SavedState]]  # Return type of the generate_move function
]


def initialize_game_state() -> np.ndarray:
    """
    Returns an ndarray, shape (6, 7) and data type (dtype) BoardPiece, initialized to 0 (NO_PLAYER).
    """
    return np.full((rows, columns), NO_PLAYER, dtype=BoardPiece)


def pretty_print_board(board: np.ndarray) -> str:
    """
    Should return `board` converted to a human readable string representation,
    to be used when playing or printing diagnostics to the console (stdout). The piece in
    board[0, 0] should appear in the lower-left. Here's an example output, note that we use
    PLAYER1_Print to represent PLAYER1 and PLAYER2_Print to represent PLAYER2):
    |==============|
    |              |
    |              |
    |    X X       |
    |    O X X     |
    |  O X O O     |
    |  O O X X     |
    |==============|
    |0 1 2 3 4 5 6 |
    """
    board = board[::-1] 
    pp_board = f"|==============|\n"  # add top border of board
    for row in range(board.shape[0]):
        pp_board += f"|" + np.array2string(board[row, :]).replace(
            '[', '').replace(']', '').replace('-1', PLAYER2_PRINT) + " |\n"  # replace player values with strings: -1 value allows use of convolutions

        if np.any(board[row, :] == PLAYER2):  # handle both -1 and 1 cases, since -1 is two characters
            pp_board = pp_board.replace(' 0', NO_PLAYER_PRINT).replace(
                ' 1', PLAYER1_PRINT)
        else:
            pp_board = pp_board.replace('0', NO_PLAYER_PRINT).replace(
                '1', PLAYER1_PRINT)

    pp_board += f"|==============|\n|0 1 2 3 4 5 6 |"  # add bottom border
    return pp_board


def string_to_board(pp_board: str) -> np.ndarray:
    """
    Takes the output of pretty_print_board and turns it back into an ndarray.
    This is quite useful for debugging, when the agent crashed and you have the last
    board state as a string.
    """
    board = pp_board[18:-33].replace('|\n|', '').replace('|', '')[0:-1:2]  # remove borders of pretty print board
    board = np.reshape([
        PLAYER1 if board[i] == PLAYER1_PRINT else  # retrieve player values from string board
        PLAYER2 if board[i] == PLAYER2_PRINT else NO_PLAYER
        for i in range(len(board))
    ], (rows, columns))
    return board[::-1]


def apply_player_action(board: np.ndarray, action: PlayerAction,
                        player: BoardPiece) -> np.ndarray:
    """
    Sets board[i, action] = player, where i is the lowest open row. Raises a ValueError
    if action is not a legal move. If it is a legal move, the modified version of the
    board is returned and the original board should remain unchanged (i.e., either set
    back or copied beforehand).
    """

    if action > columns - 1 or action < 0:
        raise ValueError('Action outside of board.')

    if ~np.any(board[:, action] == NO_PLAYER):
        raise ValueError('Column is already full.')

    modified_board = board.copy()
    for i in range(rows):
        if modified_board[:, action][i] == NO_PLAYER:  # find first non-filled space in column
            modified_board[:, action][i] = player
            break

    return modified_board


def connected_four(board: np.ndarray, player: BoardPiece) -> bool:
    """
    Returns True if there are four adjacent pieces equal to `player` arranged
    in either a horizontal, vertical, or diagonal line. Returns False otherwise.
    """
    vertical_kernel = np.ones((4, 1), dtype=BoardPiece)
    horizontal_kernel = np.ones((1, 4), dtype=BoardPiece)
    diagonal_kernel = np.eye(4, dtype=BoardPiece)
    off_diagonal_kernel = np.eye(4, dtype=BoardPiece)[::-1]

    is_win = False
    if np.any(signal.convolve2d(board, vertical_kernel, 'same') == player * 4):
        is_win = True
    elif np.any(signal.convolve2d(board, horizontal_kernel, 'same') == player * 4):
        is_win = True
    elif np.any(signal.convolve2d(board, diagonal_kernel, 'same') == player * 4):
        is_win = True
    elif np.any(signal.convolve2d(board, off_diagonal_kernel, 'same') == player * 4):
        is_win = True
    return is_win


def check_end_state(board: np.ndarray, player: BoardPiece) -> GameState:
    """
    Returns the current game state for the current `player`, i.e. has their last
    action won (GameState.IS_WIN) or drawn (GameState.IS_DRAW) the game,
    or is play still on-going (GameState.STILL_PLAYING)?
    """
    if connected_four(board, player):
        return GameState.IS_WIN

    elif np.all(board != NO_PLAYER):
        return GameState.IS_DRAW

    else:
        return GameState.STILL_PLAYING


def is_terminal_board(board: np.ndarray, player: BoardPiece) -> bool:
    """
    Returns True only if the game is over at the current state.
    """
    if check_end_state(board, player) != GameState.STILL_PLAYING:
        return True
    else:
        return False


In [4]:
disable_jit = True
if disable_jit:
    import os

    os.environ["NUMBA_DISABLE_JIT"] = "1"

from agents.common import (
    PLAYER1,
    PLAYER2,
    NO_PLAYER,
    GameState,
    BoardPiece,
    SavedState,
    PlayerAction,
    check_end_state,
    is_terminal_board,
    apply_player_action,
)
import numpy as np
import time
from numba import njit
from typing import Optional, Tuple


def get_valid_actions(board):
    return np.where(board[-1, :] == NO_PLAYER)[0]


def other_player(player):
    return PLAYER1 if player is PLAYER2 else PLAYER2


class Node:
    def __init__(self, board, parent, player):

        # State details
        self.board = board
        self.player = player
        self.parent = parent
        self.children = {}
        self.unplayed_actions = get_valid_actions(board)
        if is_terminal_board(board, other_player(player)) or is_terminal_board(
            board, player
        ):
            self.is_terminal = True
        else:
            self.is_terminal = False

        # Monte Carlo Metrics
        self.visits = 0
        self.wins = 0


# Monte Carlo Tree Search
class MCTS(object):
    def __init__(
        self,
        current_player,
        current_board,
        iterations = 2000,
        timeout = 5,
        exploration_const=1 / np.sqrt(2),
    ):

        self.iterations = iterations
        self.exploration_const = exploration_const
        self.current_player = current_player
        self.rootnode = Node(current_board, None, current_player)

    @njit()
    def get_best_action(self):
        for _ in range(self.iterations):
            node = self.select()
            result = self.simulate(node)
            self.backpropogate(node, result)

        most_visits = -1
        for action in self.rootnode.children:
            child_visits = self.rootnode.children[action].visits
            if child_visits > most_visits:
                best_action = action
                most_visits = child_visits

        return best_action

    @njit()
    def select(self):
        node = self.rootnode
        while not node.is_terminal:
            if len(node.unplayed_actions) == 0:
                node = self.get_best_child(node)
            else:
                return self.expand(node)
        return node

    def expand(self, node):
        action = np.random.choice(node.unplayed_actions)
        child_board = apply_player_action(node.board, action, node.player)
        node.unplayed_actions = node.unplayed_actions[node.unplayed_actions != action]
        child_node = Node(child_board, node, other_player(node.player))
        node.children[action] = child_node
        return child_node

    @njit()
    def simulate(self, node):
        board = node.board
        player = node.player
        while len(get_valid_actions(board)) != 0 and not is_terminal_board(
            board, other_player(player)
        ):
            action = np.random.choice(get_valid_actions(board))
            board = apply_player_action(board, action, player)
            player = other_player(player)
        result = (
            1 if check_end_state(board, self.current_player) is GameState.IS_WIN else 0
        )
        return result

    @njit()
    def backpropogate(self, node, result):
        while node is not None:
            node.visits += 1
            if other_player(node.player) == self.current_player:
                node.wins += result
            node = node.parent

    @njit()
    def get_best_child(self, node):
        best_score = np.NINF
        for child in node.children.values():
            move_score = child.wins / child.visits + self.exploration_const * np.sqrt(
                np.log(node.visits / child.visits)
            )
            if move_score > best_score:
                best_score = move_score
                best_child = child

        return best_child


def generate_move_mcts(
    board: np.ndarray,
    player: BoardPiece,
    saved_state: Optional[SavedState],
    iterations=2000,
    timeout=4.5
) -> Tuple[PlayerAction, Optional[SavedState]]:
    """
    Runs the mcts algorithm and returns best action.
    """
    mcts_search = MCTS(player, board, iterations, timeout)
    action = mcts_search.get_best_action()

    return action, saved_state


In [48]:
disable_jit = True
if disable_jit:
    import os

    os.environ["NUMBA_DISABLE_JIT"] = "1"

import numpy as np
import time
from numba import njit
from functools import cache
from typing import Optional, Tuple


def get_valid_actions(board):
    return np.where(board[-1, :] == NO_PLAYER)[0]


def other_player(player):
    return PLAYER1 if player is PLAYER2 else PLAYER2


class Node:
    def __init__(self, board, parent, player):

        # State details
        self.board = board
        self.player = player
        self.parent = parent
        self.children = {}
        self.unplayed_actions = get_valid_actions(board)
        if is_terminal_board(board, other_player(player)) or is_terminal_board(
            board, player
        ):
            self.is_terminal = True
        else:
            self.is_terminal = False

        # Monte Carlo Metrics
        self.visits = 0
        self.wins = 0


# Monte Carlo Tree Search
class MCTS(object):
    def __init__(
        self,
        current_player,
        current_board,
        iterations = 2000,
        timeout = 5,
        exploration_const=1 / np.sqrt(2),
    ):

        self.iterations = iterations
        self.exploration_const = exploration_const
        self.current_player = current_player
        self.rootnode = Node(current_board, None, current_player)

    @cache()
    @njit()
    def get_best_action(self):
        for _ in range(self.iterations):
            node = self.select()
            result = self.simulate(node)
            self.backpropogate(node, result)

        most_visits = -1
        for action in self.rootnode.children:
            child_visits = self.rootnode.children[action].visits
            if child_visits > most_visits:
                best_action = action
                most_visits = child_visits

        return best_action

    @njit()
    def select(self):
        node = self.rootnode
        while not node.is_terminal:
            if len(node.unplayed_actions) == 0:
                node = self.get_best_child(node)
            else:
                return self.expand(node)
        return node

    def expand(self, node):
        action = np.random.choice(node.unplayed_actions)
        child_board = apply_player_action(node.board, action, node.player)
        node.unplayed_actions = node.unplayed_actions[node.unplayed_actions != action]
        child_node = Node(child_board, node, other_player(node.player))
        node.children[action] = child_node
        return child_node

    @njit()
    def simulate(self, node):
        board = node.board
        player = node.player
        while len(get_valid_actions(board)) != 0 and not is_terminal_board(
            board, other_player(player)
        ):
            action = np.random.choice(get_valid_actions(board))
            board = apply_player_action(board, action, player)
            player = other_player(player)
        result = (
            1 if check_end_state(board, self.current_player) is GameState.IS_WIN else 0
        )
        return result

    @njit()
    def backpropogate(self, node, result):
        while node is not None:
            node.visits += 1
            if other_player(node.player) == self.current_player:
                node.wins += result
            node = node.parent

    @njit()
    def get_best_child(self, node):
        best_score = np.NINF
        for child in node.children.values():
            move_score = child.wins / child.visits + self.exploration_const * np.sqrt(
                np.log(node.visits / child.visits)
            )
            if move_score > best_score:
                best_score = move_score
                best_child = child

        return best_child


def generate_move_mcts(
    board: np.ndarray,
    player: BoardPiece,
    saved_state: Optional[SavedState],
    iterations=2000,
    timeout=4.5
) -> Tuple[PlayerAction, Optional[SavedState]]:
    """
    Runs the mcts algorithm and returns best action.
    """
    mcts_search = MCTS(player, board, iterations, timeout)
    action = mcts_search.get_best_action()

    return action, saved_state


TypeError: cache() missing 1 required positional argument: 'user_function'

In [41]:
mcts = MCTS(PLAYER1, initialize_game_state(), iterations=2000, timeout=5)
a = mcts.get_best_action()


In [32]:
mcts.select()

<__main__.Node at 0x22bc5def370>

In [54]:
board[np.where(board == -PLAYER1)] = 0

In [55]:
board

array([[0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0]], dtype=int8)