In [1]:
import chess

def moves_to_mate(fen, solution_moves):
    board = chess.Board(fen)
    move_count = 0

    for uci in solution_moves:
        move = chess.Move.from_uci(uci)
        if move in board.legal_moves:
            board.push(move)
            if board.turn == chess.WHITE:  # full move completed (after Blackâ€™s move)
                move_count += 1
        else:
            return 0  # illegal move, not a valid mate sequence

    return move_count if board.is_checkmate() else 0

In [2]:
# Utility: convert FEN to tensor with positional features
def fen_to_tensor(fen):
    board = chess.Board(fen)
    # Base piece planes: 12 channels
    planes = np.zeros((12, 8, 8), dtype=np.float32)
    for sq, piece in board.piece_map().items():
        idx = {'P':0,'N':1,'B':2,'R':3,'Q':4,'K':5}[piece.symbol().upper()]
        color_offset = 0 if piece.color == chess.WHITE else 6
        row = 7 - (sq // 8)
        col = sq % 8
        planes[idx + color_offset, row, col] = 1

    # Side to move plane
    stm_plane = np.full((1, 8, 8), float(board.turn), dtype=np.float32)

    # Additional positional features: 7 channels
    # 1. Attack maps (white, black)
    attack_w = np.zeros((8, 8), dtype=np.float32)
    attack_b = np.zeros((8, 8), dtype=np.float32)
    for sq in chess.SQUARES:
        r = 7 - (sq // 8)
        c = sq % 8
        if board.attackers(chess.WHITE, sq):
            attack_w[r, c] = 1
        if board.attackers(chess.BLACK, sq):
            attack_b[r, c] = 1

    # 2. Legal move mask
    legal_mask = np.zeros((8, 8), dtype=np.float32)
    for mv in board.legal_moves:
        r = 7 - (mv.to_square // 8)
        c = mv.to_square % 8
        legal_mask[r, c] = 1

    # 3. Distance to kings
    dist_wk = np.zeros((8, 8), dtype=np.float32)
    dist_bk = np.zeros((8, 8), dtype=np.float32)
    wksq = board.king(chess.WHITE)
    bksq = board.king(chess.BLACK)
    for sq in chess.SQUARES:
        r = 7 - (sq // 8)
        c = sq % 8
        if wksq is not None:
            dist_wk[r, c] = chess.square_distance(sq, wksq)
        if bksq is not None:
            dist_bk[r, c] = chess.square_distance(sq, bksq)

    # 4. Check status plane
    check_pl = np.full((8, 8), float(board.is_check()), dtype=np.float32)

    # 5. Pinned pieces map
    pinned = np.zeros((8, 8), dtype=np.float32)
    for sq in chess.SQUARES:
        piece = board.piece_at(sq)
        if piece and board.is_pinned(piece.color, sq):
            r = 7 - (sq // 8)
            c = sq % 8
            pinned[r, c] = 1

    # Stack all planes: 12 + 1 + 2 + 1 + 2 + 1 + 1 = 20 channels
    extra = [attack_w, attack_b, legal_mask, dist_wk, dist_bk, check_pl, pinned]
    feature_planes = np.stack([*extra], axis=0)  # shape (7,8,8)
    all_planes = np.concatenate([planes, stm_plane, feature_planes], axis=0)

    return torch.from_numpy(all_planes)
