In [94]:
import pandas as pd
import torch
import chess
from itertools import product
from torch.utils.data import Dataset, DataLoader

In [98]:
PIECES = "pPnNbBrRqQkK"
_THEMES = ["advancedPawn", "advantage", "anastasiaMate", "arabianMate", "attackingF2F7", "attraction", "backRankMate", "bishopEndgame", "bodenMate", "capturingDefender", "castling", "clearance", "crushing", "defensiveMove", "deflection", "discoveredAttack", "doubleBishopMate", "doubleCheck", "dovetailMate", "enPassant", "endgame", "equality", "exposedKing", "fork", "hangingPiece", "hookMate", "interference", "intermezzo", "kingsideAttack", "knightEndgame", "long", "master", "masterVsMaster", "mate", "mateIn1", "mateIn2", "mateIn3", "mateIn4", "mateIn5", "middlegame", "oneMove", "opening", "pawnEndgame", "pin", "promotion", "queenEndgame", "queenRookEndgame", "queensideAttack", "quietMove", "rookEndgame", "sacrifice", "short", "skewer", "smotheredMate", "superGM", "trappedPiece", "underPromotion", "veryLong", "xRayAttack", "zugzwang"]

def bitboards_to_fen(bbs: torch.Tensor) -> str:
    """Given a 12x8x8 tensor of pPnNbBrRqQkK bitboards, convert it to FEN"""
    assert bbs.shape == (12, 8, 8)
    board = chess.BaseBoard.empty()

    for p_idx, sq_rank, sq_file in (bbs == 1).nonzero():
        piece = chess.Piece.from_symbol(PIECES[p_idx])
        square = chess.square(sq_file, sq_rank)
        board.set_piece_at(square, piece)
    
    return board.board_fen()


        
def fen_to_bitboards(fen: str) -> torch.Tensor:
    """Given a FEN string, convert it to a 12x8x8 tensor of pPnNbBrRqQkK bitboards."""
    # Defensively split fen string, just in case
    fen = fen.split()[0]
    
    board = chess.BaseBoard(fen)
    bbs = torch.zeros((12, 8, 8))

    for sq, piece in board.piece_map().items():
        sq_rank, sq_file = chess.square_rank(sq), chess.square_file(sq)
        # p -> 0, P -> 1, n -> 2, N -> 3, ... k -> 10, K -> 11
        p_idx = (piece.piece_type - 1) * 2 + piece.color
        
        bbs[p_idx, sq_rank, sq_file] = 1
    
    return bbs



class ChessPuzzleDataset(Dataset):
    def __init__(self, filename="lichess_db_puzzle_processed.csv", themes=_THEMES):
        self.df = pd.read_csv(filename)
        self.themes = themes
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        puzzle = self.df.iloc[idx]
        bitboards = fen_to_bitboards(puzzle["FEN"]) 
        themes = torch.tensor(puzzle[self.themes].values.astype(int))
        return bitboards, themes


In [99]:
train_dataloader = DataLoader(ChessPuzzleDataset("lichess_db_puzzle_train.csv"), batch_size=1)
test_dataloader = DataLoader(ChessPuzzleDataset("lichess_db_puzzle_test.csv"), batch_size=1, shuffle=True)