In [34]:
import numpy as np
from typing import Dict, Tuple, List, Union

Pos = Tuple[int, int]  # (row, col)

class Board:
    """
    Peg-Solitaire (7×7 cross) board.
    - Peg  : 1
    - Hole : 0
    - Illegal cell : 0 (בתצוגת מערך בלבד, אין -1! ראה הסבר)
    """
    __slots__ = ("state",)

    # 33 legal board positions (cross shape)
    LEGAL_POSITIONS: List[Pos] = [
        (r, c) for r in range(7) for c in range(7)
        if (2 <= r <= 4) or (2 <= c <= 4)
    ]
    LEGAL_MASK: np.ndarray = np.zeros((7, 7), dtype=np.float32)
    for _r, _c in LEGAL_POSITIONS:
        LEGAL_MASK[_r, _c] = 1.0
    TOTAL_PEGS = 32

    def __init__(self) -> None:
        """Standard start: pegs on all positions, center is empty."""
        self.state: Dict[Pos, int] = {pos: 1 for pos in self.LEGAL_POSITIONS}
        self.state[(3, 3)] = 0

    def reset(self) -> None:
        """Reset to standard board (all pegs, center empty)."""
        for pos in self.LEGAL_POSITIONS:
            self.state[pos] = 1
        self.state[(3, 3)] = 0

    def get(self, pos: Pos) -> Union[int, None]:
        """Return 1 (peg) / 0 (hole) / None (illegal position)."""
        return self.state.get(pos, None)

    def set(self, pos: Pos, value: int) -> None:
        """Place peg (1) or hole (0) at pos."""
        if pos not in self.LEGAL_POSITIONS or value not in (0, 1):
            raise ValueError(f"Illegal position/value: {pos} {value}")
        self.state[pos] = value

    def all_pegs(self) -> List[Pos]:
        return [p for p, v in self.state.items() if v == 1]

    def all_holes(self) -> List[Pos]:
        return [p for p, v in self.state.items() if v == 0]

    def count_pegs(self) -> int:
        """Return current number of pegs on board."""
        # יעיל, בלי קשר ל-numpy, עובד מול dict
        return sum(self.state.values())

    def as_array(self) -> np.ndarray:
        """
        Returns a 7×7 array:
          1 = peg, 0 = hole or out-of-board
        """
        arr = np.zeros((7, 7), dtype=np.float32)
        for pos in self.LEGAL_POSITIONS:
            arr[pos] = float(self.state[pos])
        return arr

    def get_state(self) -> np.ndarray:
        """Alias for as_array()."""
        return self.as_array()

    def set_state(self, data: Union[np.ndarray, Dict[Pos, int]]) -> None:
        """
        Set board from ndarray (7×7) or dict {pos: val}.
        Illegal/outside values treated as holes (0).
        """
        if isinstance(data, dict):
            for pos in self.LEGAL_POSITIONS:
                self.state[pos] = int(data.get(pos, 0))
        else:
            arr = np.asarray(data)
            if arr.shape != (7, 7):
                raise ValueError("state array must be shape (7,7)")
            for pos in self.LEGAL_POSITIONS:
                self.state[pos] = 1 if arr[pos] == 1 else 0

    @classmethod
    def from_dict(cls, d: Dict[Pos, int]) -> "Board":
        b = cls()
        b.set_state(d)
        return b

    def encode_observation(self) -> np.ndarray:
        """
        Return obs (7,7,4):
         ch0: pegs (1/0),
         ch1: fraction removed (global),
         ch2: fraction remain (global),
         ch3: legal mask (1/0)
        """
        arr = self.as_array()
        num_pegs = np.sum(arr)
        removed = (self.TOTAL_PEGS - num_pegs) / self.TOTAL_PEGS
        remain = num_pegs / self.TOTAL_PEGS
        obs = np.zeros((7, 7, 4), dtype=np.float32)
        obs[:, :, 0] = arr
        obs[:, :, 1] = removed
        obs[:, :, 2] = remain
        obs[:, :, 3] = self.LEGAL_MASK
        return obs

    @staticmethod
    def augment_observation(obs: np.ndarray, mode: str = "random") -> np.ndarray:
        """
        Return augmented observation (rot/flip).
        mode = 'random' | 'all' | None
        """
        augs = []
        for rot in range(4):
            r = np.rot90(obs, k=rot, axes=(0, 1))
            for flip in (False, True):
                augs.append(np.flip(r, axis=1) if flip else r)
        if mode == "all":
            return np.stack(augs)
        if mode == "random":
            idx = np.random.randint(8)
            return augs[idx]
        return obs

    def copy(self) -> "Board":
        new_b = Board()
        new_b.state = self.state.copy()
        return new_b

    clone = copy  # alias

    def to_dict(self) -> Dict[Pos, int]:
        return self.state.copy()

    def __hash__(self) -> int:
        return hash(tuple(sorted(self.state.items())))

    def __eq__(self, other: object) -> bool:
        return isinstance(other, Board) and self.state == other.state

    def __str__(self) -> str:
        arr = self.as_array()
        rows = []
        for r in range(7):
            rows.append(' '.join('●' if arr[r, c] == 1 else
                                 '◯' if arr[r, c] == 0 else
                                 ' ' for c in range(7)))
        return "\n".join(rows)

In [35]:
from typing import Tuple, List, Dict, Optional, Callable
from copy import deepcopy

Pos = Tuple[int, int]
Action = Tuple[int, int, int]          # (row, col, dir-idx)


class Game:
    """
    Core logic of Peg-Solitaire.
    • Holds a Board instance
    • Generates legal moves / actions
    • Applies moves with full undo/redo support
    • Supplies hooks for RL (reward_fn, clone_state, etc.)
    """

    # Δ-וקטורים (row,col)   –idx→ 0:up  1:down  2:left  3:right
    DIRECTIONS: List[Pos] = [(-2, 0), (2, 0), (0, -2), (0, 2)]

    def __init__(
        self,
        board: Optional["Board"] = None,
        reward_fn: Optional[Callable[["Board", Tuple[Pos, Pos, Pos], bool], float]] = None
    ) -> None:
        self.board: "Board" = board.copy() if board else Board()
        self.move_history: List[Tuple[Pos, Pos, Pos, "Board"]] = []   # (from,to,over, board_before)
        self.redo_stack: List[Tuple[Pos, Pos, Pos, "Board"]] = []
        self.last_move: Optional[Tuple[Pos, Pos, Pos]] = None
        self.reward_fn = reward_fn or self._default_reward        # use built-in if None
        self.custom_metadata: Dict[str, object] = {}
        self.move_log: List[Tuple[Pos, Pos, Pos]] = []

    # ------------------------------------------------------------------ #
    #                          move legality                              #
    # ------------------------------------------------------------------ #
    def _middle(self, a: Pos, b: Pos) -> Pos:
        return ((a[0] + b[0]) // 2, (a[1] + b[1]) // 2)

    def is_legal_move(self, src: Pos, dst: Pos) -> Tuple[bool, Optional[Pos]]:
        """Return (is_legal, middle_pos) for move src→dst."""
        if src not in Board.LEGAL_POSITIONS or dst not in Board.LEGAL_POSITIONS:
            return False, None
        dr, dc = dst[0] - src[0], dst[1] - src[1]
        if (abs(dr), abs(dc)) not in ((2, 0), (0, 2)):
            return False, None
        over = self._middle(src, dst)
        if self.board.get(src) == 1 and self.board.get(over) == 1 and self.board.get(dst) == 0:
            return True, over
        return False, None

    def get_legal_moves(self) -> List[Tuple[Pos, Pos, Pos]]:
        """List of (src, dst, over) tuples for all legal moves."""
        moves = []
        for src in self.board.all_pegs():
            for di, dj in self.DIRECTIONS:
                dst = (src[0] + di, src[1] + dj)
                legal, over = self.is_legal_move(src, dst)
                if legal:
                    moves.append((src, dst, over))
        return moves

    # ------------------------------------------------------------------ #
    #                        RL-friendly actions                          #
    # ------------------------------------------------------------------ #
    def get_legal_actions(self) -> List[Action]:
        acts: List[Action] = []
        for src in self.board.all_pegs():
            for d_idx, (di, dj) in enumerate(self.DIRECTIONS):
                dst = (src[0] + di, src[1] + dj)
                legal, _ = self.is_legal_move(src, dst)
                if legal:
                    acts.append((src[0], src[1], d_idx))
        return acts

    def is_legal_action(self, a: Action) -> bool:
        r, c, d = a
        dr, dc = self.DIRECTIONS[d]
        return self.is_legal_move((r, c), (r + dr, c + dc))[0]

    # ------------------------------------------------------------------ #
    #                              apply                                  #
    # ------------------------------------------------------------------ #
    def _apply(self, src: Pos, dst: Pos, over: Pos) -> None:
        """Mutate board: execute the (already validated) jump."""
        self.board.set(src, 0)
        self.board.set(over, 0)
        self.board.set(dst, 1)

    def apply_move(self, src: Pos, dst: Pos) -> Tuple[bool, float, bool, Dict]:
        legal, over = self.is_legal_move(src, dst)
        if not legal:
            return False, 0.0, self.is_game_over(), {"reason": "illegal"}

        before = self.board.copy()
        self.move_history.append((src, dst, over, before))
        self.redo_stack.clear()

        self._apply(src, dst, over)
        self.last_move = (src, over, dst)
        self.move_log.append(self.last_move)

        done = self.is_game_over()
        reward = self.reward_fn(self.board, self.last_move, done)
        return True, reward, done, {"last_move": self.last_move, "done": done}

    def apply_action(self, action: Action) -> Tuple[bool, float, bool, Dict]:
        r, c, d = action
        dr, dc = self.DIRECTIONS[d]
        return self.apply_move((r, c), (r + dr, c + dc))

    # ------------------------------------------------------------------ #
    #                          undo / redo                                #
    # ------------------------------------------------------------------ #
    def undo(self) -> bool:
        if not self.move_history:
            return False
        src, dst, over, before = self.move_history.pop()
        self.redo_stack.append((src, dst, over, self.board.copy()))
        self.board = before
        self.last_move = self.move_history[-1][:3] if self.move_history else None
        if self.move_log:
            self.move_log.pop()
        return True

    def redo(self) -> bool:
        if not self.redo_stack:
            return False
        src, dst, over, before = self.redo_stack.pop()
        self.move_history.append((src, dst, over, before))
        self.board = before.copy()
        self._apply(src, dst, over)
        self.last_move = (src, over, dst)
        self.move_log.append(self.last_move)
        return True

    # ------------------------------------------------------------------ #
    #                         termination checks                          #
    # ------------------------------------------------------------------ #
    def is_game_over(self) -> bool:
        return not self.get_legal_moves()

    def is_win(self) -> bool:
        return self.board.count_pegs() == 1 and self.board.get((3, 3)) == 1

    is_solved = is_win  # alias

    # ------------------------------------------------------------------ #
    #                             reset / clone                           #
    # ------------------------------------------------------------------ #
    def reset(self, board: Optional["Board"] = None) -> None:
        self.board = board.copy() if board else Board()
        self.move_history.clear()
        self.redo_stack.clear()
        self.last_move = None
        self.move_log.clear()
        self.custom_metadata.clear()

    def clone_state(self) -> "Game":
        """Deep copy for MCTS."""
        return Game(board=self.board.copy(), reward_fn=self.reward_fn)

    # ------------------------------------------------------------------ #
    #                       misc / metadata / io                          #
    # ------------------------------------------------------------------ #
    def export_move_log(self) -> List[Tuple[Pos, Pos, Pos]]:
        return self.move_log.copy()

    def set_state(self, state: Union["Board", Dict[Pos, int]]) -> None:
        if isinstance(state, Board):
            self.board = state.copy()
        elif isinstance(state, dict):
            self.board.set_state(state)
        else:
            raise TypeError("state must be Board or dict")
        self.move_history.clear()
        self.redo_stack.clear()
        self.last_move = None

    # custom metadata
    def get_custom_metadata(self, key: str, default=None):
        return self.custom_metadata.get(key, default)

    def set_custom_metadata(self, key: str, val):
        self.custom_metadata[key] = val

    # hashing / equality
    def __hash__(self) -> int:
        return hash(self.board)

    def __eq__(self, other: object) -> bool:
        return isinstance(other, Game) and self.board == other.board

    # ------------------------------------------------------------------ #
    #                          reward default                             #
    # ------------------------------------------------------------------ #
    @staticmethod
    def _default_reward(board: "Board", last_mv: Tuple[Pos, Pos, Pos], done: bool) -> float:
        """+1 solved / 0.1 per jump / 0 on fail."""
        if done and board.count_pegs() == 1:
            return 1.0
        return 0.1

    # ------------------------------------------------------------------ #
    #                           debug print                               #
    # ------------------------------------------------------------------ #
    def __str__(self) -> str:
        txt = [str(self.board)]
        if self.last_move:
            txt.append(f"Last move: {self.last_move}")
        return "\n".join(txt)

In [36]:
from typing import Callable, Optional, Tuple, List
import numpy as np


class PegSolitaireEnv:
    """
    Minimal-Gym-like environment for Peg-Solitaire.

    Observation shape : (7, 7, 4) – see encode_observation().
    Action            : tuple (row, col, dir-idx) — dir-idx∈{0:↑,1:↓,2:←,3:→}

    Parameters
    ----------
    board_cls : type[Board]
    game_cls  : type[Game]
    reward_fn : Optional custom callable overriding Game.reward_fn
    """

    BOARD_SIZE = 7
    TOTAL_PEGS = 32

    # ---------------------- ctor / reset ---------------------- #
    def __init__(
        self,
        board_cls,
        game_cls,
        reward_fn: Optional[Callable[["Board", Tuple, bool], float]] = None
    ) -> None:
        self._board_cls = board_cls
        self._game_cls = game_cls
        self.game = game_cls(board_cls(), reward_fn=reward_fn)
        self.done = False
        self.board_mask = self._generate_board_mask()

    def _generate_board_mask(self) -> np.ndarray:
        mask = np.zeros((self.BOARD_SIZE, self.BOARD_SIZE), dtype=np.float32)
        for r in range(7):
            for c in range(7):
                if (2 <= r <= 4) or (2 <= c <= 4):
                    mask[r, c] = 1.0
        return mask

    # Gym-style API ------------------------------------------------------ #
    def reset(self, state=None) -> Tuple[np.ndarray, dict]:
        """Reset env. Optionally load custom board state (dict|ndarray|Board)."""
        if state is None:
            self.game.reset()
        else:
            self.game.set_state(state)
        self.done = False
        obs = self.encode_observation()
        info = {"num_pegs": self.game.board.count_pegs()}
        return obs, info

    def step(self, action: Tuple[int, int, int]) -> Tuple[np.ndarray, float, bool, dict]:
        """Apply action, return (obs, reward, done, info)."""
        if self.done:
            raise RuntimeError("Call reset() before further step()s.")

        if self.game.is_legal_action(action):
            _, reward, self.done, _ = self.game.apply_action(action)
        else:
            reward = -1.0  # penalty for illegal
            self.done = self.game.is_game_over()

        obs = self.encode_observation()
        info = {
            "num_pegs": self.game.board.count_pegs(),
            "is_solved": self.game.is_solved(),
        }
        return obs, reward, self.done, info

    # ---------------------- helpers ---------------------------- #
    def get_legal_actions(self) -> List[Tuple[int, int, int]]:
        return self.game.get_legal_actions()

    def get_legal_action_mask(
        self, action_space_size: int, to_idx: Callable[[Tuple[int, int, int]], int]
    ) -> np.ndarray:
        mask = np.zeros(action_space_size, dtype=np.float32)
        for a in self.get_legal_actions():
            mask[to_idx(a)] = 1.0
        return mask

    # ------------------ observation encoding ------------------- #
    def encode_observation(self) -> np.ndarray:
        """
        Channel mapping
        0 – peg(1)/hole(0)
        1 – fraction removed   (broadcast scalar)
        2 – fraction remaining (broadcast scalar)
        3 – legal mask
        """
        arr = self.game.board.get_state().astype(np.float32)  # 7×7
        num_pegs = arr[arr == 1].size
        removed, remain = (self.TOTAL_PEGS - num_pegs) / self.TOTAL_PEGS, num_pegs / self.TOTAL_PEGS
        obs = np.zeros((self.BOARD_SIZE, self.BOARD_SIZE, 4), dtype=np.float32)
        obs[:, :, 0] = (arr == 1).astype(np.float32)
        obs[:, :, 1] = removed
        obs[:, :, 2] = remain
        obs[:, :, 3] = self.board_mask
        return obs

    # --------------------- render / debug ---------------------- #
    def render(self, mode: str = "human") -> None:
        print(self.game.board)

    # ------------------- cloning for MCTS ---------------------- #
    def clone_state(self, state=None) -> "PegSolitaireEnv":
        """Return *deep* copy of env (keeping Board & Game classes intact)."""
        clone = PegSolitaireEnv(self._board_cls, self._game_cls)
        clone.reset(state or self.game.board.to_dict())
        return clone

    # --------------------- data augmentation ------------------- #
    @staticmethod
    def augment_observation(obs: np.ndarray, mode: str = "random") -> np.ndarray:
        augs = []
        for r in range(4):
            rot = np.rot90(obs, k=r, axes=(0, 1))
            for flip in (False, True):
                augs.append(np.flip(rot, axis=1) if flip else rot)
        if mode == "all":
            return np.stack(augs)
        if mode == "random":
            return augs[np.random.randint(8)]
        return obs

    # ---------- action augmentation (for symmetry RL) ---------- #
    @staticmethod
    def augment_action(
        action: Tuple[int, int, int], rot: int = 0, flip: bool = False
    ) -> Tuple[int, int, int]:
        row, col, d = action
        directions = [(-2, 0), (2, 0), (0, -2), (0, 2)]
        dr, dc = directions[d]
        trg = (row + dr, col + dc)

        # apply rotation(s)
        for _ in range(rot):
            row, col = col, 6 - row
            trg = (trg[1], 6 - trg[0])
        # apply mirror
        if flip:
            row, col = row, 6 - col
            trg = (trg[0], 6 - trg[1])

        # recompute dir index
        diff = (trg[0] - row, trg[1] - col)
        d_new = directions.index(diff)
        return row, col, d_new

In [37]:
from __future__ import annotations
import random
from typing import Dict, List, Tuple, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# ---------------------------------------------------------------------- #
#                         Neural-Net  (ResNet)                           #
# ---------------------------------------------------------------------- #
class ResidualBlock(nn.Module):
    def __init__(self, channels: int) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = F.relu(self.bn1(self.conv1(x)))
        y = self.bn2(self.conv2(y))
        return F.relu(x + y)

class PegSolitaireNet(nn.Module):
    """
    רשת עצבית לשחקן מחשבת:
      - קלט: (batch, 4, 7, 7)
      - ראש Policy: התפלגות על כל הפעולות (N_actions)
      - ראש Value: מספר בין -1 ל-1 (הערכת סיכוי פתרון)
    """
    def __init__(self, n_actions, n_res_blocks=10, channels=64, device='cpu'):
        super().__init__()
        self.device = device
        self.conv_in = nn.Conv2d(4, channels, 3, padding=1)
        self.bn_in = nn.BatchNorm2d(channels)
        self.res_blocks = nn.Sequential(*[ResidualBlock(channels) for _ in range(n_res_blocks)])
        # Policy Head
        self.policy_conv = nn.Conv2d(channels, 4, 1)  # 4 כי 4 כיוונים
        self.policy_bn = nn.BatchNorm2d(4)
        self.policy_fc = nn.Linear(4 * 7 * 7, n_actions)
        # Value Head
        self.value_conv = nn.Conv2d(channels, 2, 1)
        self.value_bn = nn.BatchNorm2d(2)
        self.value_fc1 = nn.Linear(2 * 7 * 7, 64)
        self.value_fc2 = nn.Linear(64, 1)
        self.to(device)

    def forward(self, x):
        # x: (batch, 4, 7, 7)
        x = F.relu(self.bn_in(self.conv_in(x)))
        x = self.res_blocks(x)

        # Policy
        p = F.relu(self.policy_bn(self.policy_conv(x)))
        # FIX THIS LINE
        p = p.reshape(x.size(0), -1) # Was: p.view(x.size(0), -1)
        p = self.policy_fc(p)

        # Value
        v = F.relu(self.value_bn(self.value_conv(x)))
        # AND FIX THIS LINE
        v = v.reshape(x.size(0), -1) # Was: v.view(x.size(0), -1)
        v = F.relu(self.value_fc1(v))
        v = torch.tanh(self.value_fc2(v))  # ערך בטווח [-1,1]

        return p, v.squeeze(-1)

# ---------------------------------------------------------------------- #
#                          Replay-Buffer                                 #
# ---------------------------------------------------------------------- #
class ReplayBuffer:
    """Fixed-size replay buffer (FIFO)."""

    def __init__(self, max_size: int = 50_000) -> None:
        self._buf: List[Tuple[np.ndarray, np.ndarray, float]] = []
        self.max_size = max_size

    def push(self, sample: Tuple[np.ndarray, np.ndarray, float]) -> None:
        self._buf.append(sample)
        if len(self._buf) > self.max_size:
            self._buf.pop(0)

    def sample(self, batch: int) -> List[Tuple[np.ndarray, np.ndarray, float]]:
        return random.sample(self._buf, min(len(self._buf), batch))

    def __len__(self) -> int:
        return len(self._buf)


# ---------------------------------------------------------------------- #
#                           MCTS (AlphaZero)                             #
# ---------------------------------------------------------------------- #
class _Node:
    __slots__ = ("prior", "children", "visit", "value_sum")

    def __init__(self, prior: float) -> None:
        self.prior = prior
        self.children: Dict[int, _Node] = {}
        self.visit = 0
        self.value_sum = 0.0

    @property
    def value(self) -> float:
        return self.value_sum / self.visit if self.visit else 0.0


class MCTS:
    def __init__(
        self,
        env,
        model: PegSolitaireNet,
        action_space,
        sims: int = 100,
        c_puct: float = 1.5,
        device: str | torch.device = "cpu",
    ) -> None:
        self.env = env
        self.model = model
        self.action_space = action_space
        self.sims = sims
        self.c_puct = c_puct
        self.device = torch.device(device)

    # ------------- core -------------- #
    def run(self, root_obs: np.ndarray) -> np.ndarray:
        root = self._expand(root_obs, self.env.get_legal_actions())
        for _ in range(self.sims):
            node, env_copy, path = root, self.env.clone_state(), [root]
            # selection
            while node.children:
                act_idx, node = self._select(node)
                env_copy.step(self.action_space.from_index(act_idx))
                path.append(node)
            # expansion + evaluation
            obs = env_copy.encode_observation()
            node.children = self._expand(obs, env_copy.get_legal_actions()).children
            value = self._evaluate(obs)
            # back-prop
            for n in path:
                n.visit += 1
                n.value_sum += value
        # policy as visit-counts
        pi = np.zeros(len(self.action_space), dtype=np.float32)
        for idx, child in root.children.items():
            pi[idx] = child.visit
        return pi / (pi.sum() + 1e-8)

    # ---------- helpers ------------- #
    def _evaluate(self, obs: np.ndarray) -> float:
        with torch.no_grad():
            t = torch.tensor(obs).float().permute(2, 0, 1).unsqueeze(0).to(self.device)
            _, v = self.model(t)
        return v.item()

    def _expand(
        self, obs: np.ndarray, legal: List[Tuple[int, int, int]]
    ) -> _Node:
        with torch.no_grad():
            t = torch.tensor(obs).float().permute(2, 0, 1).unsqueeze(0).to(self.device)
            logits, _ = self.model(t)
            probs = torch.softmax(logits, -1).cpu().numpy().flatten()
        mask = self.action_space.legal_action_mask(legal)
        probs *= mask
        if probs.sum() == 0:
            probs = mask / mask.sum()
        node = _Node(prior=1.0)
        for a in legal:
            idx = self.action_space.to_index(a)
            node.children[idx] = _Node(prior=probs[idx])
        return node

    def _select(self, node: _Node) -> Tuple[int, _Node]:
        total = np.sqrt(node.visit)
        best, best_child = -1.0, -1
        for idx, child in node.children.items():
            u = self.c_puct * child.prior * total / (1 + child.visit)
            score = child.value + u
            if score > best:
                best, best_child = score, idx
        return best_child, node.children[best_child]


# ---------------------------------------------------------------------- #
#                            Agent                                       #
# ---------------------------------------------------------------------- #
class Agent:
    def __init__(
        self,
        env,
        model: PegSolitaireNet,
        action_space,
        buffer: ReplayBuffer,
        sims: int = 100,
        device: str | torch.device = "cpu",
    ) -> None:
        self.env = env
        self.model = model
        self.action_space = action_space
        self.buffer = buffer
        self.device = torch.device(device)
        self.mcts = MCTS(env, model, action_space, sims, device=self.device)

    # ------------------ self-play ------------------ #
    def self_play_episode(self, augment: bool = True) -> None:
        obs, _ = self.env.reset()
        done = False
        states, policies = [], []
        while not done:
            pi = self.mcts.run(obs)
            action_idx = np.random.choice(len(pi), p=pi)
            action = self.action_space.from_index(action_idx)
            states.append(obs)
            policies.append(pi)
            obs, reward, done, _ = self.env.step(action)
        # store
        for s, p in zip(states, policies):
            if augment:
                for aug in self.env.augment_observation(s, "all"):
                    self.buffer.push((aug, p, reward))
            else:
                self.buffer.push((s, p, reward))

    # ------------------ training ------------------- #
    def train(
        self,
        batch_size: int = 256,
        epochs: int = 1,
        lr: float = 1e-3,
    ) -> None:
        if len(self.buffer) < batch_size:
            return
        opt = torch.optim.Adam(self.model.parameters(), lr=lr)
        for _ in range(epochs):
            batch = self.buffer.sample(batch_size)
            s, p, v = zip(*batch)
            s = torch.tensor(np.stack(s)).float().permute(0, 3, 1, 2).contiguous().to(self.device)
            p = torch.tensor(np.stack(p)).float().to(self.device)
            v = torch.tensor(v).float().to(self.device)

            opt.zero_grad()
            logits, v_pred = self.model(s)
            loss_pol = F.kl_div(
                torch.log_softmax(logits, -1), p, reduction="batchmean"
            )
            loss_val = F.mse_loss(v_pred, v)
            loss = loss_pol + loss_val
            loss.backward()
            opt.step()


# ---------------------------------------------------------------------- #
#                       Action-Space helper                              #
# ---------------------------------------------------------------------- #
class PegSolitaireActionSpace:
    """Maps (row,col,dir) ↔ flat-index and provides legal mask."""

    def __init__(self, board_mask: np.ndarray) -> None:
        self.valid_cells = [
            (r, c) for r in range(7) for c in range(7) if board_mask[r, c] == 1
        ]
        self.actions: List[Tuple[int, int, int]] = [
            (r, c, d) for r, c in self.valid_cells for d in range(4)
        ]
        self._to_idx: Dict[Tuple[int, int, int], int] = {
            a: i for i, a in enumerate(self.actions)
        }

    # mapping
    def to_index(self, a: Tuple[int, int, int]) -> int:
        return self._to_idx[a]

    def from_index(self, idx: int) -> Tuple[int, int, int]:
        return self.actions[idx]

    # mask
    def legal_action_mask(self, legal: List[Tuple[int, int, int]]) -> np.ndarray:
        mask = np.zeros(len(self.actions), dtype=np.float32)
        for a in legal:
            mask[self.to_index(a)] = 1.0
        return mask

    def __len__(self) -> int:
        return len(self.actions)

In [38]:
import os, types, tkinter as tk
import numpy as np
import torch

# ------------------- PATHS / PARAMS ------------------- #
AGENT_PATH      = "peg_agent.pt"
TRAIN_EPISODES  = 800           # אפיזודות self-play אם אין קובץ
DEVICE          = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("Using device:", DEVICE)

# ------------ Runtime patch – fix _expand ------------- #
def _safe_expand(self, obs, legal):
    """Replacement for MCTS._expand – נמנע חלוקה באפס."""
    with torch.no_grad():
        t = torch.tensor(obs).float().permute(2,0,1).unsqueeze(0).to(self.device)
        logits, _ = self.model(t)
        probs = torch.softmax(logits, -1).cpu().numpy().flatten()

    mask = self.action_space.legal_action_mask(legal)
    probs *= mask
    if mask.sum() == 0:                       # אין מהלכים חוקיים
        return _Node(prior=1.0)               # צומת עלה

    probs /= mask.sum()                       # נורמליזציה בטוחה
    node = _Node(prior=1.0)
    for a in legal:
        idx = self.action_space.to_index(a)
        node.children[idx] = _Node(prior=float(probs[idx]))
    return node

# החלפת המתודה במחלקה
MCTS._expand = _safe_expand            # type: ignore

# -------------- Helper save / load -------------------- #
def save_agent(agent, path=AGENT_PATH):
    torch.save({
        "state_dict": agent.model.state_dict(),
        "n_actions" : len(agent.action_space),
        "sims"      : agent.mcts.sims,
    }, path)
    print(f"✓ Agent saved ➜ {path}")

def load_agent(path=AGENT_PATH):
    ckpt  = torch.load(path, map_location=DEVICE)
    env   = PegSolitaireEnv(Board, Game)
    asp   = PegSolitaireActionSpace(env.board_mask)
    model = PegSolitaireNet(ckpt["n_actions"], device=DEVICE)
    model.load_state_dict(ckpt["state_dict"]); model.eval()
    return Agent(env, model, asp, ReplayBuffer(),
                 sims=ckpt.get("sims", 100), device=DEVICE)

# -------------- Load or Train ------------------------- #
if os.path.exists(AGENT_PATH):
    agent = load_agent()
else:
    env   = PegSolitaireEnv(Board, Game)
    asp   = PegSolitaireActionSpace(env.board_mask)
    model = PegSolitaireNet(len(asp), device=DEVICE)
    buf   = ReplayBuffer()
    agent = Agent(env, model, asp, buf, sims=100, device=DEVICE)

    print("• Training new agent …")
    for ep in range(1, TRAIN_EPISODES + 1):
        agent.self_play_episode()
        if ep % 10 == 0 and len(buf) > 1024:
            agent.train(batch_size=256, epochs=3, lr=1e-3)
        if ep % 100 == 0:
            print(f"  ↳ episode {ep:4d}/{TRAIN_EPISODES} | buffer={len(buf)}")
    save_agent(agent)


Using device: mps


In [39]:
import tkinter as tk
from tkinter import messagebox
import numpy as np
import torch   # נדרש ל-value-bar


# ----------------------------------------------------------- #
class PegSolitaireGUI(tk.Frame):
    """GUI עם פס-הצלחה ו-Hint מהסוכן."""

    CELL_SIZE, PEG_RADIUS, PADDING = 60, 22, 16
    PEG_COLOR, HOLE_COLOR          = "#FFD600", "#202020"
    OUTLINE_COLOR, HIGHLIGHT_COLOR = "#333",    "#42A5F5"
    SUGGEST_COLOR, BG_COLOR        = "#00C853", "#eeeeee"
    BAR_W, BAR_H                   = 160, 16

    def __init__(self, master, game: "Game", agent=None):
        super().__init__(master, bg=self.BG_COLOR)
        self.game   = game
        self.agent  = agent
        self.selected_pos = None            # type: tuple[int,int] | None
        self.highlight_move = None          # type: tuple[tuple[int,int],tuple[int,int]]|None

        # --- Canvas -------------------------------------------------- #
        side = 7 * self.CELL_SIZE + 2 * self.PADDING
        self.canvas = tk.Canvas(self, width=side, height=side,
                                bg=self.BG_COLOR, highlightthickness=0)
        self.canvas.pack()

        # --- Status row + Value bar ---------------------------------- #
        stat_frm = tk.Frame(self, bg=self.BG_COLOR); stat_frm.pack(pady=4, fill="x")
        self.status_label = tk.Label(stat_frm, font=("Arial", 14),
                                     bg=self.BG_COLOR, anchor="w")
        self.status_label.pack(side="left", padx=4, fill="x", expand=True)
        self.bar_canvas = tk.Canvas(stat_frm, width=self.BAR_W, height=self.BAR_H,
                                    bg=self.BG_COLOR, highlightthickness=0)
        self.bar_canvas.pack(side="right", padx=6)
        self._draw_value_bar(0.0)

        # --- Buttons ------------------------------------------------- #
        btn_frm = tk.Frame(self, bg=self.BG_COLOR); btn_frm.pack()
        self.undo_btn  = tk.Button(btn_frm, text="↩️ ביטול",  command=self.on_undo)
        self.redo_btn  = tk.Button(btn_frm, text="↪️ קדימה",  command=self.on_redo)
        self.reset_btn = tk.Button(btn_frm, text="משחק חדש", command=self.on_reset)
        self.hint_btn  = tk.Button(btn_frm, text="🤖 המלצה",  command=self.on_hint)
        for col, b in enumerate((self.undo_btn, self.redo_btn,
                                 self.reset_btn, self.hint_btn)):
            b.grid(row=0, column=col, padx=3)

        # --- Move-log ----------------------------------------------- #
        tk.Label(self, text="מהלכים:", font=("Arial", 12),
                 bg=self.BG_COLOR).pack()
        self.log_list = tk.Listbox(self, height=6, width=42,
                                   font=("Consolas", 11)); self.log_list.pack(pady=(0, 8))

        # bindings
        self.canvas.bind("<Button-1>", self.on_canvas_click)
        master.bind("<Control-z>", lambda e: self.on_undo())
        master.bind("<Control-y>", lambda e: self.on_redo())

        self.redraw()

    # ========================================================= #
    #                    ציור לוח ו-Value-Bar                    #
    # ========================================================= #
    def board_to_xy(self, pos):
        r, c = pos
        return (self.PADDING + c * self.CELL_SIZE + self.CELL_SIZE // 2,
                self.PADDING + r * self.CELL_SIZE + self.CELL_SIZE // 2)

    def redraw(self):
        self.canvas.delete("all")
        # פינים / חורים
        for pos in Board.LEGAL_POSITIONS:
            x, y = self.board_to_xy(pos)
            fill = self.PEG_COLOR if self.game.board.get(pos) == 1 else self.HOLE_COLOR
            width  = 3 if pos == self.selected_pos else 1
            outline = self.HIGHLIGHT_COLOR if pos == self.selected_pos else self.OUTLINE_COLOR
            self.canvas.create_oval(x-self.PEG_RADIUS, y-self.PEG_RADIUS,
                                    x+self.PEG_RADIUS, y+self.PEG_RADIUS,
                                    fill=fill, outline=outline, width=width)
        # יעדים חוקיים
        if self.selected_pos:
            for dst in [d for s, d, _ in self.game.get_legal_moves() if s == self.selected_pos]:
                x, y = self.board_to_xy(dst)
                self.canvas.create_oval(x-self.PEG_RADIUS//2, y-self.PEG_RADIUS//2,
                                        x+self.PEG_RADIUS//2, y+self.PEG_RADIUS//2,
                                        outline=self.HIGHLIGHT_COLOR, width=3)
        # קו-המלצה
        if self.highlight_move:
            src, dst = self.highlight_move
            x1, y1 = self.board_to_xy(src)
            x2, y2 = self.board_to_xy(dst)
            self.canvas.create_line(x1, y1, x2, y2, fill=self.SUGGEST_COLOR,
                                    width=5, arrow=tk.LAST)

        self.update_status(); self.update_buttons()
        self.update_move_log(); self.update_value_bar()

    # ------------- Value-Bar helpers ---------------- #
    def _draw_value_bar(self, v):
        self.bar_canvas.delete("all")
        frac   = (v + 1) / 2
        length = int(frac * self.BAR_W)
        color  = "#d50000" if v < -0.3 else "#9e9e9e" if v < 0.3 else "#00c853"
        self.bar_canvas.create_rectangle(0, 0, length, self.BAR_H, fill=color, width=0)
        self.bar_canvas.create_rectangle(0, 0, self.BAR_W, self.BAR_H, outline="#555")

    def update_value_bar(self):
        if self.agent is None:
            self._draw_value_bar(0.0); return
        obs = self.game.board.encode_observation()
        with torch.no_grad():
            t = torch.tensor(obs).float().permute(2,0,1).unsqueeze(0).to(self.agent.device)
            _, value = self.agent.model(t)
        self._draw_value_bar(float(value))

    # ========================================================= #
    #                   אירועי ממשק                             #
    # ========================================================= #
    def on_canvas_click(self, e):
        pos = ((e.y - self.PADDING)//self.CELL_SIZE,
               (e.x - self.PADDING)//self.CELL_SIZE)
        if pos not in Board.LEGAL_POSITIONS: return
        if self.selected_pos is None and self.game.board.get(pos)==1:
            self.selected_pos = pos
        elif self.selected_pos and pos != self.selected_pos:
            ok, *_ = self.game.apply_move(self.selected_pos, pos)
            if ok: self.selected_pos = self.highlight_move = None
        else:
            self.selected_pos = None
        self.redraw()

    def _generic(self, fn, msg=""):
        if not fn() and msg: self.status_label.config(text=msg)
        self.selected_pos = self.highlight_move = None; self.redraw()

    def on_undo(self):  self._generic(self.game.undo, "אין מהלך לבטל.")
    def on_redo(self):  self._generic(self.game.redo, "אין מהלך לשחזר.")
    def on_reset(self): self._generic(self.game.reset)

    def on_hint(self):
        if self.agent is None:
            messagebox.showinfo("המלצה", "אין סוכן מחובר."); return
        obs = self.game.board.encode_observation()
        pi  = self.agent.mcts.run(obs)
        idx = int(np.argmax(pi))
        r,c,d = self.agent.action_space.from_index(idx)
        dr,dc = Game.DIRECTIONS[d]
        self.highlight_move = ((r,c), (r+dr, c+dc))
        self.status_label.config(text=f"המלצה: {(r,c)} ↠ {(r+dr,c+dc)}")
        self.redraw()

    # ----------------- UI helpers ------------------ #
    def update_status(self):
        if self.game.is_win():
            text="ניצחון! פיון יחיד במרכז 👑"
        elif self.game.is_game_over():
            text=f"סיום • אין מהלכים ({len(self.game.move_log)})"
        else:
            text=f"פינים: {self.game.board.count_pegs()} | מהלך: {len(self.game.move_log)}"
        self.status_label.config(text=text)

    def update_buttons(self):
        self.undo_btn.config(state=tk.NORMAL if self.game.move_history else tk.DISABLED)
        self.redo_btn.config(state=tk.NORMAL if self.game.redo_stack   else tk.DISABLED)

    def update_move_log(self):
        self.log_list.delete(0, tk.END)
        for i,(src,over,dst) in enumerate(self.game.move_log,1):
            self.log_list.insert(tk.END, f"{i:2}: {src} → {dst} (/{over})")


# ---------------- DEMO (משתמש ב-agent אם קיים) --------------- #
if __name__ == "__main__":
    root = tk.Tk(); root.title("מחשבת – Peg-Solitaire")

    try:         # אם במודול הראשי כבר קיים agent מאומן – נשתמש בו
        agent  # type: ignore
    except NameError:
        class DummyAgent:           # גיבוי
            device=torch.device("cpu")
            action_space=None
            def mcts(self): ...
            class model:            # noqa
                @staticmethod
                def __call__(*_,**__): return None, torch.tensor([0.0])
        agent = DummyAgent()        # type: ignore

    gui = PegSolitaireGUI(root, Game(), agent=agent)
    gui.pack(); root.mainloop()