diff --git a/chess/binary_fen.py b/chess/binary_fen.py new file mode 100644 index 000000000..385b78a2a --- /dev/null +++ b/chess/binary_fen.py @@ -0,0 +1,544 @@ +from __future__ import annotations + +# Almost all code based on: //github.com/lichess-org/scalachess/blob/8c94e2087f83affb9718fd2be19c34866c9a1a22/core/src/main/scala/format/BinaryFen.scala +import typing + +import chess +import chess.variant + +from enum import IntEnum, unique +from typing import Tuple, Optional, List, Union, Iterator, Literal, cast +from dataclasses import dataclass, field +from itertools import zip_longest + +if typing.TYPE_CHECKING: + from typing_extensions import Self, assert_never + +@unique +class ChessHeader(IntEnum): + STANDARD = 0 + CHESS_960 = 2 + FROM_POSITION = 3 + + @classmethod + def from_int_opt(cls, value: int) -> Optional[Self]: + """Convert an integer to a ChessHeader enum member, or return None if invalid.""" + try: + return cls(value) + except ValueError: + return None + +@unique +class VariantHeader(IntEnum): + # chess/std + STANDARD = 0 + CHESS_960 = 2 + FROM_POSITION = 3 + + CRAZYHOUSE = 1 + KING_OF_THE_HILL = 4 + THREE_CHECK = 5 + ANTICHESS = 6 + ATOMIC = 7 + HORDE = 8 + RACING_KINGS = 9 + + def board(self) -> chess.Board: + if self == VariantHeader.CRAZYHOUSE: + return chess.variant.CrazyhouseBoard.empty() + elif self == VariantHeader.KING_OF_THE_HILL: + return chess.variant.KingOfTheHillBoard.empty() + elif self == VariantHeader.THREE_CHECK: + return chess.variant.ThreeCheckBoard.empty() + elif self == VariantHeader.ANTICHESS: + return chess.variant.AntichessBoard.empty() + elif self == VariantHeader.ATOMIC: + return chess.variant.AtomicBoard.empty() + elif self == VariantHeader.HORDE: + return chess.variant.HordeBoard.empty() + elif self == VariantHeader.RACING_KINGS: + return chess.variant.RacingKingsBoard.empty() + # mypy... pyright can use `self in (...)` + elif self == VariantHeader.STANDARD or self == VariantHeader.CHESS_960 or self == VariantHeader.FROM_POSITION: + return chess.Board.empty(chess960=True) + else: + assert_never(self) + + @classmethod + def encode(cls, board: chess.Board) -> VariantHeader: + uci_variant = type(board).uci_variant + if uci_variant == "chess": + # TODO check if this auto mode is OK + root = board.root() + if root in CHESS_960_STARTING_POSITIONS: + return cls.CHESS_960 + elif root == STANDARD_STARTING_POSITION: + return cls.STANDARD + else: + return cls.FROM_POSITION + elif uci_variant == "crazyhouse": + return cls.CRAZYHOUSE + elif uci_variant == "kingofthehill": + return cls.KING_OF_THE_HILL + elif uci_variant == "3check": + return cls.THREE_CHECK + elif uci_variant == "antichess": + return cls.ANTICHESS + elif uci_variant == "atomic": + return cls.ATOMIC + elif uci_variant == "horde": + return cls.HORDE + elif uci_variant == "racingkings": + return cls.RACING_KINGS + else: + raise ValueError(f"Unsupported variant: {uci_variant}") + +CHESS_960_STARTING_POSITIONS = [chess.Board.from_chess960_pos(i) for i in range(960)] +STANDARD_STARTING_POSITION = chess.Board() + +Nibble = Literal[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15] + +# TODO FIXME actually implement __eq__ for variant.pocket? +# not using `chess.variant.CrazyhousePocket` because its __eq__ is wrong for our case +# only using `chess.variant.CrazyhousePocket` public API for now +@dataclass(frozen=True) +class CrazyhousePiecePocket: + pawns: Nibble + knights: Nibble + bishops: Nibble + rooks: Nibble + queens: Nibble + + @classmethod + def from_crazyhouse_pocket(cls, pocket: chess.variant.CrazyhousePocket) -> Self: + return cls( + pawns=_to_nibble(pocket.count(chess.PAWN)), + knights=_to_nibble(pocket.count(chess.KNIGHT)), + bishops=_to_nibble(pocket.count(chess.BISHOP)), + rooks=_to_nibble(pocket.count(chess.ROOK)), + queens=_to_nibble(pocket.count(chess.QUEEN)) + ) + + def to_crazyhouse_pocket(self) -> chess.variant.CrazyhousePocket: + return chess.variant.CrazyhousePocket( + "p"*self.pawns + + "n"*self.knights + + "b"*self.bishops + + "r"*self.rooks + + "q"*self.queens + ) + +@dataclass(frozen=True) +class ThreeCheckData: + white_received_checks: Nibble + black_received_checks: Nibble + +@dataclass(frozen=True) +class CrazyhouseData: + white_pocket: CrazyhousePiecePocket + black_pocket: CrazyhousePiecePocket + promoted: chess.Bitboard + +@dataclass(frozen=True) +class BinaryFen: + """ + A simple binary format that encode a position in a compact way, initially used by Stockfish and Lichess + + See https://lichess.org/@/revoof/blog/adapting-nnue-pytorchs-binary-position-format-for-lichess/cpeeAMeY for more information + """ + occupied: chess.Bitboard + nibbles: List[Nibble] + halfmove_clock: Optional[int] + plies: Optional[int] + variant_header: int + variant_data: Optional[Union[ThreeCheckData, CrazyhouseData]] + + def _halfmove_clock_or_zero(self) -> int: + return self.halfmove_clock if self.halfmove_clock is not None else 0 + def _plies_or_zero(self) -> int: + return self.plies if self.plies is not None else 0 + + def to_canonical(self) -> Self: + """ + Multiple binary FEN can correspond to the same position: + + - When a position has multiple black kings with black to move + - When trailing zeros are omitted from halfmove clock or plies + - When its black to move and the ply is even + + The 'canonical' position is then the one with every king with the turn set + And trailing zeros removed + + Return the canonical version of the binary FEN + """ + is_black_to_move = (15 in self.nibbles) or (self._plies_or_zero() % 2 == 1) + + if is_black_to_move: + canon_nibbles: List[Nibble] = [(15 if nibble == 11 else nibble) for nibble in self.nibbles] + else: + canon_nibbles = self.nibbles.copy() + + is_black_to_move_due_to_plies = (15 not in canon_nibbles) and (self._plies_or_zero() % 2 == 1) + + canon_plies = (self._plies_or_zero() + 1) if (is_black_to_move and self._plies_or_zero() % 2 == 0) else self.plies + canon_halfmove_clock = self.halfmove_clock + + if self.variant_header == 0 and not is_black_to_move_due_to_plies: + # with black to move, ply == 1 add no information, it's the same as ply == None which + # equivalent to ply == 0 + if self._plies_or_zero() <= 1: + canon_plies = None + if self._halfmove_clock_or_zero() == 0: + canon_halfmove_clock = None + + return self.__class__(occupied=self.occupied, + nibbles=canon_nibbles, + halfmove_clock=canon_halfmove_clock, + plies=canon_plies, + variant_header=self.variant_header, + variant_data=self.variant_data + ) + + @classmethod + def parse_from_bytes(cls, data: bytes) -> Self: + """ + Read from bytes and return a BinaryFen + + should not error even if data is invalid + """ + reader = iter(data) + return cls.parse_from_iter(reader) + + @classmethod + def parse_from_iter(cls, reader: Iterator[int]) -> Self: + """ + Read from bytes and return a `BinaryFen` + + should not error even if data is invalid + """ + occupied = _read_bitboard(reader) + + nibbles = [] + iter_occupied = chess.scan_forward(occupied) + for (sq1, sq2) in zip_longest(iter_occupied, iter_occupied): + lo, hi = _read_nibbles(reader) + nibbles.append(lo) + if sq2 is not None: + nibbles.append(hi) + + halfmove_clock = _read_leb128(reader) + plies = _read_leb128(reader) + + variant_header = _next0(reader) + + variant_data: Optional[Union[ThreeCheckData, CrazyhouseData]] = None + if variant_header == VariantHeader.THREE_CHECK: + lo, hi = _read_nibbles(reader) + variant_data = ThreeCheckData(white_received_checks=lo, black_received_checks=hi) + elif variant_header == VariantHeader.CRAZYHOUSE: + wp, bp = _read_nibbles(reader) + wn, bn = _read_nibbles(reader) + wb, bb = _read_nibbles(reader) + wr, br = _read_nibbles(reader) + wq, bq = _read_nibbles(reader) + # optimise? + white_pocket = CrazyhousePiecePocket(pawns=wp, knights=wn, bishops=wb, rooks=wr, queens=wq) + black_pocket = CrazyhousePiecePocket(pawns=bp, knights=bn, bishops=bb, rooks=br, queens=bq) + promoted = _read_bitboard(reader) + variant_data = CrazyhouseData(white_pocket=white_pocket, black_pocket=black_pocket, promoted=promoted) + return cls(occupied=occupied, + nibbles=nibbles, + halfmove_clock=halfmove_clock, + plies=plies, + variant_header=variant_header, + variant_data=variant_data) + + + def to_board(self) -> Tuple[chess.Board, Optional[ChessHeader]]: + """ + Return a chess.Board of the proper variant, and std_mode if applicable + + The returned board might be illegal, check with `board.is_valid()` + + Raise `ValueError` if the BinaryFen data is invalid in a way that chess.Board cannot handle: + - Invalid variant header + - Invalid en passant square + - Multiple en passant squares + """ + std_mode: Optional[ChessHeader] = ChessHeader.from_int_opt(self.variant_header) + + board = VariantHeader(self.variant_header).board() + ep_square_set = False + for sq, nibble in zip(chess.scan_forward(self.occupied), self.nibbles): + if not ep_square_set: + ep_square_set = _unpack_piece(board, sq, nibble) + else: + if _unpack_piece(board, sq, nibble): + raise ValueError("At least two passant squares found") + board.halfmove_clock = self._halfmove_clock_or_zero() + board.fullmove_number = self._plies_or_zero()//2 + 1 + # it is important to write it that way + # because default turn can have been already set to black inside `_unpack_piece` + if self._plies_or_zero() % 2 == 1: + board.turn = chess.BLACK + + if isinstance(board, chess.variant.ThreeCheckBoard) and isinstance(self.variant_data, ThreeCheckData): + # remaining check are for the opposite side + board.remaining_checks[chess.WHITE] = 3 - self.variant_data.black_received_checks + board.remaining_checks[chess.BLACK] = 3 - self.variant_data.white_received_checks + elif isinstance(board, chess.variant.CrazyhouseBoard) and isinstance(self.variant_data, CrazyhouseData): + board.pockets[chess.WHITE] = self.variant_data.white_pocket.to_crazyhouse_pocket() + board.pockets[chess.BLACK] = self.variant_data.black_pocket.to_crazyhouse_pocket() + board.promoted = self.variant_data.promoted + return (board, std_mode) + + + @classmethod + def decode(cls, data: bytes) -> Tuple[chess.Board, Optional[ChessHeader]]: + """ + Read from bytes and return a chess.Board of the proper variant + + If it is standard chess position, also return the mode (standard, chess960, from_position) + + raise `ValueError` if data is invalid + """ + binary_fen = cls.parse_from_bytes(data) + return binary_fen.to_board() + + + @classmethod + def parse_from_board(cls, board: chess.Board, std_mode: Optional[ChessHeader]=None) -> Self: + """ + Given a chess.Board, return its binary FEN representation, and std_mode if applicable + + If the board is a standard chess position, `std_mode` can be provided to specify the mode (standard, chess960, from_position) + if not provided, it will be inferred from the root position + """ + if std_mode is not None and type(board).uci_variant != "chess": + raise ValueError("std_mode can only be provided for standard chess positions") + occupied = board.occupied + iter_occupied = chess.scan_forward(occupied) + nibbles: List[Nibble] = [] + for (sq1, sq2) in zip_longest(iter_occupied, iter_occupied): + lo = _pack_piece(board, sq1) + nibbles.append(lo) + if sq2 is not None: + hi = _pack_piece(board, sq2) + nibbles.append(hi) + + plies = board.ply() + binary_ply = None + binary_halfmove_clock = None + + broken_turn = board.king(chess.BLACK) is None and board.turn == chess.BLACK + variant_header = std_mode.value if std_mode is not None else VariantHeader.encode(board).value + + if board.halfmove_clock > 0 or plies > 1 or broken_turn or variant_header != 0: + binary_halfmove_clock = board.halfmove_clock + + if plies > 1 or broken_turn or variant_header != 0: + binary_ply = plies + + variant_data: Optional[Union[ThreeCheckData, CrazyhouseData]] = None + if variant_header != VariantHeader.STANDARD: + if isinstance(board, chess.variant.ThreeCheckBoard): + black_received_checks = _to_nibble(3 - board.remaining_checks[chess.WHITE]) + white_received_checks = _to_nibble(3 - board.remaining_checks[chess.BLACK]) + variant_data = ThreeCheckData(white_received_checks=white_received_checks, black_received_checks=black_received_checks) + elif isinstance(board, chess.variant.CrazyhouseBoard): + variant_data = CrazyhouseData( + white_pocket=CrazyhousePiecePocket.from_crazyhouse_pocket(board.pockets[chess.WHITE]), + black_pocket=CrazyhousePiecePocket.from_crazyhouse_pocket(board.pockets[chess.BLACK]), + promoted=board.promoted + ) + return cls(occupied=occupied, + nibbles=nibbles, + halfmove_clock=binary_halfmove_clock, + plies=binary_ply, + variant_header=variant_header, + variant_data=variant_data) + + + def to_bytes(self) -> bytes: + """ + Write the BinaryFen data as bytes + """ + builder = bytearray() + _write_bitboard(builder, self.occupied) + iter_nibbles: Iterator[Nibble] = iter(self.nibbles) + for (lo, hi) in zip_longest(iter_nibbles, iter_nibbles,fillvalue=0): + _write_nibbles(builder, cast(Nibble, lo), cast(Nibble, hi)) + + + if self.halfmove_clock is not None: + _write_leb128(builder, self.halfmove_clock) + + if self.plies is not None: + _write_leb128(builder, self.plies) + + if self.variant_header != VariantHeader.STANDARD: + builder.append(self.variant_header) + if isinstance(self.variant_data, ThreeCheckData): + _write_nibbles(builder, self.variant_data.white_received_checks, self.variant_data.black_received_checks) + elif isinstance(self.variant_data, CrazyhouseData): + _write_nibbles(builder, self.variant_data.white_pocket.pawns, self.variant_data.black_pocket.pawns) + _write_nibbles(builder, self.variant_data.white_pocket.knights, self.variant_data.black_pocket.knights) + _write_nibbles(builder, self.variant_data.white_pocket.bishops, self.variant_data.black_pocket.bishops) + _write_nibbles(builder, self.variant_data.white_pocket.rooks, self.variant_data.black_pocket.rooks) + _write_nibbles(builder, self.variant_data.white_pocket.queens, self.variant_data.black_pocket.queens) + + if self.variant_data.promoted: + _write_bitboard(builder, self.variant_data.promoted) + return bytes(builder) + + def __bytes__(self) -> bytes: + """ + Write the BinaryFen data as bytes + + Example: bytes(my_binary_fen) + """ + return self.to_bytes() + + + @classmethod + def encode(cls, board: chess.Board, std_mode: Optional[ChessHeader]=None) -> bytes: + """ + Given a chess.Board, return its binary FEN representation, and std_mode if applicable + + If the board is a standard chess position, `std_mode` can be provided to specify the mode (standard, chess960, from_position) + if not provided, it will be inferred from the root position + """ + binary_fen = cls.parse_from_board(board, std_mode) + return binary_fen.to_bytes() + +def _pack_piece(board: chess.Board, sq: chess.Square) -> Nibble: + # Encoding from + # https://github.com/official-stockfish/nnue-pytorch/blob/2db3787d2e36f7142ea4d0e307b502dda4095cd9/lib/nnue_training_data_formats.h#L4607 + piece = board.piece_at(sq) + if piece is None: + raise ValueError(f"Unreachable: no piece at square {sq}, board: {board}") + if piece.piece_type == chess.PAWN: + if board.ep_square is not None: + rank = chess.square_rank(sq) + if (board.ep_square + 8 == sq and piece.color == chess.WHITE and rank == 3) or (board.ep_square - 8 == sq and piece.color == chess.BLACK and rank == 4): + return 12 + return 0 if piece.color == chess.WHITE else 1 + elif piece.piece_type == chess.KNIGHT: + return 2 if piece.color == chess.WHITE else 3 + elif piece.piece_type == chess.BISHOP: + return 4 if piece.color == chess.WHITE else 5 + elif piece.piece_type == chess.ROOK: + if board.castling_rights & chess.BB_SQUARES[sq]: + return 13 if piece.color == chess.WHITE else 14 + return 6 if piece.color == chess.WHITE else 7 + elif piece.piece_type == chess.QUEEN: + return 8 if piece.color == chess.WHITE else 9 + elif piece.piece_type == chess.KING: + if piece.color == chess.BLACK and board.turn == chess.BLACK: + return 15 + return 10 if piece.color == chess.WHITE else 11 + raise ValueError(f"Unreachable: unknown piece {piece} at square {sq}, board: {board}") + +def _unpack_piece(board: chess.Board, sq: chess.Square, nibble: Nibble) -> bool: + """Return true if set the en passant square""" + if nibble == 0: + board.set_piece_at(sq, chess.Piece(chess.PAWN, chess.WHITE)) + elif nibble == 1: + board.set_piece_at(sq, chess.Piece(chess.PAWN, chess.BLACK)) + elif nibble == 2: + board.set_piece_at(sq, chess.Piece(chess.KNIGHT, chess.WHITE)) + elif nibble == 3: + board.set_piece_at(sq, chess.Piece(chess.KNIGHT, chess.BLACK)) + elif nibble == 4: + board.set_piece_at(sq, chess.Piece(chess.BISHOP, chess.WHITE)) + elif nibble == 5: + board.set_piece_at(sq, chess.Piece(chess.BISHOP, chess.BLACK)) + elif nibble == 6: + board.set_piece_at(sq, chess.Piece(chess.ROOK, chess.WHITE)) + elif nibble == 7: + board.set_piece_at(sq, chess.Piece(chess.ROOK, chess.BLACK)) + elif nibble == 8: + board.set_piece_at(sq, chess.Piece(chess.QUEEN, chess.WHITE)) + elif nibble == 9: + board.set_piece_at(sq, chess.Piece(chess.QUEEN, chess.BLACK)) + elif nibble == 10: + board.set_piece_at(sq, chess.Piece(chess.KING, chess.WHITE)) + elif nibble == 11: + board.set_piece_at(sq, chess.Piece(chess.KING, chess.BLACK)) + elif nibble == 12: + # in scalachess rank starts at 1, python-chess 0 + rank = chess.square_rank(sq) + if rank == 3: + color = chess.WHITE + elif rank == 4: + color = chess.BLACK + else: + raise ValueError(f"Pawn at square {chess.square_name(sq)} cannot be an en passant pawn") + board.ep_square = sq - 8 if color else sq + 8 + board.set_piece_at(sq, chess.Piece(chess.PAWN, color)) + return True + elif nibble == 13: + board.castling_rights |= chess.BB_SQUARES[sq] + board.set_piece_at(sq, chess.Piece(chess.ROOK, chess.WHITE)) + elif nibble == 14: + board.castling_rights |= chess.BB_SQUARES[sq] + board.set_piece_at(sq, chess.Piece(chess.ROOK, chess.BLACK)) + elif nibble == 15: + board.turn = chess.BLACK + board.set_piece_at(sq, chess.Piece(chess.KING, chess.BLACK)) + else: + raise ValueError(f"Impossible nibble value: {nibble} at square {chess.square_name(sq)}") + return False + +def _next0(reader: Iterator[int]) -> int: + return next(reader, 0) + +def _read_bitboard(reader: Iterator[int]) -> chess.Bitboard: + bb = chess.BB_EMPTY + for _ in range(8): + bb = (bb << 8) | (_next0(reader) & 0xFF) + return bb + +def _write_bitboard(data: bytearray, bb: chess.Bitboard) -> None: + for shift in range(56, -1, -8): + data.append((bb >> shift) & 0xFF) + +def _read_nibbles(reader: Iterator[int]) -> Tuple[Nibble, Nibble]: + byte = _next0(reader) + return cast(Nibble, byte & 0x0F), cast(Nibble, (byte >> 4) & 0x0F) + +def _write_nibbles(data: bytearray, lo: Nibble, hi: Nibble) -> None: + data.append((hi << 4) | (lo & 0x0F)) + +def _read_leb128(reader: Iterator[int]) -> Optional[int]: + result = 0 + shift = 0 + while True: + byte = next(reader, None) + if byte is None: + return None + result |= (byte & 127) << shift + if (byte & 128) == 0: + break + shift += 7 + # this is useless + return result & 0x7fff_ffff + +def _write_leb128(data: bytearray, value: int) -> None: + while True: + byte = value & 127 + value >>= 7 + if value != 0: + byte |= 128 + data.append(byte) + if value == 0: + break + +def _to_nibble(value: int) -> Nibble: + if 0 <= value <= 15: + return cast(Nibble, value) + else: + raise ValueError(f"Value {value} cannot be represented as a nibble") + + + diff --git a/fuzz/binary_fen.py b/fuzz/binary_fen.py new file mode 100644 index 000000000..76f33e974 --- /dev/null +++ b/fuzz/binary_fen.py @@ -0,0 +1,27 @@ +import chess.binary_fen +from chess.binary_fen import BinaryFen + +from pythonfuzz.main import PythonFuzz + + +@PythonFuzz +def fuzz(buf): + binary_fen = BinaryFen.parse_from_bytes(buf) + try: + board, std_mode = binary_fen.to_board() + except ValueError: + pass + else: + board.status() + list(board.legal_moves) + binary_fen2 = BinaryFen.parse_from_board(board,std_mode=std_mode) + encoded = binary_fen2.to_bytes() + board2, std_mode2 = BinaryFen.decode(encoded) + assert board == board2 + assert binary_fen2 == binary_fen2.to_canonical(), "from_board should be canonical" + assert binary_fen.to_canonical() == binary_fen2.to_canonical() + assert std_mode == std_mode2 + + +if __name__ == "__main__": + fuzz() diff --git a/fuzz/corpus/binary_fen/000-std b/fuzz/corpus/binary_fen/000-std new file mode 100644 index 000000000..80bb1b63e --- /dev/null +++ b/fuzz/corpus/binary_fen/000-std @@ -0,0 +1 @@ +0000000000000000 \ No newline at end of file diff --git a/fuzz/corpus/binary_fen/001-std b/fuzz/corpus/binary_fen/001-std new file mode 100644 index 000000000..c1a10a1f5 --- /dev/null +++ b/fuzz/corpus/binary_fen/001-std @@ -0,0 +1 @@ +00000000000000000001 \ No newline at end of file diff --git a/fuzz/corpus/binary_fen/002-std b/fuzz/corpus/binary_fen/002-std new file mode 100644 index 000000000..3695a1ca3 --- /dev/null +++ b/fuzz/corpus/binary_fen/002-std @@ -0,0 +1 @@ +000000000000000064df06 \ No newline at end of file diff --git a/fuzz/corpus/binary_fen/003-std b/fuzz/corpus/binary_fen/003-std new file mode 100644 index 000000000..d7626d88f --- /dev/null +++ b/fuzz/corpus/binary_fen/003-std @@ -0,0 +1 @@ +ffff00001000efff2d844ad200000000111111113e955fe3 \ No newline at end of file diff --git a/fuzz/corpus/binary_fen/004-std b/fuzz/corpus/binary_fen/004-std new file mode 100644 index 000000000..bf3f7c55b --- /dev/null +++ b/fuzz/corpus/binary_fen/004-std @@ -0,0 +1 @@ +20400006400000080ac0b1 \ No newline at end of file diff --git a/fuzz/corpus/binary_fen/005-std b/fuzz/corpus/binary_fen/005-std new file mode 100644 index 000000000..3ff911238 --- /dev/null +++ b/fuzz/corpus/binary_fen/005-std @@ -0,0 +1 @@ +10000000180040802ac10f \ No newline at end of file diff --git a/fuzz/corpus/binary_fen/006-std b/fuzz/corpus/binary_fen/006-std new file mode 100644 index 000000000..f42d53eb4 --- /dev/null +++ b/fuzz/corpus/binary_fen/006-std @@ -0,0 +1 @@ +00000002180000308a1c0f030103 \ No newline at end of file diff --git a/fuzz/corpus/binary_fen/007-koth b/fuzz/corpus/binary_fen/007-koth new file mode 100644 index 000000000..426bcdf0c --- /dev/null +++ b/fuzz/corpus/binary_fen/007-koth @@ -0,0 +1 @@ +ffff00000000ffff2d844ad200000000111111113e955be3000004 \ No newline at end of file diff --git a/fuzz/corpus/binary_fen/008-3c b/fuzz/corpus/binary_fen/008-3c new file mode 100644 index 000000000..94fb5dd91 --- /dev/null +++ b/fuzz/corpus/binary_fen/008-3c @@ -0,0 +1 @@ +ffff00000000ffff2d844ad200000000111111113e955be363000501 \ No newline at end of file diff --git a/fuzz/corpus/binary_fen/009-anti b/fuzz/corpus/binary_fen/009-anti new file mode 100644 index 000000000..de63122f3 --- /dev/null +++ b/fuzz/corpus/binary_fen/009-anti @@ -0,0 +1 @@ +00800000000008001a000106 \ No newline at end of file diff --git a/fuzz/corpus/binary_fen/010-atom b/fuzz/corpus/binary_fen/010-atom new file mode 100644 index 000000000..ccf03c375 --- /dev/null +++ b/fuzz/corpus/binary_fen/010-atom @@ -0,0 +1 @@ +ffff00000000ffff2d844ad200000000111111113e955be3020407 \ No newline at end of file diff --git a/fuzz/corpus/binary_fen/011-hord b/fuzz/corpus/binary_fen/011-hord new file mode 100644 index 000000000..77d7ab3ce --- /dev/null +++ b/fuzz/corpus/binary_fen/011-hord @@ -0,0 +1 @@ +ffff0066ffffffff000000000000000000000000000000000000111111113e955be3000008 \ No newline at end of file diff --git a/fuzz/corpus/binary_fen/012-rk b/fuzz/corpus/binary_fen/012-rk new file mode 100644 index 000000000..269d2f1d5 --- /dev/null +++ b/fuzz/corpus/binary_fen/012-rk @@ -0,0 +1 @@ +000000000000ffff793542867b3542a6000009 \ No newline at end of file diff --git a/fuzz/corpus/binary_fen/013-zh b/fuzz/corpus/binary_fen/013-zh new file mode 100644 index 000000000..9dc886f46 --- /dev/null +++ b/fuzz/corpus/binary_fen/013-zh @@ -0,0 +1 @@ +ffff00000000ffff2d844ad200000000111111113e955be300e407010000000000ef0000000000002a \ No newline at end of file diff --git a/fuzz/corpus/binary_fen/014-z b/fuzz/corpus/binary_fen/014-z new file mode 100644 index 000000000..1f7f7f896 --- /dev/null +++ b/fuzz/corpus/binary_fen/014-z @@ -0,0 +1 @@ +ffff00000000ffff2d844ad200000000111111113e955be30000010000000000 \ No newline at end of file diff --git a/test_binary_fen.py b/test_binary_fen.py new file mode 100644 index 000000000..5d582109c --- /dev/null +++ b/test_binary_fen.py @@ -0,0 +1,338 @@ +#!/usr/bin/env python3 + +# Almost all tests adapted from https://github.com/lichess-org/scalachess/blob/8c94e2087f83affb9718fd2be19c34866c9a1a22/test-kit/src/test/scala/format/BinaryFenTest.scala#L1 + +import asyncio +import copy +import logging +import os +import os.path +import platform +import sys +import tempfile +import textwrap +import unittest +import io + +import chess +import chess.variant +import chess.binary_fen + +from dataclasses import asdict + +from chess import Board +from chess.binary_fen import BinaryFen, ChessHeader, VariantHeader + +KOTH = chess.variant.KingOfTheHillBoard +THREE_CHECKS = chess.variant.ThreeCheckBoard +ANTI = chess.variant.AntichessBoard +ATOMIC = chess.variant.AtomicBoard +HORDE = chess.variant.HordeBoard +RK = chess.variant.RacingKingsBoard +ZH = chess.variant.CrazyhouseBoard + +class BinaryFenTestCase(unittest.TestCase): + + def test_nibble_roundtrip(self): + for lo in range(16): + for hi in range(16): + data = bytearray() + chess.binary_fen._write_nibbles(data, lo, hi) + read_lo, read_hi = chess.binary_fen._read_nibbles(iter(data)) + self.assertEqual(lo, read_lo) + self.assertEqual(hi, read_hi) + + def test_std_mode_eq(self): + self.assertEqual(ChessHeader.STANDARD,ChessHeader.from_int_opt(0)) + + def test_bitboard_roundtrip(self): + test_bitboards = [ + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x1234567890ABCDEF, + 0x0F0F0F0F0F0F0F0F, + 0xF0F0F0F0F0F0F0F0, + 0x8000000000000001, + 0x7FFFFFFFFFFFFFFE, + ] + for bb in test_bitboards: + data = bytearray() + chess.binary_fen._write_bitboard(data, bb) + read_bb = chess.binary_fen._read_bitboard(iter(data)) + self.assertEqual(bb, read_bb) + + def test_leb128_roundtrip(self): + test_values = [ + 0, + 1, + 3, + 127, + 128, + 255, + 16384, + 2097151, + 268435455, + 2147483647, + ] + for value in test_values: + data = bytearray() + chess.binary_fen._write_leb128(data, value) + read_value = chess.binary_fen._read_leb128(iter(data)) + self.assertEqual(value, read_value) + + def test_to_canonical_1(self): + # illegal position, but it should not matter + canon = BinaryFen( + occupied=chess.BB_A1 | chess.BB_B1 | chess.BB_C1, + nibbles = [15, 15, 15], + halfmove_clock=3, + plies=5, + variant_header=ChessHeader.STANDARD.value, + variant_data=None, + ) + cases = [BinaryFen( + occupied=chess.BB_A1 | chess.BB_B1 | chess.BB_C1, + nibbles = [11, 15, 11], + halfmove_clock=3, + plies=4, + variant_header=ChessHeader.STANDARD.value, + variant_data=None + ), + BinaryFen( + occupied=chess.BB_A1 | chess.BB_B1 | chess.BB_C1, + nibbles = [15, 15, 11], + halfmove_clock=3, + plies=4, + variant_header=ChessHeader.STANDARD.value, + variant_data=None + ), + BinaryFen( + occupied=chess.BB_A1 | chess.BB_B1 | chess.BB_C1, + nibbles = [11, 15, 15], + halfmove_clock=3, + plies=4, + variant_header=ChessHeader.STANDARD.value, + variant_data=None + ), + ] + for case in cases: + with self.subTest(case=case): + self.assertNotEqual(canon, case) + canon_case = case.to_canonical() + self.assertEqual(canon, canon_case) + + def test_to_canonical_2(self): + # illegal position, but it should not matter + canon = BinaryFen( + occupied=chess.BB_A1 | chess.BB_B1 | chess.BB_C1, + nibbles = [15, 15, 15], + halfmove_clock=3, + plies=5, + variant_header=ChessHeader.STANDARD.value, + variant_data=None, + ) + cases = [BinaryFen( + occupied=chess.BB_A1 | chess.BB_B1 | chess.BB_C1, + nibbles = [11, 15, 11], + halfmove_clock=3, + plies=5, + variant_header=ChessHeader.STANDARD.value, + variant_data=None + ), + BinaryFen( + occupied=chess.BB_A1 | chess.BB_B1 | chess.BB_C1, + nibbles = [15, 15, 11], + halfmove_clock=3, + plies=5, + variant_header=ChessHeader.STANDARD.value, + variant_data=None + ), + BinaryFen( + occupied=chess.BB_A1 | chess.BB_B1 | chess.BB_C1, + nibbles = [11, 15, 15], + halfmove_clock=3, + plies=5, + variant_header=ChessHeader.STANDARD.value, + variant_data=None + ), + BinaryFen( + occupied=chess.BB_A1 | chess.BB_B1 | chess.BB_C1, + nibbles = [11, 11, 11], + halfmove_clock=3, + plies=5, + variant_header=ChessHeader.STANDARD.value, + variant_data=None + ), + ] + for case in cases: + with self.subTest(case=case): + self.assertNotEqual(canon, case) + canon_case = case.to_canonical() + self.assertEqual(canon, canon_case) + + def test_binary_fen_roundtrip(self): + cases = [ + Board(fen="8/8/8/8/8/8/8/8 w - - 0 1"), + Board(fen="8/8/8/8/8/8/8/8 b - - 0 1"), + Board(fen="8/8/8/8/8/8/8/7k b - - 0 1"), + Board(fen="8/8/8/8/8/8/8/8 w - - 0 2"), + Board(fen="8/8/8/8/8/8/8/8 b - - 0 2"), + Board(fen="8/8/8/8/8/8/8/8 b - - 100 432"), + Board(fen="rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1"), + Board(fen="4nrk1/1pp3pp/p4p2/4P3/2BB1n2/8/PP3P1P/2K3R1 b - - 1 25"), + Board(fen="4nrk1/1pp3pp/p4p2/4P3/2BB1n2/8/PP3P1P/2K3R1 b - - 1 25"), + Board(fen="5k2/6p1/8/1Pp5/6P1/8/8/3K4 w - c6 0 1"), + Board(fen="4k3/8/8/8/3pP3/8/6N1/7K b - e3 0 1"), + Board(fen="r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 1"), + Board(fen="r1k1r2q/p1ppp1pp/8/8/8/8/P1PPP1PP/R1K1R2Q w KQkq - 0 1",chess960=True), + Board(fen="r1k2r1q/p1ppp1pp/8/8/8/8/P1PPP1PP/R1K2R1Q w KQkq - 0 1", chess960=True), + Board(fen="8/8/8/4B2b/6nN/8/5P2/2R1K2k w Q - 1 1", chess960=True), + Board(fen="2r5/8/8/8/8/8/6PP/k2KR3 w K - 0 2", chess960=True), + Board(fen="4r3/3k4/8/8/8/8/6PP/qR1K1R2 w KQ - 2 1", chess960=True), + Board(fen="4rrk1/pbbp2p1/1ppnp3/3n1pqp/3N1PQP/1PPNP3/PBBP2P1/4RRK1 w Ff - 0 3", chess960=True), + Board(fen="8/8/8/1k6/3Pp3/8/8/4KQ2 b - d3 3 1"), + Board(fen="r2r3k/p7/3p4/8/8/P6P/8/R3K2R b KQq - 0 4",chess960=True), + Board(fen="rn2k1r1/ppp1pp1p/3p2p1/5bn1/P7/2N2B2/1PPPPP2/2BNK1RR w Gkq - 4 11", chess960=True), + Board(fen="8/8/8/8/8/8/2Rk4/1K6 w - - 0 1"), + + HORDE(fen="rn1qkb1r/3bn1p1/2p3P1/pPP2P2/P1PPP1P1/P1PP1PPP/PPPPPPPP/PPPPPPPP w kq a6 0 12"), + + ANTI("8/2p1p2p/2Q1N2B/8/p7/N7/PPP1P1PP/R4B1R b - - 0 13"), + ANTI("8/p6p/4p3/1P4P1/Pp4p1/3P4/7P/8 b - a3 0 1"), + ANTI("8/p6p/4p3/1P4P1/1p4pP/3P4/P7/8 b - h3 0 1"), + ANTI("8/7p/4p3/pP4P1/1p1P2p1/8/P6P/8 w - a6 0 2"), + ANTI("8/p7/4p3/1P4Pp/1p1P2p1/8/P6P/8 w - h6 0 2"), + + ATOMIC(fen="rnbq3r/ppp1p1pp/5p2/3p4/8/8/PPPPPPPP/RNBQKB1R b KQ - 0 4"), + ATOMIC(fen="8/6pp/2p2p1n/3p4/4P3/B6P/3P1PP1/1r2K2R b K - 0 17"), + + RK(fen="8/8/8/8/8/8/krbnNBRK/qrbnNBRQ w - - 0 1"), + RK(fen="8/8/8/8/8/6K1/krbnNBR1/qrbnNBRQ b - - 1 1"), + + KOTH(fen="rnbq1bnr/ppp2ppp/3k4/4p2Q/3PK3/8/PPP2PPP/RNB2BNR b - - 0 7"), + + THREE_CHECKS(fen="1r3rk1/pbp1N1pp/3p1q2/1p2bp2/7P/2PBB1P1/PP3Q1R/R5K1 b - - 3 21 +2+1"), + + ZH(fen="1r3Q1n/p1kp3p/1p2ppq1/2p2b2/8/3P2P1/PPP1PPBP/R4RK1/NRpnnbb w - - 2 28"), + ZH(fen="b2nkbnQ~/p1pppp1p/pP1q2p1/r7/8/R5PR/P1PP1P1P/1NBQ1BNK/R w - - 1 2"), + ZH(fen="8/8/8/8/8/8/8/8/ w - - 0 1"), + ZH(fen="r~n~b~q~kb~n~r~/pppppppp/8/8/8/8/PPPPPPPP/RN~BQ~KB~NR/ w KQkq - 0 1"), + ] + for case in cases: + case_fen = case.fen() + with self.subTest(fen=case_fen): + bin_fen = BinaryFen.parse_from_board(case) + bin_fen2 = BinaryFen.parse_from_bytes(bin_fen.to_bytes()) + self.assertEqual(bin_fen2, bin_fen2.to_canonical(), "from_bytes should produce canonical value") + self.assertEqual(bin_fen.to_canonical(), bin_fen2.to_canonical()) + decoded, _ = bin_fen2.to_board() + self.assertEqual(case, decoded) + + + # tests that failed the fuzzer at some point + def test_fuzzer_fail(self): + fuzz_fails = [ + "23d7", + "e17f11efd84522d34878ffffffa600000000ce1b23ffff000943", + "20f7076f1718f99824a5020724b3cfc1020146ae00004f85ae28aebc", + "edf9b3c5cb7fa5008000004081c83e4092a7e63dd95a", + "f7cef6e64ed47a4ede172a100000009b004c909b", + "bb7cb00cc3f31dc3f325b8", + "4584aced8100da50a20bd7251705a15b108000251705", + "77ff05111f77111f4214e803647fff6429f0a2f65933310185016400000045bf1e8be6b013ed02", + "55d648e9a20fd600400000e9a29c0010043b26fb41d50a50", + "d8805347e76003102228687fffff41b19e2bff00000100020220c6" + ] + for fuzz_fail in fuzz_fails: + with self.subTest(fuzz_fail=fuzz_fail): + data = bytes.fromhex(fuzz_fail) + binary_fen = BinaryFen.parse_from_bytes(data) + try: + board, std_mode = binary_fen.to_board() + except ValueError: + continue + # print("binary_fen", binary_fen) + # print("ep square", board.ep_square) + # print("fullmove", board.fullmove_number) + # print("halfmove_clock", board.halfmove_clock) + # print("fen", board.fen()) + # print() + # should not error + board.status() + list(board.legal_moves) + binary_fen2 = BinaryFen.parse_from_board(board,std_mode=std_mode) + # print("encoded", binary_fen2.to_bytes().hex()) + # print("binary_fen2", binary_fen2) + # dbg(binary_fen, binary_fen2) + # print("CANONICAL") + # dbg(binary_fen.to_canonical(), binary_fen) + self.assertEqual(binary_fen2, binary_fen2.to_canonical(), "from board should produce canonical value") + self.assertEqual(binary_fen.to_canonical(), binary_fen2.to_canonical()) + board2, std_mode2 = binary_fen2.to_board() + self.assertEqual(board, board2) + self.assertEqual(std_mode, std_mode2) + + + def test_read_binary_fen_std(self): + test_cases = [ + ("0000000000000000", "8/8/8/8/8/8/8/8 w - - 0 1"), + ("00000000000000000001", "8/8/8/8/8/8/8/8 b - - 0 1"), + ("000000000000000064df06", "8/8/8/8/8/8/8/8 b - - 100 432"), + ("ffff00001000efff2d844ad200000000111111113e955fe3", "rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1"), + ("20400006400000080ac0b1", "5k2/6p1/8/1Pp5/6P1/8/8/3K4 w - c6 0 1"), + ("10000000180040802ac10f", "4k3/8/8/8/3pP3/8/6N1/7K b - e3 0 1"), + # TODO FIXME, this is encoded with `standard` variant but with chess960 castling + # should this be accepted? for now basing on scalachess behavior + ("8901080000810091ad0d10e1f70007", "r2r3k/p7/3p4/8/8/P6P/8/R3K2R b KQq - 0 4"), + + ("95dd00000000dd95ad8d000000111111be9e", "r1k1r2q/p1ppp1pp/8/8/8/8/P1PPP1PP/R1K1R2Q w KQkq - 0 1"), + ("00000002180000308a1c0f030103", "8/8/8/1k6/3Pp3/8/8/4KQ2 b - d3 3 1") + ] + for binary_fen, expected_fen in test_cases: + with self.subTest(binary_fen=binary_fen, expected_fen=expected_fen): + self.check_binary(binary_fen, expected_fen) + + + # for python-chess, 960 is handled the same as std + def test_read_binary_fen_960(self): + test_cases = [("704f1ee8e81e4f70d60a44000002020813191113511571be000402", "4rrk1/pbbp2p1/1ppnp3/3n1pqp/3N1PQP/1PPNP3/PBBP2P1/4RRK1 w Ff - 0 3")] + for binary_fen, expected_fen in test_cases: + with self.subTest(binary_fen=binary_fen, expected_fen=expected_fen): + self.check_binary(binary_fen, expected_fen) + + def test_read_binary_fen_variants(self): + test_cases = [("ffff00000000ffff2d844ad200000000111111113e955be3000004", "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1", KOTH), + ("ffff00000000ffff2d844ad200000000111111113e955be363000501", "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 99 1 +0+1", THREE_CHECKS), + ("00800000000008001a000106", "8/7p/8/8/8/8/3K4/8 b - - 0 1", ANTI), + ("ffff00000000ffff2d844ad200000000111111113e955be3020407", "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 2 3", ATOMIC), + ("ffff0066ffffffff000000000000000000000000000000000000111111113e955be3000008", "rnbqkbnr/pppppppp/8/1PP2PP1/PPPPPPPP/PPPPPPPP/PPPPPPPP/PPPPPPPP w kq - 0 1", HORDE), + ("000000000000ffff793542867b3542a6000009", "8/8/8/8/8/8/krbnNBRK/qrbnNBRQ w - - 0 1", RK), + ("ffff00000000ffff2d844ad200000000111111113e955be300e407010000000000ef0000000000002a", "r~n~b~q~kb~n~r~/pppppppp/8/8/8/8/PPPPPPPP/RN~BQ~KB~NR/ w KQkq - 0 499", ZH), + ("ffff00000000ffff2d844ad200000000111111113e955be30000010000000000", "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR/ w KQkq - 0 1", ZH) + ] + for binary_fen_str, expected_fen, variant in test_cases: + with self.subTest(binary_fen=binary_fen_str, expected_fen=expected_fen, variant=variant): + self.check_binary(binary_fen_str, expected_fen, variant) + + + def check_binary(self, binary_fen_str, expected_fen, variant = None): + compressed = bytes.fromhex(binary_fen_str) + board, std_mode = BinaryFen.decode(compressed) + binary_fen1 = BinaryFen.parse_from_bytes(compressed) + from_fen = chess.Board(fen=expected_fen, chess960=True) if variant is None else variant(fen=expected_fen) + encoded = BinaryFen.encode(board,std_mode=std_mode) + binary_fen2 = BinaryFen.parse_from_board(board,std_mode=std_mode) + self.maxDiff = None + self.assertEqual(binary_fen2, binary_fen2.to_canonical(), "from board should produce canonical value") + self.assertEqual(binary_fen1.to_canonical(), binary_fen2.to_canonical()) + self.assertEqual(board, from_fen) + self.assertEqual(encoded.hex(), compressed.hex()) + +def dbg(a, b): + from pprint import pprint + from deepdiff import DeepDiff + pprint(DeepDiff(a, b),indent=2) + +if __name__ == "__main__": + print("#"*80) + unittest.main() \ No newline at end of file