In [None]:
import math
import random


class MCTSNode:
    def __init__(self, game_state, move=None, parent=None):
        """
        Initialize the node with the game state, move, and parent.
        """
        self.game_state = game_state
        self.move = move
        self.parent = parent
        self.children = []
        self.wins = 0
        self.visits = 0
        self.untried_moves = game_state.get_legal_moves()

    def select_child(self):
        """
        Select a child node with the highest UCT score.
        """
        return max(self.children, key=lambda c: c.wins / c.visits + math.sqrt(2 * math.log(self.visits) / c.visits))

    def add_child(self, move, state):
        """
        Remove the move from untried moves, create a new child node, and add it to children.
        """
        child = MCTSNode(game_state=state, move=move, parent=self)
        self.untried_moves.remove(move)
        self.children.append(child)
        return child

    def update(self, result):
        """
        Update this node - increment the visit count and update the win count based on the result.
        """
        self.visits += 1
        self.wins += result


def mcts(root_state, iterations):
    """
    MCTS algorithm. Selection, expansion, simulation, and backpropagation.
    """
    root_node = MCTSNode(game_state=root_state)

    for _ in range(iterations):
        node = root_node
        state = root_state.copy()

        # selection
        while node.untried_moves == [] and node.children != []:
            node = node.select_child()
            state.claim_island(node.move)

        # expansion
        if node.untried_moves:
            move = random.choice(node.untried_moves)
            state.claim_island(move)
            node = node.add_child(move, state)

        # simulation
        while not state.is_game_over():
            possible_moves = state.get_legal_moves()
            state.claim_island(random.choice(possible_moves))

        # backpropagation
        while node is not None:
            node.update(state.get_result(root_state.player_turn))
            node = node.parent

    for child in root_node.children:
        print(f"Move {child.move}: {child.visits} visits, {child.wins / child.visits:.2f} win rate")
    return sorted(root_node.children, key=lambda c: c.visits)[-1].move

In [None]:
class IslandConquest:
    def __init__(self):
        """
        Initialize the game with all islands unclaimed
        0 for unclaimed, 1 for Player 1, -1 for Player 2
        Player 1 starts
        """
        self.islands = [0, 0, 0, 0]
        self.player_turn = 1

    def claim_island(self, island):
        """
        Player claims an island. Island numbers are 1-4.
        Switch turns after a valid move.
        """
        if self.islands[island - 1] == 0:
            self.islands[island - 1] = self.player_turn
            self.player_turn *= -1
            return True
        else:
            return False  # Island already claimed

    def check_win(self):
        """
        Mapping the win/lose/tie conditions to sums of island states.
        Calculate the sum of states for comparison with the conditions.
        Return the result of the game (winner/tie/game continues).
        """
        win_conditions = [(1, 4), (1, 2)]
        lose_conditions = [(2, 3), (3, 4)]

        island_sums = {tuple: sum(self.islands[i - 1] for i in tuple) for tuple in win_conditions + lose_conditions}

        for condition in win_conditions:
            if island_sums[condition] == len(condition):
                return "Player 1 wins"
            elif island_sums[condition] == -len(condition):
                return "Player 2 wins"

        for condition in lose_conditions:
            if island_sums[condition] == len(condition):
                return "Player 2 wins"
            elif island_sums[condition] == -len(condition):
                return "Player 1 wins"

        if self.islands[0] + self.islands[2] in [-2, 2] and self.islands[1] + self.islands[3] in [-2, 2]:
            return "It's a tie"

        return "Game continues"

    def is_game_over(self):
        """
        Check if the game is over.
        """
        return self.check_win() != "Game continues"

    def current_state(self):
        """
        Print the current state of the game.
        """
        state = " ".join(["1" if x == 1 else "2" if x == -1 else "0" for x in self.islands])
        print(f"Islands: {state}")
        print(f"Player {'1' if self.player_turn == 1 else '2'}'s turn")
        print(self.check_win())

    def get_legal_moves(self):
        """
        Get the list of unclaimed islands.
        """
        return [i + 1 for i, x in enumerate(self.islands) if x == 0]

    def copy(self):
        """
        Used in MCTS to copy the game state.
        """
        new_game = IslandConquest()
        new_game.islands = self.islands[:]
        new_game.player_turn = self.player_turn
        return new_game

    def get_result(self, player):
        """
        Get score of the game for the player for MCTS.
        """
        result = self.check_win()
        if result == "Player 1 wins" and player == 1 or result == "Player 2 wins" and player == -1:
            return 1
        elif result == "Player 1 wins" and player == -1 or result == "Player 2 wins" and player == 1:
            return -1
        return 0  # Tie or game continues


def play_game_with_mcts(game, mcts_iterations=1000):
    """
    Play the game with MCTS.
    Selects the best move for the current player and game state and performs it.
    """
    while not game.is_game_over():
        print("\nCurrent game state:")
        game.current_state()

        best_move = mcts(game.copy(), mcts_iterations)
        print(f"Recommended move: Claim island {best_move}")

        game.claim_island(best_move)

        if game.is_game_over():
            print("\nFinal game state:")
            game.current_state()
            break

In [None]:
game = IslandConquest()
play_game_with_mcts(game)