In [1]:
import unittest
import chess
import base64

In [2]:


class ChessEncoder:
    def __init__(self):
        self.alive_bits = {
            chess.PAWN: 4,
            chess.ROOK: 2,
            chess.BISHOP: 2,
            chess.KNIGHT: 2,
            chess.QUEEN: 1,
            chess.KING: 1
        }

    def encode_dynamic(self, fen):
        board = chess.Board(fen)
        active_color = format(board.turn, '01b')
        castling_rights = ''
        castling_rights += '1' if board.has_kingside_castling_rights(chess.WHITE) else '0'
        castling_rights += '1' if board.has_queenside_castling_rights(chess.WHITE) else '0'
        castling_rights += '1' if board.has_kingside_castling_rights(chess.BLACK) else '0'
        castling_rights += '1' if board.has_queenside_castling_rights(chess.BLACK) else '0'

        if board.ep_square is not None:
            en_passant = '1'+format(board.ep_square, '06b')
        else:
            en_passant = '0'

        positions = ''
        for piece_type in chess.PIECE_TYPES:
            for color in chess.COLORS:
                pieces = [sq for sq in chess.SQUARES if board.piece_at(sq) is not None and board.piece_at(sq).piece_type == piece_type and board.piece_at(sq).color == color]
                positions += format(len(pieces), '0' + str(self.alive_bits[piece_type]) + 'b')
                for square in pieces:
                    positions += format(square, '06b')

        bit_string = active_color + castling_rights + en_passant + positions
        return bit_string
    
    def decode_dynamic(self, bit_string):
        active_color = int(bit_string[0], 2)
        castling_rights = bit_string[1:5]

        if bit_string[5] == '1':
            en_passant_square = int(bit_string[6:12], 2)  # Corrected to decode 6 bits
            positions = bit_string[12:]  # Corrected to start from 12th bit
        else:
            en_passant_square = None
            positions = bit_string[6:]  # Corrected to start from 7th bit

        board = chess.Board()
        board.clear()

        for piece_type in chess.PIECE_TYPES:
            for color in chess.COLORS:
                num_pieces = int(positions[:self.alive_bits[piece_type]], 2)
                positions = positions[self.alive_bits[piece_type]:]
                for _ in range(num_pieces):
                    square = int(positions[:6], 2)
                    positions = positions[6:]
                    board.set_piece_at(square, chess.Piece(piece_type, color))

        board.ep_square = en_passant_square
        
        board.turn = bool(active_color)
        board.castling_rights = 0
        if castling_rights[0] == '1':
            board.castling_rights |= chess.BB_H1
        if castling_rights[1] == '1':
            board.castling_rights |= chess.BB_A1
        if castling_rights[2] == '1':
            board.castling_rights |= chess.BB_H8
        if castling_rights[3] == '1':
            board.castling_rights |= chess.BB_A8

        return board.fen()
    
    def encode_base64(self, bit_string):
        # Convert the bit string to bytes
        bit_string_bytes = int(bit_string, 2).to_bytes((len(bit_string) + 7) // 8, 'big')

        # Encode the bytes using base64
        base64_string = base64.b64encode(bit_string_bytes).decode()

        return base64_string

    def decode_base64(self, base64_string):
        # Decode the base64 string to bytes
        bit_string_bytes = base64.b64decode(base64_string)

        # Convert the bytes to a bit string
        bit_string = bin(int.from_bytes(bit_string_bytes, 'big'))[2:]

        return bit_string

In [3]:
class TestChessEncoder(unittest.TestCase):
    def setUp(self):
        self.encoder = ChessEncoder()
        self.fen_notations = [
            "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1",  # Initial position
            "rnbqkbnr/pppp1ppp/8/4p3/4P3/5N2/PPPP1PPP/RNBQKB1R b KQkq - 1 2",  # After 1.e4 e5 2.Nf3
            "8/4k3/8/8/8/8/8/4K3 w - - 0 1",  # Only kings on the board
            "rnbqkb1r/pppp1ppp/2n5/4P3/8/8/PPPP1PPP/RNBQKBNR w KQkq - 0 3",  # After 1.e4 e5 2.Nf3 Nc6 3.e5
            "rnbqkbnr/pppp1ppp/8/4Pp2/8/8/PPPP1PPP/RNBQKBNR w KQkq f6 0 3",  # Enpassant capture possible
        ]

    def test_encode_decode_dynamic(self):
        for fen in self.fen_notations:
            with self.subTest(fen=fen):
                encoded = self.encoder.encode_dynamic(fen)
                print(f"FEN: {fen}")
                print(f"FEN notation size: {len(fen)}")
                print(f"Encoded bit size: {len(encoded)}")
                print(f"base64 size: {len(self.encoder.encode_base64(encoded))}")
                print(f"base64: {self.encoder.encode_base64(encoded)}")
                print()
                decoded = self.encoder.decode_dynamic(encoded)
                # Split the FEN strings into sections and compare only the first four
                fen_sections = fen.split(' ')
                decoded_sections = decoded.split(' ')
                self.assertEqual(fen_sections[:4], decoded_sections[:4])

if __name__ == "__main__":
    unittest.main(argv=[''], exit=False)

.
----------------------------------------------------------------------
Ran 1 test in 0.002s

OK


FEN: rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1
FEN notation size: 56
Encoded bit size: 222
base64 size: 40
base64: PoIJKLMNOPjDHLPTXbeBGufoIW69gB7j+H7ifA==

FEN: rnbqkbnr/pppp1ppp/8/4p3/4P3/5N2/PPPP1PPP/RNBQKB1R b KQkq - 1 2
FEN notation size: 62
Encoded bit size: 222
base64 size: 40
base64: HoIJKLNOPciTDHLPXbeBVufoIW69gB7j+H7ifA==

FEN: 8/4k3/8/8/8/8/8/4K3 w - - 0 1
FEN notation size: 29
Encoded bit size: 42
base64 size: 8
base64: AgAAACJ0

FEN: rnbqkb1r/pppp1ppp/2n5/4P3/8/8/PPPP1PPP/RNBQKBNR w KQkq - 0 3
FEN notation size: 60
Encoded bit size: 216
base64 size: 36
base64: +ggkos04+R8Mcs9dt4Eaq5ghbr2AHuP4fuJ8

FEN: rnbqkbnr/pppp1ppp/8/4Pp2/8/8/PPPP1PPP/RNBQKBNR w KQkq f6 0 3
FEN notation size: 60
Encoded bit size: 228
base64 size: 40
base64: D+2CCSizTj5Ilwxyz123gRrn6CFuvYAe4/h+4nw=

