# Create environment

In [1]:
from dataclasses import dataclass, field
from abc import ABC, abstractmethod
import random

In [2]:
@dataclass(frozen=True)
class NimState:
    piles: list[int] = field(default_factory=list)
    player: int = 0

    def __post_init__(self):
        assert all(pile >= 0 for pile in self.piles)

    def __str__(self):
        return f"Player {self.player}: {self.piles}"

    def __hash__(self):
        return hash(str(self))


@dataclass
class NimAction:
    pile_index: int
    amount: int
    probability: float = -1

    def __str__(self):
        return f"Take {self.amount} from pile {self.pile_index} with probability {self.probability}"

In [3]:
class BaseNimPlayer(ABC):
    @abstractmethod
    def get_action(self, state: NimState, actual_valid_actions: list[NimAction]) -> NimAction:
        pass

In [4]:
class StochasticNimEnvironment:
    def __init__(
        self, zero_player: BaseNimPlayer, one_player: BaseNimPlayer, state: NimState | None = None, seed: int = 0
    ) -> None:
        # Save the players
        self.zero_player = zero_player
        self.one_player = one_player

        # Initialize the state
        self.current_state = state if state is not None else self.initial_state()
        self.current_dice = -1

        # Set the seed
        random.seed(seed)

    def _reroll_dice_if_needed(self) -> None:
        """
        Re-roll the dice if it is not rolled yet
        """
        if self.current_dice == -1:
            self.current_dice = random.randint(1, 6)

    def reset(self) -> None:
        """
        Reset the environment to the initial state
        """
        self.current_state = self.initial_state()
        self.current_player = 0
        self.current_dice = -1

    def get_next_max_amount(self) -> int:
        """
        Get the maximum amount of face that can be taken from the pile
        Randomly choose the face if the dice is not rolled yet
        """
        self._reroll_dice_if_needed()

        return self.current_dice

    def step(self, action: NimAction) -> tuple[NimState, bool]:
        """
        Take the action and return the next state and whether the game is over
        """
        # Roll the dice if needed
        self._reroll_dice_if_needed()

        # Check if left pile is valid
        if self.current_state.piles[action.pile_index] < action.amount:
            raise ValueError(
                "Pile is not enough, left: {}, take: {}".format(
                    self.current_state.piles[action.pile_index], action.amount
                )
            )

        # Check if take more than the dice face
        if action.amount > self.get_next_max_amount():
            raise ValueError(
                "Take more than the dice face, take: {}, dice: {}".format(action.amount, self.current_dice)
            )

        # Update the pile
        new_piles = self.current_state.piles.copy()
        new_piles[action.pile_index] -= action.amount
        self.current_state = NimState(new_piles, 1 - self.current_state.player)

        # Reset the dice
        self.current_dice = -1

        return self.current_state, self.is_game_over(self.current_state)

    def play(self, verbose: bool = True) -> list[NimState]:
        """
        Play the game until it is over
        """
        states = [self.current_state]
        while not self.is_game_over(self.current_state):
            # Get the actual valid actions
            actual_valid_actions = self.get_actual_valid_action()

            # Get the action from the player
            if self.current_state.player == 0:
                action = self.zero_player.get_action(self.current_state, actual_valid_actions)
            else:
                action = self.one_player.get_action(self.current_state, actual_valid_actions)

            # Take the action
            self.step(action)

            # Save the state
            states.append(self.current_state)

        if verbose:
            print("Player 0" if self.current_state.player == 1 else "Player 1", "wins!")

        return states

    @staticmethod
    def step_without_stochastic(state: NimState, action: NimAction) -> tuple[NimState, bool]:
        """
        Take the action and return the next state
        """
        # Check if left pile is valid
        if state.piles[action.pile_index] < action.amount:
            raise ValueError(
                "Pile is not enough, left: {}, take: {}".format(state.piles[action.pile_index], action.amount)
            )

        # Update the pile
        new_piles = state.piles.copy()
        new_piles[action.pile_index] -= action.amount
        new_state = NimState(new_piles, 1 - state.player)

        return new_state, StochasticNimEnvironment.is_game_over(new_state)

    @staticmethod
    def is_game_over(state: NimState) -> bool:
        """
        Return True if the game is over
        """
        return all([pile == 0 for pile in state.piles])

    @staticmethod
    def initial_state() -> NimState:
        """
        Return the initial state of the game
        """
        return NimState([3, 4, 5], 0)

    @staticmethod
    def get_stochastic_valid_actions(state: NimState) -> list[NimAction]:
        """
        Return a list of valid actions for the current state
        """
        valid_actions = []
        for pile_index, pile in enumerate(state.piles):
            for amount in range(1, pile + 1):
                # Each play will die which face 1-6 with equal probability
                # And player can only take 1-face from the pile
                # e.g. if the face is 3, the player can only take 1, 2, 3 from the pile
                probability = (7 - amount) / 6

                valid_actions.append(NimAction(pile_index, amount, probability))

        return valid_actions

    def get_actual_valid_action(self) -> list[NimAction]:
        """
        Return a list of valid actions for the current state
        """
        # Roll the dice if needed
        self._reroll_dice_if_needed()

        # Get the all possible actions
        stochastic_actions = self.get_stochastic_valid_actions(self.current_state)

        # Filter the actions based on the dice face
        valid_actions = [action for action in stochastic_actions if action.amount <= self.current_dice]

        # Set the probability to 1 if the action is valid
        for action in valid_actions:
            action.probability = 1

        return valid_actions

# Create an Agent class

In [5]:
class HumanNimPlayer(BaseNimPlayer):
    def get_action(self, state: NimState, actual_valid_actions: list[NimAction]) -> NimAction:
        print("Your turn!")
        print("Current state -----------")
        print(state)
        print("Valid actions -----------")
        print("\n - ".join(str(action) for action in actual_valid_actions))
        print("-------------------------")
        pile_index = int(input("Enter the pile index: "))
        amount = int(input("Enter the amount: "))

        return NimAction(pile_index, amount)

In [11]:
class RandomNimPlayer(BaseNimPlayer):
    """
    A player that plays randomly from the available actions
    """

    def __init__(self, seed: int = 0) -> None:
        random.seed(seed)

    def get_action(self, state: NimState, actual_valid_actions: list[NimAction]) -> NimAction:
        return random.choice(actual_valid_actions)

In [25]:
class MiniMaxNimPlayer(BaseNimPlayer):
    """
    A player that plays the best move based on the minimax algorithm
    """

    def __init__(self, max_depth: int = 3, player_id: int = 0) -> None:
        self.max_depth = max_depth
        self.player_id = player_id

    def get_quality_of_state(self, state: NimState) -> int:
        """
        Return the quality of the state
        -> 1 if the player wins
        -> -1 if the player loses
        -> 0 if the game is not over
        """
        if StochasticNimEnvironment.is_game_over(state):
            return -1 if state.player == self.player_id else 1
        else:
            return 0

    def minimax(self, state: NimState, depth: int, is_maximizing: bool) -> list[tuple[int, NimAction]]:
        """
        Return the action that minimize or maximize based on turn
        """
        if depth == 0 or StochasticNimEnvironment.is_game_over(state):
            return [(self.get_quality_of_state(state), None)]

        quality_and_actions = []
        for action in StochasticNimEnvironment.get_stochastic_valid_actions(state):
            new_state, _ = StochasticNimEnvironment.step_without_stochastic(state, action)
            current_state_quality = self.get_quality_of_state(new_state)
            next_state_quality, _ = self.minimax(new_state, depth - 1, not is_maximizing)[0]

            quality_and_actions.append((current_state_quality + next_state_quality, action))

        quality_and_actions.sort(key=lambda x: x[0], reverse=is_maximizing)

        return quality_and_actions

    def get_action(self, state: NimState, actual_valid_actions: list[NimAction]) -> NimAction:
        quality_and_actions = self.minimax(state, self.max_depth, state.player == self.player_id)

        for _, action in quality_and_actions:
            if action in actual_valid_actions:
                return action

# Run the agent in the environment

In [26]:
from tqdm import tqdm

In [58]:
GAME_AMOUNT = 50
MINIMAX_DEPTH = 6

MINIMAX_WINS = 0

looper = tqdm(range(GAME_AMOUNT))
for i in looper:
    is_minimax_first = random.choice([True, False])
    if is_minimax_first:
        zero_player = MiniMaxNimPlayer(max_depth=MINIMAX_DEPTH, player_id=0)
        one_player = RandomNimPlayer(seed=i)
    else:
        zero_player = RandomNimPlayer(seed=i)
        one_player = MiniMaxNimPlayer(max_depth=MINIMAX_DEPTH, player_id=1)

    env = StochasticNimEnvironment(zero_player, one_player, seed=i)
    history = env.play(verbose=False)

    if history[-1].player == is_minimax_first:
        MINIMAX_WINS += 1

    looper.set_description(f"Minimax wins {MINIMAX_WINS} out of {i+1}(%.2f%%)" % (MINIMAX_WINS / (i + 1) * 100))

print(f"Minimax wins {MINIMAX_WINS} out of {GAME_AMOUNT}")

Minimax wins 28 out of 50(56.00%): 100%|██████████| 50/50 [00:05<00:00,  9.73it/s]

Minimax wins 28 out of 50



