# Extract Evaluation Positions for Next-Move Prediction

Dataset format: `(FEN, last_move, next_move_candidates, correct_outputs)`

Each record contains **two separate groups** of illegal moves:
- `illegal_category_uci` -- category-specific illegals (pin-breaking, king-to-attacked, castling-while-in-check, etc.)
- `illegal_general_uci` -- general distractors (backward pawn, friendly fire, blocked sliding, wrong geometry, etc.)

Both are merged into `next_move_candidates_uci`; `correct_outputs_uci` = all legal moves.

**Position categories:**
1. **En passant** -- en passant is legal; distractors are pawn diagonal moves to non-ep empty squares
2. **Check evasion (single)** -- in check by one piece; distractors: king-to-attacked + castling-while-in-check
3. **Double check** -- two checkers; only king moves legal; distractors: non-king captures/blocks + king-to-attacked + castling-while-in-check
4. **Illegal king moves** -- not in check; king-to-attacked + castling through attacked squares
5. **Pin** -- own piece pinned to own king; distractors: pseudo-legal pin-breaking moves
6. **Promotion** -- pawn can promote; distractors: promotion-captures to empty squares, promotion push onto piece
7. **Vanilla** -- random positions with no special tag; only general distractors

In [1]:
import chess
import chess.pgn
import json
import random
from pathlib import Path
from collections import defaultdict
from typing import Optional, List, Set, Dict, Tuple
from dataclasses import dataclass, field

# ── Config knobs ─────────────────────────────────────────────────────────────
PGN_PATH = Path('data/lichess_db_standard_rated_2013-01.pgn')
OUT_PATH = Path('data/eval_positions_preview.jsonl')
MAX_GAMES = 50                  # games to scan
NUM_GENERAL_DISTRACTORS = 5     # general illegal moves per position
NUM_VANILLA_POSITIONS = 100     # how many vanilla (no-tag) positions to sample
SEED = 50

rng = random.Random(SEED)

In [2]:
# ── Helpers ──────────────────────────────────────────────────────────────────

def get_phase(board: chess.Board) -> str:
    """Classify position into opening / middlegame / endgame."""
    fullmove = board.fullmove_number
    piece_map = board.piece_map()
    non_pawn_non_king = sum(
        1 for p in piece_map.values()
        if p.piece_type not in (chess.PAWN, chess.KING)
    )
    total = len(piece_map)

    if fullmove <= 12:
        return "opening"
    if non_pawn_non_king >= 6 and total >= 16:
        return "middlegame"
    return "endgame"


def iter_games(pgn_path: str, max_games: Optional[int]):
    with open(pgn_path, "r", encoding="utf-8", errors="ignore") as f:
        count = 0
        while True:
            game = chess.pgn.read_game(f)
            if game is None:
                break
            yield game
            count += 1
            if max_games is not None and count >= max_games:
                break


# ── Castling geometry table ─────────────────────────────────────────────────
# Used by check, double-check, and illegal-king detectors.
_CASTLE_INFO = {
    chess.WHITE: [
        # kingside O-O
        {"rights_fn": "has_kingside_castling_rights",
         "king_from": chess.E1, "king_to": chess.G1,
         "rook_sq": chess.H1,
         "clear_sqs": [chess.F1, chess.G1],
         "safe_sqs":  [chess.E1, chess.F1, chess.G1]},
        # queenside O-O-O
        {"rights_fn": "has_queenside_castling_rights",
         "king_from": chess.E1, "king_to": chess.C1,
         "rook_sq": chess.A1,
         "clear_sqs": [chess.B1, chess.C1, chess.D1],
         "safe_sqs":  [chess.E1, chess.D1, chess.C1]},
    ],
    chess.BLACK: [
        {"rights_fn": "has_kingside_castling_rights",
         "king_from": chess.E8, "king_to": chess.G8,
         "rook_sq": chess.H8,
         "clear_sqs": [chess.F8, chess.G8],
         "safe_sqs":  [chess.E8, chess.F8, chess.G8]},
        {"rights_fn": "has_queenside_castling_rights",
         "king_from": chess.E8, "king_to": chess.C8,
         "rook_sq": chess.A8,
         "clear_sqs": [chess.B8, chess.C8, chess.D8],
         "safe_sqs":  [chess.E8, chess.D8, chess.C8]},
    ],
}

print("Helpers loaded.")

Helpers loaded.


In [3]:
# ── Category 1: En Passant ───────────────────────────────────────────────────
# Detect positions where en passant is legal.
# Distractors: pawn diagonal moves to non-ep empty squares (looks like a
# capture but there's nothing there and it's not en passant).

def detect_en_passant(board: chess.Board) -> Optional[dict]:
    """If en passant is legal, return info about it."""
    if board.ep_square is None:
        return None

    ep_moves = [m for m in board.legal_moves if board.is_en_passant(m)]
    if not ep_moves:
        return None

    return {"ep_moves": ep_moves, "ep_square": board.ep_square}


def build_en_passant_illegals(board: chess.Board, info: dict) -> List[str]:
    """Return category-specific illegal moves for en passant positions.

    Distractors: pawn diagonal moves to empty squares that aren't the ep square.
    """
    legal_ucis = set(m.uci() for m in board.legal_moves)
    illegal = []
    for sq in board.pieces(chess.PAWN, board.turn):
        for target in board.attacks(sq):
            if board.piece_at(target) is None and target != board.ep_square:
                uci = chess.Move(sq, target).uci()
                if uci not in legal_ucis:
                    illegal.append(uci)
    return illegal

print("En passant detector loaded.")

En passant detector loaded.


In [4]:
# ── Category 2: Check Evasion (single check) ────────────────────────────────
# Position is in check by ONE piece. Legal responses: king move, capture attacker, block.
# Illegal distractors: king moves to attacked squares + castling while in check.

@dataclass
class CheckInfo:
    king_moves: List[chess.Move]       # legal king moves (to safe squares)
    captures: List[chess.Move]         # legal captures of the checking piece
    blocks: List[chess.Move]           # legal interpositions
    illegal_king_moves: List[chess.Move]  # king moves to attacked squares
    illegal_castling: List[chess.Move]    # castling while in check
    num_checkers: int

    @property
    def evasion_types(self) -> int:
        return (
            int(len(self.king_moves) > 0) +
            int(len(self.captures) > 0) +
            int(len(self.blocks) > 0)
        )


def _find_castling_in_check(board: chess.Board) -> List[chess.Move]:
    """Find castling moves that are illegal because the king is in check.
    Requires: castling rights exist, rook on square, path clear."""
    turn = board.turn
    illegal = []
    for info in _CASTLE_INFO[turn]:
        if not getattr(board, info["rights_fn"])(turn):
            continue
        rook = board.piece_at(info["rook_sq"])
        if rook is None or rook.piece_type != chess.ROOK or rook.color != turn:
            continue
        if any(board.piece_at(sq) is not None for sq in info["clear_sqs"]):
            continue
        illegal.append(chess.Move(info["king_from"], info["king_to"]))
    return illegal


def analyze_check(board: chess.Board) -> Optional[CheckInfo]:
    """Detailed analysis of a single-check position."""
    if not board.is_check():
        return None

    king_sq = board.king(board.turn)
    checkers = list(board.checkers())

    # Compute block squares (squares between king and sliding checkers)
    block_squares: Set[int] = set()
    for att_sq in checkers:
        piece = board.piece_at(att_sq)
        if piece and piece.piece_type in (chess.BISHOP, chess.ROOK, chess.QUEEN):
            between = chess.SquareSet.between(king_sq, att_sq)
            block_squares |= set(between)

    king_moves, captures, blocks = [], [], []
    for move in board.legal_moves:
        mover = board.piece_at(move.from_square)
        if mover and mover.piece_type == chess.KING:
            king_moves.append(move)
        elif move.to_square in [s for s in checkers]:
            captures.append(move)
        elif move.to_square in block_squares and mover and mover.piece_type != chess.KING:
            blocks.append(move)

    # Illegal king moves: king to attacked squares
    illegal_king_moves = []
    for dest in chess.SquareSet(chess.BB_KING_ATTACKS[king_sq]):
        own_piece = board.piece_at(dest)
        if own_piece and own_piece.color == board.turn:
            continue
        m = chess.Move(king_sq, dest)
        if not board.is_pseudo_legal(m):
            continue
        if m not in board.legal_moves and board.is_attacked_by(not board.turn, dest):
            illegal_king_moves.append(m)

    # Illegal castling: can't castle while in check
    illegal_castling = _find_castling_in_check(board)

    return CheckInfo(
        king_moves=king_moves,
        captures=captures,
        blocks=blocks,
        illegal_king_moves=illegal_king_moves,
        illegal_castling=illegal_castling,
        num_checkers=len(checkers),
    )


def build_check_candidates(board: chess.Board, info: CheckInfo) -> List[str]:
    """Return category-specific illegal moves for single check positions."""
    illegal = ([m.uci() for m in info.illegal_king_moves] +
               [m.uci() for m in info.illegal_castling])
    return illegal

print("Check evasion detector loaded.")

Check evasion detector loaded.


In [5]:
# ── Category 3: Double Check ─────────────────────────────────────────────────
# Two pieces give check simultaneously. Can't capture one checker and can't
# block both rays, so the ONLY legal moves are king moves.
#
# Distractors:
#   1. Non-king pseudo-legal moves that capture/block ONE checker (illegal)
#   2. King moves to attacked squares
#   3. Castling while in check

@dataclass
class DoubleCheckInfo:
    checker_squares: List[int]
    legal_king_moves: List[chess.Move]
    illegal_king_moves: List[chess.Move]
    illegal_non_king: List[chess.Move]
    illegal_castling: List[chess.Move]


def detect_double_check(board: chess.Board) -> Optional[DoubleCheckInfo]:
    """Detect double check: 2+ pieces giving check simultaneously."""
    if not board.is_check():
        return None

    checkers = list(board.checkers())
    if len(checkers) < 2:
        return None

    king_sq = board.king(board.turn)

    # In double check, all legal moves must be king moves
    legal_king_moves = list(board.legal_moves)

    # Illegal king moves to attacked squares
    illegal_king_moves = []
    for dest in chess.SquareSet(chess.BB_KING_ATTACKS[king_sq]):
        own_piece = board.piece_at(dest)
        if own_piece and own_piece.color == board.turn:
            continue
        m = chess.Move(king_sq, dest)
        if not board.is_pseudo_legal(m):
            continue
        if m not in board.legal_moves and board.is_attacked_by(not board.turn, dest):
            illegal_king_moves.append(m)

    # Tempting non-king moves: pseudo-legal captures/blocks of ONE checker
    block_or_capture_squares = set(checkers)
    for att_sq in checkers:
        piece = board.piece_at(att_sq)
        if piece and piece.piece_type in (chess.BISHOP, chess.ROOK, chess.QUEEN):
            block_or_capture_squares |= set(chess.SquareSet.between(king_sq, att_sq))

    illegal_non_king = []
    for m in board.pseudo_legal_moves:
        mover = board.piece_at(m.from_square)
        if mover and mover.piece_type == chess.KING:
            continue
        if m.to_square in block_or_capture_squares:
            if m not in board.legal_moves:
                illegal_non_king.append(m)

    # Castling while in check
    illegal_castling = _find_castling_in_check(board)

    return DoubleCheckInfo(
        checker_squares=checkers,
        legal_king_moves=legal_king_moves,
        illegal_king_moves=illegal_king_moves,
        illegal_non_king=illegal_non_king,
        illegal_castling=illegal_castling,
    )


def build_double_check_illegals(board: chess.Board, info: DoubleCheckInfo) -> List[str]:
    """Return category-specific illegal moves for double check positions."""
    return ([m.uci() for m in info.illegal_king_moves] +
            [m.uci() for m in info.illegal_non_king] +
            [m.uci() for m in info.illegal_castling])

print("Double check detector loaded.")

Double check detector loaded.


In [6]:
# ── Category 4: Illegal King Moves + Illegal Castling ────────────────────────
# Not in check, but the king has adjacent squares controlled by the opponent.
# Sub-category: castling rights exist and the path is clear, but the king
# passes through or lands on an attacked square, making the castle illegal.
#
# Distractors = king moves to attacked squares + illegal castling moves.
# We require at least 1 legal king move AND at least 1 illegal king/castle move.
#
# NOTE: _CASTLE_INFO is defined in the helpers cell above.

def _find_illegal_castling(board: chess.Board) -> List[chess.Move]:
    """Find castling moves that are illegal because the king passes through
    or lands on an attacked square (path is clear, rights exist, not in check)."""
    turn = board.turn
    opp = not turn
    illegal = []

    for info in _CASTLE_INFO[turn]:
        if not getattr(board, info["rights_fn"])(turn):
            continue
        rook = board.piece_at(info["rook_sq"])
        if rook is None or rook.piece_type != chess.ROOK or rook.color != turn:
            continue
        if any(board.piece_at(sq) is not None for sq in info["clear_sqs"]):
            continue
        castle_move = chess.Move(info["king_from"], info["king_to"])
        if castle_move in board.legal_moves:
            continue
        if any(board.is_attacked_by(opp, sq) for sq in info["safe_sqs"]):
            illegal.append(castle_move)

    return illegal


def detect_illegal_king_moves(board: chess.Board) -> Optional[dict]:
    """Find positions (not in check) where king has pseudo-legal moves to
    attacked squares, or where castling is blocked by attacked squares."""
    if board.is_check():
        return None

    king_sq = board.king(board.turn)
    if king_sq is None:
        return None

    legal_king = []
    illegal_king = []

    for dest in chess.SquareSet(chess.BB_KING_ATTACKS[king_sq]):
        own_piece = board.piece_at(dest)
        if own_piece and own_piece.color == board.turn:
            continue
        m = chess.Move(king_sq, dest)
        if not board.is_pseudo_legal(m):
            continue
        if m in board.legal_moves:
            legal_king.append(m)
        elif board.is_attacked_by(not board.turn, dest):
            illegal_king.append(m)

    illegal_castling = _find_illegal_castling(board)

    if not (illegal_king or illegal_castling):
        return None
    has_legal_king = len(legal_king) > 0 or any(
        board.is_castling(m) for m in board.legal_moves
    )
    if not has_legal_king:
        return None

    return {
        "legal_king": legal_king,
        "illegal_king": illegal_king,
        "illegal_castling": illegal_castling,
    }


def build_illegal_king_illegals(board: chess.Board, info: dict) -> List[str]:
    """Return category-specific illegal moves for illegal king positions."""
    return ([m.uci() for m in info["illegal_king"]] +
            [m.uci() for m in info["illegal_castling"]])

print("Illegal king move + castling detector loaded.")

Illegal king move + castling detector loaded.


In [7]:
# ── Category 5: Pin Against King ─────────────────────────────────────────────
# The side to move has one of its own pieces pinned to its own king by an
# opponent's sliding piece. The pinned piece has pseudo-legal moves that
# leave the pin ray -- those are illegal because they'd expose the king.
#
# Distractors = pseudo-legal moves of the pinned piece off the pin ray.

def detect_pin(board: chess.Board) -> Optional[dict]:
    """Find positions where the side to move has a piece pinned to its own king."""
    pinned_pieces = []
    pinned_illegal_moves = []

    for sq, piece in board.piece_map().items():
        if piece.color != board.turn:
            continue
        if piece.piece_type == chess.KING:
            continue
        if not board.is_pinned(board.turn, sq):
            continue

        pin_mask = board.pin(board.turn, sq)

        illegal_for_piece = []
        for m in board.pseudo_legal_moves:
            if m.from_square != sq:
                continue
            if m.to_square not in pin_mask:
                if m not in board.legal_moves:
                    illegal_for_piece.append(m)

        if illegal_for_piece:
            pinned_pieces.append({
                "square": sq,
                "piece": piece,
                "illegal_moves": illegal_for_piece,
            })
            pinned_illegal_moves.extend(illegal_for_piece)

    if not pinned_pieces:
        return None

    return {"pinned_pieces": pinned_pieces, "illegal_moves": pinned_illegal_moves}


def build_pin_illegals(board: chess.Board, info: dict) -> List[str]:
    """Return category-specific illegal moves for pin positions."""
    return [m.uci() for m in info["illegal_moves"]]

print("Pin detector loaded.")

Pin detector loaded.


In [8]:
# ── Category 6: Promotion ────────────────────────────────────────────────────
# A pawn is on the 7th rank (2nd for black) and can promote.
# Legal promotion moves have a suffix like e7e8q.
#
# Distractors:
#   - Promotion push onto an occupied square (pawn blocked)
#   - Promotion capture to an empty square (no piece to capture)
#   - Under-promotion alternatives (if the game move promotes to queen,
#     promoting to rook/bishop/knight are legal but "wrong" -- but these
#     ARE legal, so we don't list them as illegal; they go in correct)

PROMO_PIECES = [chess.QUEEN, chess.ROOK, chess.BISHOP, chess.KNIGHT]


def detect_promotion(board: chess.Board) -> Optional[dict]:
    """Detect positions where a pawn can promote this move."""
    promo_moves = [m for m in board.legal_moves if m.promotion is not None]
    if not promo_moves:
        return None

    return {"promo_moves": promo_moves}


def build_promotion_illegals(board: chess.Board, info: dict) -> List[str]:
    """Return category-specific illegal moves for promotion positions.

    Distractors:
    1. Promotion push onto occupied square (pawn can't push through a piece)
    2. Promotion capture to empty square (nothing to capture diagonally)
    """
    legal_ucis = set(m.uci() for m in board.legal_moves)
    illegal = []
    turn = board.turn
    promo_rank = 6 if turn == chess.WHITE else 1  # rank the pawn is on before promoting

    for sq in board.pieces(chess.PAWN, turn):
        if chess.square_rank(sq) != promo_rank:
            continue
        file = chess.square_file(sq)
        dest_rank = 7 if turn == chess.WHITE else 0

        # 1. Push forward onto occupied square
        push_dest = chess.square(file, dest_rank)
        if board.piece_at(push_dest) is not None:
            for promo in PROMO_PIECES:
                m = chess.Move(sq, push_dest, promotion=promo)
                uci = m.uci()
                if uci not in legal_ucis:
                    illegal.append(uci)

        # 2. Diagonal capture to empty square
        for df in [-1, 1]:
            cap_file = file + df
            if not (0 <= cap_file <= 7):
                continue
            cap_dest = chess.square(cap_file, dest_rank)
            if board.piece_at(cap_dest) is None:
                for promo in PROMO_PIECES:
                    m = chess.Move(sq, cap_dest, promotion=promo)
                    uci = m.uci()
                    if uci not in legal_ucis:
                        illegal.append(uci)

    return illegal

print("Promotion detector loaded.")

Promotion detector loaded.


In [9]:
# ── General Illegal Move Generator ───────────────────────────────────────────
# These get added to EVERY tagged position to pad the candidate list with
# plausible-looking but fundamentally illegal moves.
#
# Types:
#   1. Backward pawn       -- pawn moves in wrong direction
#   2. Friendly fire       -- piece "captures" own piece
#   3. Blocked sliding     -- rook/bishop/queen tries to move through a piece
#   4. Pawn double-push    -- pawn pushes two squares from non-starting rank
#   5. Pawn forward onto   -- pawn pushes into an occupied square
#      occupied square
#   6. Wrong geometry      -- piece moves in a way its type doesn't allow
#                             (knight diagonal, bishop straight, etc.)

PIECE_NAMES = {chess.PAWN: "P", chess.KNIGHT: "N", chess.BISHOP: "B",
               chess.ROOK: "R", chess.QUEEN: "Q", chess.KING: "K"}


def _on_board(sq: int) -> bool:
    return 0 <= sq <= 63


def _gen_backward_pawn(board: chess.Board, turn: chess.Color) -> List[Tuple[chess.Move, str]]:
    """Pawn moves backward one square."""
    results = []
    direction = -8 if turn == chess.WHITE else 8  # backward
    for sq in board.pieces(chess.PAWN, turn):
        dest = sq + direction
        if _on_board(dest) and board.piece_at(dest) is None:
            results.append((chess.Move(sq, dest), "backward_pawn"))
    return results


def _gen_friendly_fire(board: chess.Board, turn: chess.Color) -> List[Tuple[chess.Move, str]]:
    """Piece 'captures' a friendly piece on a square it could normally reach."""
    results = []
    for sq, piece in board.piece_map().items():
        if piece.color != turn or piece.piece_type == chess.PAWN:
            continue
        for target in board.attacks(sq):
            target_piece = board.piece_at(target)
            if target_piece and target_piece.color == turn:
                results.append((chess.Move(sq, target), "friendly_fire"))
    return results


def _gen_blocked_sliding(board: chess.Board, turn: chess.Color) -> List[Tuple[chess.Move, str]]:
    """Sliding piece tries to reach a square blocked by an intervening piece."""
    results = []
    for sq, piece in board.piece_map().items():
        if piece.color != turn:
            continue
        if piece.piece_type not in (chess.ROOK, chess.BISHOP, chess.QUEEN):
            continue
        # Compute attacks on an empty board vs real board
        empty = chess.Board.empty()
        empty.set_piece_at(sq, piece)
        empty_attacks = set(empty.attacks(sq))
        real_attacks = set(board.attacks(sq))
        blocked = empty_attacks - real_attacks
        # Only keep squares that have no friendly piece (otherwise it's friendly_fire)
        for dest in blocked:
            dest_piece = board.piece_at(dest)
            if dest_piece and dest_piece.color == turn:
                continue
            results.append((chess.Move(sq, dest), "blocked_sliding"))
    return results


def _gen_pawn_double_push_wrong_rank(board: chess.Board, turn: chess.Color) -> List[Tuple[chess.Move, str]]:
    """Pawn tries to push two squares from a non-starting rank."""
    results = []
    direction = 16 if turn == chess.WHITE else -16
    start_rank = 1 if turn == chess.WHITE else 6
    for sq in board.pieces(chess.PAWN, turn):
        if chess.square_rank(sq) == start_rank:
            continue  # this would be a real double-push
        dest = sq + direction
        if _on_board(dest) and board.piece_at(dest) is None:
            # Also check the intermediate square is clear (makes the move look more "real")
            mid = sq + (8 if turn == chess.WHITE else -8)
            if _on_board(mid) and board.piece_at(mid) is None:
                results.append((chess.Move(sq, dest), "pawn_double_wrong_rank"))
    return results


def _gen_pawn_push_onto_piece(board: chess.Board, turn: chess.Color) -> List[Tuple[chess.Move, str]]:
    """Pawn pushes forward into an occupied square (pawns capture diag, not forward)."""
    results = []
    direction = 8 if turn == chess.WHITE else -8
    for sq in board.pieces(chess.PAWN, turn):
        dest = sq + direction
        if _on_board(dest):
            dest_piece = board.piece_at(dest)
            if dest_piece and dest_piece.color != turn:
                # There's an enemy piece directly ahead -- pawn can't take forward
                results.append((chess.Move(sq, dest), "pawn_push_onto_piece"))
    return results


def _gen_wrong_geometry(board: chess.Board, turn: chess.Color) -> List[Tuple[chess.Move, str]]:
    """Piece moves in a way that violates its movement rules."""
    results = []
    for sq, piece in board.piece_map().items():
        if piece.color != turn:
            continue
        rank, file = chess.square_rank(sq), chess.square_file(sq)

        if piece.piece_type == chess.KNIGHT:
            # Knight tries to move diagonally (1,1) like a bishop
            for dr, df in [(1, 1), (1, -1), (-1, 1), (-1, -1)]:
                nr, nf = rank + dr, file + df
                if 0 <= nr <= 7 and 0 <= nf <= 7:
                    dest = chess.square(nf, nr)
                    dp = board.piece_at(dest)
                    if dp is None or dp.color != turn:
                        results.append((chess.Move(sq, dest), "wrong_geometry_knight"))

        elif piece.piece_type == chess.BISHOP:
            # Bishop tries to move straight (like a rook) -- 2 squares
            for dr, df in [(2, 0), (-2, 0), (0, 2), (0, -2)]:
                nr, nf = rank + dr, file + df
                if 0 <= nr <= 7 and 0 <= nf <= 7:
                    dest = chess.square(nf, nr)
                    dp = board.piece_at(dest)
                    if dp is None or dp.color != turn:
                        results.append((chess.Move(sq, dest), "wrong_geometry_bishop"))

        elif piece.piece_type == chess.ROOK:
            # Rook tries to move diagonally -- 2 squares
            for dr, df in [(2, 2), (2, -2), (-2, 2), (-2, -2)]:
                nr, nf = rank + dr, file + df
                if 0 <= nr <= 7 and 0 <= nf <= 7:
                    dest = chess.square(nf, nr)
                    dp = board.piece_at(dest)
                    if dp is None or dp.color != turn:
                        results.append((chess.Move(sq, dest), "wrong_geometry_rook"))

    return results


def generate_general_distractors(
    board: chess.Board,
    legal_ucis: Set[str],
    rng: random.Random,
    num_target: int = 5,
) -> List[Tuple[str, str]]:
    """Generate diverse illegal distractors. Returns list of (uci, type) pairs."""
    turn = board.turn

    # Collect from all generators
    all_candidates: List[Tuple[chess.Move, str]] = []
    all_candidates += _gen_backward_pawn(board, turn)
    all_candidates += _gen_friendly_fire(board, turn)
    all_candidates += _gen_blocked_sliding(board, turn)
    all_candidates += _gen_pawn_double_push_wrong_rank(board, turn)
    all_candidates += _gen_pawn_push_onto_piece(board, turn)
    all_candidates += _gen_wrong_geometry(board, turn)

    # Filter: must not be legal, must not duplicate
    seen = set(legal_ucis)
    filtered = []
    for move, mtype in all_candidates:
        uci = move.uci()
        if uci not in seen:
            seen.add(uci)
            filtered.append((uci, mtype))

    # Try to pick one from each type, then fill remaining randomly
    by_type = defaultdict(list)
    for uci, mtype in filtered:
        by_type[mtype].append(uci)

    selected = []
    types_available = list(by_type.keys())
    rng.shuffle(types_available)

    # One from each type first
    for t in types_available:
        if len(selected) >= num_target:
            break
        choice = rng.choice(by_type[t])
        selected.append((choice, t))
        by_type[t].remove(choice)

    # Fill remaining from the pool
    remaining = [(u, t) for t, ulist in by_type.items() for u in ulist]
    rng.shuffle(remaining)
    for u, t in remaining:
        if len(selected) >= num_target:
            break
        selected.append((u, t))

    return selected


# Quick test
b = chess.Board()  # starting position
b.push_san("e4")
b.push_san("e5")
b.push_san("Nf3")
distractors = generate_general_distractors(b, set(m.uci() for m in b.legal_moves), rng, num_target=8)
print(f"Test: generated {len(distractors)} distractors from position after 1.e4 e5 2.Nf3:")
for uci, dtype in distractors:
    print(f"  {uci:6s}  ({dtype})")

Test: generated 8 distractors from position after 1.e4 e5 2.Nf3:
  a8b8    (friendly_fire)
  e5e6    (backward_pawn)
  f8f6    (wrong_geometry_bishop)
  a8c6    (wrong_geometry_rook)
  d8a5    (blocked_sliding)
  e5e4    (pawn_push_onto_piece)
  h8h6    (blocked_sliding)
  h8h1    (blocked_sliding)


In [10]:
# ── Main extraction loop ─────────────────────────────────────────────────────

def make_row(
    board: chess.Board,
    last_move: chess.Move,
    game_move: chess.Move,
    tags: List[str],
    legal_uci: List[str],
    cat_illegal_uci: List[str],
    gen_illegal_uci: List[str],
    game,
    ply: int,
    extra: Optional[dict] = None,
) -> dict:
    """Build one output record."""
    candidates = list(set(legal_uci + cat_illegal_uci + gen_illegal_uci))
    row = {
        "fen": board.fen(),
        "last_move_uci": last_move.uci(),
        "game_move_uci": game_move.uci(),
        "next_move_candidates_uci": candidates,
        "correct_outputs_uci": legal_uci,
        "illegal_category_uci": cat_illegal_uci,
        "illegal_general_uci": gen_illegal_uci,
        "tags": tags,
        "phase": get_phase(board),
        "game_id": game.headers.get("Site", ""),
        "ply": ply,
        "num_candidates": len(candidates),
        "num_correct": len(legal_uci),
        "num_illegal_category": len(cat_illegal_uci),
        "num_illegal_general": len(gen_illegal_uci),
    }
    if extra:
        row.update(extra)
    return row


def extract_all(pgn_path: str, max_games: int) -> List[dict]:
    """Scan games and extract tagged + vanilla positions with distractors."""
    rows = []
    tag_counts = defaultdict(int)

    # For vanilla sampling: collect candidate (board, last_move, game_move, game, ply)
    # from positions with no special tag, then sample NUM_VANILLA_POSITIONS at the end.
    vanilla_candidates = []

    for gi, game in enumerate(iter_games(pgn_path, max_games)):
        board = game.board()
        last_move = None
        ply = 0

        for game_move in game.mainline_moves():
            if last_move is not None:
                tags = []
                cat_illegal: List[str] = []
                extra = {}

                # ── Cat 1: En passant ──
                ep_info = detect_en_passant(board)
                if ep_info:
                    tags.append("en_passant")
                    cat_illegal += build_en_passant_illegals(board, ep_info)
                    extra["ep_moves_uci"] = [m.uci() for m in ep_info["ep_moves"]]

                # ── Cat 3: Double check (before single check) ──
                dc_info = detect_double_check(board)
                if dc_info:
                    tags.append("double_check")
                    cat_illegal += build_double_check_illegals(board, dc_info)
                    extra["num_checkers"] = len(dc_info.checker_squares)
                    extra["checker_squares"] = [chess.square_name(s) for s in dc_info.checker_squares]
                    extra["checker_pieces"] = [board.piece_at(s).symbol() for s in dc_info.checker_squares]
                    extra["num_legal_king_moves"] = len(dc_info.legal_king_moves)

                # ── Cat 2: Single check (only if NOT double check) ──
                elif board.is_check():
                    check_info = analyze_check(board)
                    if check_info and check_info.evasion_types >= 2:
                        tags.append("check")
                        cat_illegal += build_check_candidates(board, check_info)
                        extra["check_king_moves"] = len(check_info.king_moves)
                        extra["check_captures"] = len(check_info.captures)
                        extra["check_blocks"] = len(check_info.blocks)
                        extra["check_illegal_king"] = len(check_info.illegal_king_moves)
                        extra["check_illegal_castling"] = len(check_info.illegal_castling)

                # ── Cat 4: Illegal king moves + castling (not in check) ──
                if not board.is_check():
                    ik_info = detect_illegal_king_moves(board)
                    if ik_info:
                        tags.append("illegal_king")
                        cat_illegal += build_illegal_king_illegals(board, ik_info)
                        extra["num_legal_king_moves"] = len(ik_info["legal_king"])
                        extra["num_illegal_king_moves"] = len(ik_info["illegal_king"])
                        extra["num_illegal_castling"] = len(ik_info["illegal_castling"])
                        if ik_info["illegal_castling"]:
                            extra["illegal_castling_uci"] = [
                                m.uci() for m in ik_info["illegal_castling"]
                            ]

                # ── Cat 5: Pin ──
                pin_info = detect_pin(board)
                if pin_info:
                    tags.append("pin")
                    cat_illegal += build_pin_illegals(board, pin_info)
                    extra["num_pinned_pieces"] = len(pin_info["pinned_pieces"])
                    extra["num_pin_illegal_moves"] = len(pin_info["illegal_moves"])
                    extra["pinned_details"] = [
                        {
                            "square": chess.square_name(p["square"]),
                            "piece": p["piece"].symbol(),
                            "num_illegal": len(p["illegal_moves"]),
                        }
                        for p in pin_info["pinned_pieces"]
                    ]

                # ── Cat 6: Promotion ──
                promo_info = detect_promotion(board)
                if promo_info:
                    tags.append("promotion")
                    cat_illegal += build_promotion_illegals(board, promo_info)
                    extra["promo_moves_uci"] = [m.uci() for m in promo_info["promo_moves"]]

                legal_uci = [m.uci() for m in board.legal_moves]

                if tags:
                    # Deduplicate category illegals & remove any that are actually legal
                    legal_set = set(legal_uci)
                    cat_illegal = list(set(cat_illegal) - legal_set)

                    # General distractors
                    existing = legal_set | set(cat_illegal)
                    gen_distractors = generate_general_distractors(
                        board, legal_ucis=existing, rng=rng,
                        num_target=NUM_GENERAL_DISTRACTORS,
                    )
                    gen_illegal = [u for u, _ in gen_distractors]
                    extra["general_distractor_types"] = [t for _, t in gen_distractors]

                    row = make_row(
                        board, last_move, game_move, tags,
                        legal_uci, cat_illegal, gen_illegal,
                        game, ply, extra,
                    )
                    rows.append(row)
                    for t in tags:
                        tag_counts[t] += 1
                else:
                    # No special tag → vanilla candidate
                    vanilla_candidates.append(
                        (board.fen(), last_move.uci(), game_move.uci(),
                         game.headers.get("Site", ""), ply)
                    )

            board.push(game_move)
            last_move = game_move
            ply += 1

        if (gi + 1) % 10 == 0:
            print(f"  Processed {gi+1} games, {len(rows)} tagged positions so far...")

    # ── Vanilla positions ──
    rng.shuffle(vanilla_candidates)
    num_vanilla = min(NUM_VANILLA_POSITIONS, len(vanilla_candidates))
    print(f"\nSampling {num_vanilla} vanilla positions from {len(vanilla_candidates)} candidates...")

    for fen, lm_uci, gm_uci, game_id, v_ply in vanilla_candidates[:num_vanilla]:
        vboard = chess.Board(fen)
        legal_uci = [m.uci() for m in vboard.legal_moves]
        gen_distractors = generate_general_distractors(
            vboard, legal_ucis=set(legal_uci), rng=rng,
            num_target=NUM_GENERAL_DISTRACTORS,
        )
        gen_illegal = [u for u, _ in gen_distractors]
        gen_types = [t for _, t in gen_distractors]

        candidates = list(set(legal_uci + gen_illegal))
        row = {
            "fen": fen,
            "last_move_uci": lm_uci,
            "game_move_uci": gm_uci,
            "next_move_candidates_uci": candidates,
            "correct_outputs_uci": legal_uci,
            "illegal_category_uci": [],
            "illegal_general_uci": gen_illegal,
            "tags": ["vanilla"],
            "phase": get_phase(vboard),
            "game_id": game_id,
            "ply": v_ply,
            "num_candidates": len(candidates),
            "num_correct": len(legal_uci),
            "num_illegal_category": 0,
            "num_illegal_general": len(gen_illegal),
            "general_distractor_types": gen_types,
        }
        rows.append(row)
        tag_counts["vanilla"] += 1

    print(f"\nDone: {gi+1} games, {len(rows)} total positions")
    print(f"Tag counts: {dict(tag_counts)}")
    return rows

print("Extraction function ready.")

Extraction function ready.


In [11]:
# ── Run extraction ───────────────────────────────────────────────────────────
rows = extract_all(str(PGN_PATH), MAX_GAMES)

  Processed 10 games, 164 tagged positions so far...
  Processed 20 games, 450 tagged positions so far...
  Processed 30 games, 593 tagged positions so far...
  Processed 40 games, 795 tagged positions so far...
  Processed 50 games, 1076 tagged positions so far...

Sampling 100 vanilla positions from 2208 candidates...

Done: 50 games, 1176 total positions
Tag counts: {'illegal_king': 844, 'pin': 138, 'check': 114, 'en_passant': 6, 'promotion': 20, 'double_check': 3, 'vanilla': 100}


In [12]:
# ── Summary statistics ───────────────────────────────────────────────────────
from collections import Counter

tag_counter = Counter()
phase_counter = Counter()
castling_count = 0
for r in rows:
    for t in r["tags"]:
        tag_counter[t] += 1
    phase_counter[r["phase"]] += 1
    if r.get("num_illegal_castling", 0) > 0:
        castling_count += 1

print("=== Tag distribution ===")
for tag, cnt in tag_counter.most_common():
    print(f"  {tag:20s}: {cnt}")

print(f"\n=== Phase distribution ===")
for phase, cnt in phase_counter.most_common():
    print(f"  {phase:20s}: {cnt}")

print(f"\n=== Candidate stats ===")
num_cands = [r["num_candidates"] for r in rows]
num_cat = [r["num_illegal_category"] for r in rows]
num_gen = [r["num_illegal_general"] for r in rows]
print(f"  Avg candidates:            {sum(num_cands)/len(num_cands):.1f}")
print(f"  Avg category illegals:     {sum(num_cat)/len(num_cat):.1f}")
print(f"  Avg general illegals:      {sum(num_gen)/len(num_gen):.1f}")
print(f"  Max category illegals:     {max(num_cat)}")
print(f"  Max general illegals:      {max(num_gen)}")
print(f"  Positions w/ cat illegals: {sum(1 for x in num_cat if x > 0)} / {len(rows)}")
print(f"  Positions w/ illegal castling: {castling_count}")

=== Tag distribution ===
  illegal_king        : 844
  pin                 : 138
  check               : 114
  vanilla             : 100
  promotion           : 20
  en_passant          : 6
  double_check        : 3

=== Phase distribution ===
  middlegame          : 508
  endgame             : 504
  opening             : 164

=== Candidate stats ===
  Avg candidates:            32.4
  Avg category illegals:     2.3
  Avg general illegals:      4.9
  Max category illegals:     14
  Max general illegals:      5
  Positions w/ cat illegals: 1071 / 1176
  Positions w/ illegal castling: 28


In [13]:
# ── Inspect examples per category ────────────────────────────────────────────

def show_example(row, idx=None):
    """Print text details for one example."""
    prefix = f"[{idx}] " if idx is not None else ""
    print(f"{prefix}Tags: {row['tags']}  |  Phase: {row['phase']}  |  Ply: {row['ply']}")
    print(f"  FEN:        {row['fen']}")
    print(f"  Last move:  {row['last_move_uci']}")
    print(f"  Game move:  {row['game_move_uci']}")
    print(f"  Candidates: {row['num_candidates']}  "
          f"(legal={row['num_correct']}, cat_illegal={row['num_illegal_category']}, "
          f"gen_illegal={row['num_illegal_general']})")

    # Legal moves
    print(f"  Legal moves: {sorted(row['correct_outputs_uci'])}")

    # Category-specific illegals
    cat_ill = row.get("illegal_category_uci", [])
    if cat_ill:
        print(f"  Category illegals: {sorted(cat_ill)}")

    # General illegals
    gen_ill = row.get("illegal_general_uci", [])
    if gen_ill:
        print(f"  General illegals: {sorted(gen_ill)}")

    # General distractor types
    gen_types = row.get("general_distractor_types", [])
    if gen_types:
        print(f"  General distractor types: {gen_types}")

    # Category-specific info
    if "en_passant" in row["tags"]:
        print(f"  EP moves: {row.get('ep_moves_uci', [])}")
    if "check" in row["tags"]:
        print(f"  Check evasions: king_moves={row.get('check_king_moves')}, "
              f"captures={row.get('check_captures')}, blocks={row.get('check_blocks')}, "
              f"illegal_king={row.get('check_illegal_king')}, "
              f"illegal_castling={row.get('check_illegal_castling', 0)}")
    if "double_check" in row["tags"]:
        print(f"  Double check: checkers={row.get('checker_squares')} "
              f"({row.get('checker_pieces')}), "
              f"legal_king={row.get('num_legal_king_moves')}")
    if "illegal_king" in row["tags"]:
        ic = row.get("num_illegal_castling", 0)
        if ic:
            print(f"  Illegal castling: {row.get('illegal_castling_uci', [])}")
    if "pin" in row["tags"]:
        print(f"  Pinned: {row.get('pinned_details', [])}")
    if "promotion" in row["tags"]:
        print(f"  Promotion moves: {row.get('promo_moves_uci', [])}")
    print()

# Show examples per tag
for tag in ["en_passant", "check", "double_check", "illegal_king", "pin", "promotion", "vanilla"]:
    examples = [r for r in rows if tag in r["tags"]]
    print(f"{'='*60}")
    print(f"  {tag.upper()} -- {len(examples)} positions")
    print(f"{'='*60}")
    for r in examples[:2]:
        show_example(r)

  EN_PASSANT -- 6 positions
Tags: ['en_passant', 'illegal_king']  |  Phase: endgame  |  Ply: 69
  FEN:        8/p3r2k/5R2/6p1/1PpP2Pp/2Pb3P/P4P2/2r1B1K1 b - g3 0 35
  Last move:  g2g4
  Game move:  e7e1
  Candidates: 41  (legal=31, cat_illegal=5, gen_illegal=5)
  Legal moves: ['a7a5', 'a7a6', 'c1a1', 'c1b1', 'c1c2', 'c1c3', 'c1d1', 'c1e1', 'd3b1', 'd3c2', 'd3e2', 'd3e4', 'd3f1', 'd3f5', 'd3g6', 'e7b7', 'e7c7', 'e7d7', 'e7e1', 'e7e2', 'e7e3', 'e7e4', 'e7e5', 'e7e6', 'e7e8', 'e7f7', 'e7g7', 'h4g3', 'h7g7', 'h7g8', 'h7h8']
  Category illegals: ['a7b6', 'c4b3', 'g5f4', 'h7g6', 'h7h6']
  General illegals: ['c1g1', 'd3d5', 'e7h7', 'g5g6', 'h4h3']
  General distractor types: ['friendly_fire', 'pawn_push_onto_piece', 'blocked_sliding', 'backward_pawn', 'wrong_geometry_bishop']
  EP moves: ['h4g3']

Tags: ['en_passant']  |  Phase: opening  |  Ply: 16
  FEN:        rnbq1rk1/ppp1b1pp/4p3/3pPp2/3P2Q1/2NB4/PPP2PPP/R3K1NR w KQ f6 0 9
  Last move:  f7f5
  Game move:  g4g3
  Candidates: 63  (legal=48,

In [14]:
# ── Visualize a few boards (SVG) ─────────────────────────────────────────────
import chess.svg
from IPython.display import SVG, display, HTML

def show_board(row, title=""):
    board = chess.Board(row["fen"])
    last = chess.Move.from_uci(row["last_move_uci"])

    cat_ill = set(row.get("illegal_category_uci", []))
    gen_ill = set(row.get("illegal_general_uci", []))

    # Arrows: yellow=last move, red=category illegals, orange=general illegals
    arrows = [chess.svg.Arrow(last.from_square, last.to_square, color="#888800")]
    for u in list(cat_ill)[:4]:
        m = chess.Move.from_uci(u)
        arrows.append(chess.svg.Arrow(m.from_square, m.to_square, color="#cc0000"))
    for u in list(gen_ill)[:3]:
        m = chess.Move.from_uci(u)
        arrows.append(chess.svg.Arrow(m.from_square, m.to_square, color="#cc8800"))

    svg = chess.svg.board(board, arrows=arrows, size=350)
    label = (f"<b>{title}</b><br/>Tags: {row['tags']} | Phase: {row['phase']}<br/>"
             f"Cat illegals (red): {len(cat_ill)} | Gen illegals (orange): {len(gen_ill)}")
    display(HTML(f"<div style='display:inline-block; margin:10px'>{label}<br/>{svg}</div>"))

# Show one example per category
for tag in ["en_passant", "check", "double_check", "illegal_king", "pin", "promotion", "vanilla"]:
    examples = [r for r in rows if tag in r["tags"]]
    if examples:
        show_board(examples[0], title=tag.upper())

In [15]:
# ── Save to JSONL ────────────────────────────────────────────────────────────
with OUT_PATH.open("w", encoding="utf-8") as f:
    for row in rows:
        f.write(json.dumps(row) + "\n")

print(f"Saved {len(rows)} positions to {OUT_PATH}")
print(f"Fields per record: {sorted(rows[0].keys())}")

Saved 1176 positions to data/eval_positions_preview.jsonl
Fields per record: ['correct_outputs_uci', 'fen', 'game_id', 'game_move_uci', 'general_distractor_types', 'illegal_category_uci', 'illegal_general_uci', 'last_move_uci', 'next_move_candidates_uci', 'num_candidates', 'num_correct', 'num_illegal_castling', 'num_illegal_category', 'num_illegal_general', 'num_illegal_king_moves', 'num_legal_king_moves', 'phase', 'ply', 'tags']


## Sanity checks

Verify the extracted data:
1. All `correct_outputs_uci` are actually legal moves in the FEN
2. All illegal distractors are actually illegal
3. En passant moves are among the legal moves
4. Pinned piece moves in distractors are truly illegal

In [16]:
# ── Sanity checks ────────────────────────────────────────────────────────────

errors = []
for i, row in enumerate(rows):
    board = chess.Board(row["fen"])
    legal_ucis = set(m.uci() for m in board.legal_moves)

    # Check 1: all correct outputs are legal
    for m in row["correct_outputs_uci"]:
        if m not in legal_ucis:
            errors.append(f"Row {i}: correct move {m} is NOT legal in {row['fen']}")

    # Check 2: correct = legal (should be the same set)
    if set(row["correct_outputs_uci"]) != legal_ucis:
        missing = legal_ucis - set(row["correct_outputs_uci"])
        extra = set(row["correct_outputs_uci"]) - legal_ucis
        if missing:
            errors.append(f"Row {i}: legal moves missing from correct: {missing}")
        if extra:
            errors.append(f"Row {i}: non-legal moves in correct: {extra}")

    # Check 3: all category illegals are actually illegal
    for m in row.get("illegal_category_uci", []):
        if m in legal_ucis:
            errors.append(f"Row {i}: category illegal {m} is actually LEGAL")

    # Check 4: all general illegals are actually illegal
    for m in row.get("illegal_general_uci", []):
        if m in legal_ucis:
            errors.append(f"Row {i}: general illegal {m} is actually LEGAL")

    # Check 5: no overlap between category and general illegals
    cat_set = set(row.get("illegal_category_uci", []))
    gen_set = set(row.get("illegal_general_uci", []))
    overlap = cat_set & gen_set
    if overlap:
        errors.append(f"Row {i}: overlap between cat and gen illegals: {overlap}")

    # Check 6: candidates = legal + cat_illegal + gen_illegal
    expected = set(row["correct_outputs_uci"]) | cat_set | gen_set
    actual = set(row["next_move_candidates_uci"])
    if expected != actual:
        errors.append(f"Row {i}: candidates mismatch (expected {len(expected)}, got {len(actual)})")

    # Check 7: en passant specific
    if "en_passant" in row["tags"]:
        for ep_uci in row.get("ep_moves_uci", []):
            if ep_uci not in legal_ucis:
                errors.append(f"Row {i}: ep move {ep_uci} not legal")

if errors:
    print(f"ERRORS ({len(errors)}):")
    for e in errors[:20]:
        print(f"  {e}")
else:
    print(f"All {len(rows)} positions passed sanity checks!")

All 1176 positions passed sanity checks!


In [17]:
# ── Tag co-occurrence ────────────────────────────────────────────────────────
all_tags = ["en_passant", "check", "double_check", "illegal_king", "pin", "promotion", "vanilla"]
cooc = defaultdict(int)
for r in rows:
    ts = sorted(set(r["tags"]))
    key = " + ".join(ts) if len(ts) > 1 else ts[0]
    cooc[key] += 1

print("Tag combinations:")
for combo, cnt in sorted(cooc.items(), key=lambda x: -x[1]):
    print(f"  {combo:45s}: {cnt}")

Tag combinations:
  illegal_king                                 : 800
  check                                        : 110
  pin                                          : 103
  vanilla                                      : 100
  illegal_king + pin                           : 31
  illegal_king + promotion                     : 12
  promotion                                    : 7
  en_passant                                   : 4
  check + pin                                  : 3
  double_check                                 : 3
  en_passant + illegal_king                    : 1
  check + promotion                            : 1
  en_passant + pin                             : 1


In [18]:
# ── Example JSON record ──────────────────────────────────────────────────────
import pprint

# Pick an example with both category and general illegals
ex = next((r for r in rows if r["num_illegal_category"] > 0 and r["num_illegal_general"] > 0), rows[0])
clean = {
    "fen": ex["fen"],
    "last_move_uci": ex["last_move_uci"],
    "next_move_candidates_uci": f"[{ex['num_candidates']} items]",
    "correct_outputs_uci": f"[{ex['num_correct']} items]",
    "illegal_category_uci": ex["illegal_category_uci"][:5],
    "illegal_general_uci": ex["illegal_general_uci"],
    "tags": ex["tags"],
    "phase": ex["phase"],
    "num_candidates": ex["num_candidates"],
    "num_correct": ex["num_correct"],
    "num_illegal_category": ex["num_illegal_category"],
    "num_illegal_general": ex["num_illegal_general"],
}
print("Example record (lists truncated for display):")
pprint.pprint(clean)

Example record (lists truncated for display):
{'correct_outputs_uci': '[34 items]',
 'fen': 'rn2kb1r/pbpp1p1p/1p2p2p/6q1/3PP3/P1N5/1PP1BPPP/R2QK1NR w KQkq - 2 7',
 'illegal_category_uci': ['e1d2'],
 'illegal_general_uci': ['d1f3', 'd4d3', 'a1a3', 'd4d6', 'c3b4'],
 'last_move_uci': 'd8g5',
 'next_move_candidates_uci': '[40 items]',
 'num_candidates': 40,
 'num_correct': 34,
 'num_illegal_category': 1,
 'num_illegal_general': 5,
 'phase': 'opening',
 'tags': ['illegal_king']}


## Interactive Browser

`browse(rows)` -- step through examples with Enter.  
Commands: `Enter`=next, `p`=prev, `q`=quit, number=jump, `t TAG`=filter by tag.

Arrow colors: **yellow**=last move, **red**=category illegals, **orange**=general illegals.

In [19]:
import chess.svg
from IPython.display import SVG, display, HTML, clear_output


def show_full(row, idx, total):
    """Display board SVG + full text details for one example."""
    clear_output(wait=True)
    board = chess.Board(row["fen"])
    last = chess.Move.from_uci(row["last_move_uci"])

    cat_ill = row.get("illegal_category_uci", [])
    gen_ill = row.get("illegal_general_uci", [])

    # Arrows: yellow=last move, red=cat illegals, orange=gen illegals
    arrows = [chess.svg.Arrow(last.from_square, last.to_square, color="#888800")]
    for u in cat_ill[:5]:
        m = chess.Move.from_uci(u)
        arrows.append(chess.svg.Arrow(m.from_square, m.to_square, color="#cc0000"))
    for u in gen_ill[:3]:
        m = chess.Move.from_uci(u)
        arrows.append(chess.svg.Arrow(m.from_square, m.to_square, color="#cc8800"))

    svg = chess.svg.board(board, arrows=arrows, size=400)
    display(HTML(svg))

    print(f"── Example {idx}/{total-1} ──")
    show_example(row, idx)

    # General distractor breakdown
    gen_types = row.get("general_distractor_types", [])
    if gen_types:
        print(f"  General distractor breakdown:")
        from collections import Counter
        for dtype, cnt in Counter(gen_types).most_common():
            print(f"    {dtype}: {cnt}")


def browse(data, tag_filter=None):
    """Interactive browser for extracted positions.

    Commands:
        Enter  = next example
        p      = previous example
        q      = quit
        NUMBER = jump to that index
        t TAG  = re-filter by tag (e.g. 't pin')
    """
    if tag_filter:
        data = [r for r in data if tag_filter in r["tags"]]
        print(f"Filtered to {len(data)} examples with tag '{tag_filter}'")

    if not data:
        print("No examples to show.")
        return

    idx = 0
    while True:
        show_full(data[idx], idx, len(data))
        try:
            cmd = input(f"\n[{idx}/{len(data)-1}] Enter=next  p=prev  q=quit  NUMBER=jump  t TAG=filter: ").strip()
        except (EOFError, KeyboardInterrupt):
            break

        if cmd == "q":
            break
        elif cmd == "p":
            idx = max(0, idx - 1)
        elif cmd == "":
            idx = min(len(data) - 1, idx + 1)
            if idx == len(data) - 1:
                print("(last example)")
        elif cmd.startswith("t "):
            new_tag = cmd[2:].strip()
            browse(rows, tag_filter=new_tag)
            return
        elif cmd.isdigit():
            idx = min(int(cmd), len(data) - 1)
        else:
            print(f"Unknown command: {cmd}")

print("Browser ready. Run: browse(rows)  or  browse(rows, tag_filter='pin')")

Browser ready. Run: browse(rows)  or  browse(rows, tag_filter='pin')


In [20]:
# Run this cell to start browsing (uncomment one):
# browse(rows)                            # all examples
# browse(rows, tag_filter="en_passant")   # en passant
# browse(rows, tag_filter="check")        # single check evasion
# browse(rows, tag_filter="double_check") # double check
# browse(rows, tag_filter="illegal_king") # king to attacked sq / illegal castling
# browse(rows, tag_filter="pin")          # pin
# browse(rows, tag_filter="promotion")    # promotion
# browse(rows, tag_filter="vanilla")      # vanilla (no special tag)