In [14]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import heapq
import matplotlib.pyplot as plt
from collections import OrderedDict
from tqdm import tqdm

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# ---------------------------------------------
# Module 1: Quoridor Game Logic
# ---------------------------------------------
class QuoridorGame:
    """Manages the Quoridor game rules and state transitions."""

    def __init__(self, board_size=9, wall_count=10):
        self.board_size = board_size
        self.wall_count = wall_count

    def find_shortest_path(self, state, player):
        """Use A* algorithm to find the shortest path to the goal for a player."""
        board = state.board
        start = tuple(state.player_positions[player] * 2)  # Convert to tuple for heap
        queue = [(0, 0, start)]  # (f_score, g_score, position_tuple)
        visited = np.zeros((2 * self.board_size - 1, 2 * self.board_size - 1), dtype=np.int8)
        target_row = (2 * self.board_size - 2) * (1 - player)  # Goal row for player

        while queue:
            _, g, pos = heapq.heappop(queue)
            pos = np.array(pos)  # Convert back to NumPy array for indexing
            if pos[0] < 0 or pos[0] >= 2 * self.board_size - 1 or pos[1] < 0 or pos[1] >= 2 * self.board_size - 1:
                continue
            if board[tuple(pos)] == 1 or visited[tuple(pos)]:
                continue
            if pos[0] == target_row:
                return g
            visited[tuple(pos)] = 1
            for direction in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
                new_pos = pos + np.array(direction)
                h = abs(target_row - new_pos[0])  # Manhattan distance heuristic
                heapq.heappush(queue, (g + 1 + h, g + 1, tuple(new_pos)))  # Store as tuple
        return -1

    # ... (rest of the QuoridorGame class remains unchanged, include from original code)

    def create_initial_state(self):
        return GameState(self.board_size, self.wall_count)

    def is_wall_placement_valid(self, state):
        return self.find_shortest_path(state, 0) != -1 and self.find_shortest_path(state, 1) != -1

    def get_valid_moves(self, state, player):
        board = state.board
        current_pos = state.player_positions[player] * 2
        valid_moves = []
        stack = [(current_pos, 0)]
        visited = np.zeros((2 * self.board_size - 1, 2 * self.board_size - 1), dtype=np.int8)

        while stack:
            pos, steps = stack.pop()
            if pos[0] < 0 or pos[0] > 2 * self.board_size - 2 or pos[1] < 0 or pos[1] > 2 * self.board_size - 2:
                continue
            if board[*pos] == 1 or visited[*pos]:
                continue
            if board[*pos] in (2, 3):
                steps = 0
            visited[*pos] = 1
            if steps == 2:
                valid_moves.append(pos // 2)
                continue
            for direction in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
                stack.append((pos + np.array(direction), steps + 1))
        return valid_moves

    def is_move_valid(self, state, next_pos, player):
        return any(np.array_equal(next_pos, move) for move in self.get_valid_moves(state, player))

    def apply_action(self, state, action, player):
        action_type, value = action
        new_state = state.copy()
        if action_type == 0:
            if self.is_move_valid(new_state, value, player):
                new_state.board[*new_state.player_positions[player] * 2] = 0
                new_state.board[*np.array(value) * 2] = player + 2
                new_state.player_positions[player] = value
                return new_state
            return None
        else:
            if new_state.walls_left[player] == 0:
                return None
            orientation = action_type - 1
            row, col = value
            if new_state.wall_grid[orientation, row, col] != 0:
                return None
            new_state.wall_grid[orientation, row, col] = 1
            new_state.wall_grid[1 - orientation, row, col] = -1
            if orientation == 0 and col > 0:
                new_state.wall_grid[0, row, col - 1] = -1
            if orientation == 1 and row > 0:
                new_state.wall_grid[1, row - 1, col] = -1
            if orientation == 0 and col < self.board_size - 2:
                new_state.wall_grid[0, row, col + 1] = -1
            if orientation == 1 and row < self.board_size - 2:
                new_state.wall_grid[1, row + 1, col] = -1
            new_state.board[
                row * 2 - orientation + 1 : row * 2 + orientation + 2,
                col * 2 - (1 - orientation) + 1 : col * 2 + (1 - orientation) + 2
            ] = 1
            new_state.walls_left[player] -= 1
            if not self.is_wall_placement_valid(new_state):
                return None
            return new_state

    def get_possible_actions(self, state, player):
        moves = self.get_valid_moves(state, player)
        walls = [
            (hv, (r, c))
            for hv in range(2)
            for r in range(self.board_size - 1)
            for c in range(self.board_size - 1)
            if self.apply_action(state, (hv + 1, (r, c)), player) is not None
        ]
        action_tensor = np.zeros((3, self.board_size, self.board_size), dtype=np.float32)
        for move in moves:
            action_tensor[0, *move] = 1
        for hv, (r, c) in walls:
            action_tensor[1 + hv, r, c] = 1
        return action_tensor

    def has_won(self, state, player):
        return state.player_positions[player][0] == (self.board_size - 1) * (1 - player)

    def evaluate_state(self, state, player):
        player_dist = self.find_shortest_path(state, player)
        opponent_dist = self.find_shortest_path(state, 1 - player)
        return player_dist / (player_dist + opponent_dist + 1e-8)

    def check_game_status(self, state, player, turns):
        if turns > 50:
            return True, True, self.evaluate_state(state, player)
        if self.has_won(state, player):
            return True, False, 1.0
        return False, False, 0.0

class GameState:
    def __init__(self, size, walls, copy=None):
        self.size = size
        if copy:
            self.player_positions = copy.player_positions.copy()
            self.walls_left = copy.walls_left.copy()
            self.wall_grid = copy.wall_grid.copy()
            self.board = copy.board.copy()
        else:
            self.player_positions = np.array([[0, size // 2], [size - 1, size // 2]])
            self.walls_left = np.array([walls, walls])
            self.wall_grid = np.zeros((2, size - 1, size - 1), dtype=np.int8)
            self.board = self.initialize_board()

    def initialize_board(self):
        board = np.zeros((2 * self.size - 1, 2 * self.size - 1), dtype=np.int8)
        board[1::2, 1::2] = 1
        board[self.player_positions[0, 0] * 2, self.player_positions[0, 1] * 2] = 2
        board[self.player_positions[1, 0] * 2, self.player_positions[1, 1] * 2] = 3
        return board

    def copy(self):
        return GameState(self.size, 0, copy=self)

    def encode_state(self, player):
        encoded = np.zeros((4, self.size, self.size), dtype=np.float32)
        encoded[player, self.player_positions[0, 0], self.player_positions[0, 1]] = 1
        encoded[1 - player, self.player_positions[1, 0], self.player_positions[1, 1]] = 1
        encoded[2, :, :] = np.pad(self.wall_grid[0] == 1, ((0, 1), (0, 1)), mode='constant')
        encoded[3, :, :] = np.pad(self.wall_grid[1] == 1, ((0, 1), (0, 1)), mode='constant')
        return encoded


In [15]:

# ---------------------------------------------
# Module 2: Neural Network for AlphaZero
# ---------------------------------------------
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        x += residual
        return F.relu(x)

class GameNetwork(nn.Module):
    def __init__(self, game, num_blocks, channels):
        super().__init__()
        self.game = game
        self.initial_conv = nn.Sequential(
            nn.Conv2d(4, channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(channels),
            nn.ReLU()
        )
        self.res_blocks = nn.Sequential(
            OrderedDict([(f'res_block_{i}', ResidualBlock(channels)) for i in range(num_blocks)])
        )
        self.policy_output = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=1),
            nn.BatchNorm2d(channels),
            nn.ReLU(),
            nn.Conv2d(channels, 3, kernel_size=1),
            nn.Sigmoid()
        )
        self.value_output = nn.Sequential(
            nn.Conv2d(channels, 3, kernel_size=1),
            nn.BatchNorm2d(3),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3 * game.board_size * game.board_size, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.initial_conv(x)
        x = self.res_blocks(x)
        policy = self.policy_output(x)
        value = self.value_output(x)
        return policy, value


In [16]:

# ---------------------------------------------
# Module 3: Monte Carlo Tree Search (MCTS)
# ---------------------------------------------
class TreeNode:
    def __init__(self, game, params, state, player, turn, parent=None, action=None):
        self.game = game
        self.params = params
        self.state = state
        self.player = player
        self.turn = turn
        self.parent = parent
        self.action = action
        self.children = []
        self.valid_actions = game.get_possible_actions(state, player)
        self.visits = 0
        self.value_sum = 0

    def is_expanded(self):
        return len(self.children) > 0

    def select_child(self):
        best_score = -np.inf
        best_child = None
        for child in self.children:
            score = self.compute_ucb(child)
            if score > best_score:
                best_score = score
                best_child = child
        return best_child

    def compute_ucb(self, child):
        q = child.value_sum / (child.visits + 1e-8)
        exploration = self.params['C'] * np.sqrt(np.log(self.visits + 1) / (child.visits + 1e-8))
        return q + exploration

    def expand_node(self, valid_actions, policy):
        for idx in zip(*np.where(valid_actions == 1)):
            if policy[*idx] > 0:
                action_type, row, col = idx
                new_state = self.game.apply_action(self.state, (action_type, (row, col)), 1 - self.player)
                if new_state:
                    child = TreeNode(self.game, self.params, new_state, 1 - self.player, self.turn + 1, self, idx)
                    self.children.append(child)

    def backprop(self, value, is_draw):
        self.visits += 1
        self.value_sum += value * self.params['draw_discount'] if is_draw else value
        if self.parent:
            self.parent.backprop(value, is_draw)

class MCTS:
    def __init__(self, game, params, network):
        self.game = game
        self.params = params
        self.network = network

    def run_search(self, state, player):
        root = TreeNode(self.game, self.params, state, 1 - player, 0)
        for _ in tqdm(range(self.params['n_searches'])):
            node = root
            while node.is_expanded():
                node = node.select_child()
            is_terminal, is_draw, value = self.game.check_game_status(node.state, 1 - node.player, node.turn)
            if not is_terminal:
                policy, value = self.network(torch.tensor(node.state.encode_state(1 - node.player)).unsqueeze(0))
                policy = torch.softmax(policy.flatten(1), dim=1).reshape(3, self.game.board_size, self.game.board_size).detach().numpy()
                valid_actions = self.game.get_possible_actions(node.state, 1 - node.player)
                policy *= valid_actions
                policy /= np.sum(policy) + 1e-8
                value = value.item()
                node.expand_node(valid_actions, policy)
            node.backprop(value, is_draw)
        action_probs = np.zeros((3, self.game.board_size, self.game.board_size), dtype=np.float32)
        for child in root.children:
            action_probs[*child.action] = child.value_sum / (child.visits + 1e-8)
        action_probs /= np.sum(action_probs) + 1e-8
        return action_probs


In [17]:

# ---------------------------------------------
# Module 4: Visualization
# ---------------------------------------------
def visualize_board(policy, state, player):
    plt.figure(figsize=(6, 6))
    plt.axis('equal')
    plt.gca().add_patch(plt.Rectangle((-0.5, -0.5), state.size, state.size, fc='w', ec='k'))

    for i in range(state.size):
        for j in range(state.size):
            prob = policy[0, i, j] if policy is not None else 0
            color = np.array([1 - prob, 0, prob, 1]) if prob > 0 else 'w'
            plt.gca().add_patch(plt.Rectangle((i - 0.3, j - 0.3), 0.6, 0.6, fc=color, ec='k'))

    for i in range(state.size - 1):
        for j in range(state.size - 1):
            h_prob = policy[1, i, j] if policy is not None else 0
            v_prob = policy[2, i, j] if policy is not None else 0
            h_color = np.array([1 - h_prob, 0, h_prob, 1]) if state.wall_grid[0, i, j] != 1 else 'k'
            v_color = np.array([1 - v_prob, 0, v_prob, 1]) if state.wall_grid[1, i, j] != 1 else 'k'
            plt.gca().add_patch(plt.Rectangle((i + 0.4, j - 0.5), 0.2, 1, fc=h_color, ec='k'))
            plt.gca().add_patch(plt.Rectangle((i - 0.5, j + 0.4), 1, 0.2, fc=v_color, ec='k'))
            plt.gca().add_patch(plt.Rectangle((i + 0.4, j + 0.4), 0.2, 0.2, fc='k', ec='k'))

    plt.gca().add_patch(plt.Circle((state.player_positions[0, 0], state.player_positions[0, 1]), 0.1, fc='g' if player == 0 else 'w', ec='k'))
    plt.gca().add_patch(plt.Circle((state.player_positions[1, 0], state.player_positions[1, 1]), 0.1, fc='g' if player == 1 else 'k', ec='k'))

    plt.show()


In [18]:

# ---------------------------------------------
# Module 5: AlphaZero Training
# ---------------------------------------------
class AlphaZeroTrainer:
    def __init__(self, network, optimizer, game, params):
        self.network = network
        self.optimizer = optimizer
        self.game = game
        self.params = params
        self.mcts = MCTS(game, params, network)

    def run_self_play(self):
        history = []
        player = 1
        state = self.game.create_initial_state()
        turn = 0
        while True:
            action_probs = self.mcts.run_search(state, player)
            history.append((state, action_probs, player))
            action_idx = np.unravel_index(np.argmax(action_probs), action_probs.shape)
            action_type, row, col = action_idx
            state = self.game.apply_action(state, (action_type, (row, col)), player)
            is_terminal, is_draw, value = self.game.check_game_status(state, player, turn)
            turn += 1
            if is_terminal:
                training_data = []
                for hist_state, hist_probs, hist_player in history:
                    outcome = value if hist_player == player else 1 - value
                    training_data.append((hist_state.encode_state(hist_player), hist_probs, outcome))
                return training_data
            player = 1 - player

    def train_network(self, data):
        pass

    def train(self):
        for iteration in range(self.params['n_iterations']):
            data_buffer = []
            self.network.eval()
            for _ in range(self.params['n_selfplay_iterations']):
                data_buffer.extend(self.run_self_play())
            self.network.train()
            for _ in range(self.params['n_epochs']):
                self.train_network(data_buffer)
            torch.save(self.network.state_dict(), f'model_iter_{iteration}.pt')
            torch.save(self.optimizer.state_dict(), f'optimizer_iter_{iteration}.pt')


In [20]:

# ---------------------------------------------
# Main Execution
# ---------------------------------------------
if __name__ == "__main__":
    game = QuoridorGame(board_size=5, wall_count=6)
    network = GameNetwork(game, num_blocks=3, channels=3)
    optimizer = torch.optim.Adam(network.parameters(), lr=0.001)
    params = {
        'C': 2.0,
        'n_searches': 10,
        'n_iterations': 3,
        'n_selfplay_iterations': 10,
        'n_epochs': 100,
        'draw_discount': 0.2
    }

    trainer = AlphaZeroTrainer(network, optimizer, game, params)
    trainer.train()

100%|██████████| 10/10 [00:00<00:00, 74.02it/s]
100%|██████████| 10/10 [00:00<00:00, 185.34it/s]
100%|██████████| 10/10 [00:00<00:00, 451.61it/s]
100%|██████████| 10/10 [00:00<00:00, 484.08it/s]
100%|██████████| 10/10 [00:00<00:00, 518.17it/s]
100%|██████████| 10/10 [00:00<00:00, 996.51it/s]
100%|██████████| 10/10 [00:00<00:00, 1129.96it/s]
100%|██████████| 10/10 [00:00<00:00, 999.05it/s]
100%|██████████| 10/10 [00:00<00:00, 80.01it/s]
100%|██████████| 10/10 [00:00<00:00, 188.70it/s]
100%|██████████| 10/10 [00:00<00:00, 370.20it/s]
100%|██████████| 10/10 [00:00<00:00, 515.17it/s]
100%|██████████| 10/10 [00:00<00:00, 544.92it/s]
100%|██████████| 10/10 [00:00<00:00, 999.69it/s]
100%|██████████| 10/10 [00:00<00:00, 1103.44it/s]
100%|██████████| 10/10 [00:00<00:00, 1075.57it/s]
100%|██████████| 10/10 [00:00<00:00, 81.09it/s]
100%|██████████| 10/10 [00:00<00:00, 185.30it/s]
100%|██████████| 10/10 [00:00<00:00, 411.12it/s]
100%|██████████| 10/10 [00:00<00:00, 509.09it/s]
100%|██████████| 10/