In [None]:
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import chess
from tqdm import tqdm
from itertools import product
from torch.utils.data import Dataset, DataLoader

In [None]:
PIECES = "pPnNbBrRqQkK"
_THEMES = ['advancedPawn', 'advantage', 'anastasiaMate', 'arabianMate', 'attackingF2F7', 'attraction', 'backRankMate', 'bishopEndgame', 'bodenMate', 'capturingDefender', 'castling', 'clearance', 'crushing', 'defensiveMove', 'deflection', 'discoveredAttack', 'doubleBishopMate', 'doubleCheck', 'dovetailMate', 'enPassant', 'endgame', 'equality', 'exposedKing', 'fork', 'hangingPiece', 'hookMate', 'interference', 'intermezzo', 'kingsideAttack', 'knightEndgame', 'long', 'master', 'masterVsMaster', 'mate', 'mateIn1', 'mateIn2', 'mateIn3', 'mateIn4', 'mateIn5', 'middlegame', 'oneMove', 'opening', 'pawnEndgame', 'pin', 'promotion', 'queenEndgame', 'queenRookEndgame', 'queensideAttack', 'quietMove', 'rookEndgame', 'sacrifice', 'short', 'skewer', 'smotheredMate', 'superGM', 'trappedPiece', 'underPromotion', 'veryLong', 'xRayAttack', 'zugzwang']
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def bitboards_to_fen(bbs: torch.Tensor) -> str:
    """Given a 12x8x8 tensor of pPnNbBrRqQkK bitboards, convert it to FEN"""
    assert bbs.shape == (12, 8, 8)
    board = chess.BaseBoard.empty()

    for p_idx, sq_rank, sq_file in (bbs == 1).nonzero():
        piece = chess.Piece.from_symbol(PIECES[p_idx])
        square = chess.square(sq_file, sq_rank)
        board.set_piece_at(square, piece)
    
    return board.board_fen()


        
def fen_to_bitboards(fen: str) -> torch.Tensor:
    """Given a FEN string, convert it to a 12x8x8 tensor of pPnNbBrRqQkK bitboards."""
    # Defensively split fen string, just in case
    board = chess.BaseBoard(fen)
    bbs = torch.zeros((12, 8, 8))

    for sq, piece in board.piece_map().items():
        sq_rank, sq_file = chess.square_rank(sq), chess.square_file(sq)
        # p -> 0, P -> 1, n -> 2, N -> 3, ... k -> 10, K -> 11
        p_idx = (piece.piece_type - 1) * 2 + piece.color
        
        bbs[p_idx, sq_rank, sq_file] = 1
    
    return bbs



class ChessPuzzleDataset(Dataset):
    def __init__(self, filename="lichess_db_puzzle_processed.csv", themes=_THEMES):
        self.df = pd.read_csv(filename)
        self.themes = themes
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        puzzle = self.df.iloc[idx]
        bitboards = fen_to_bitboards(puzzle["white_to_play_FEN"]) 
        themes = torch.tensor(puzzle[self.themes].values.astype(int))
        return bitboards, themes


In [None]:
train_dataloader = DataLoader(ChessPuzzleDataset("lichess_db_puzzle_train.csv"), batch_size=128)
test_dataloader = DataLoader(ChessPuzzleDataset("lichess_db_puzzle_test.csv"), batch_size=128, shuffle=True)

In [None]:
class ChessNN(nn.Module):
    def __init__(self,
        prep_channels: list,
        prep_kernel_sizes: list,
        prep_activation,
        transformer_nhead,
        transformer_dim_ff,
        transformer_layer_count,
        label_num=len(_THEMES)):

        super().__init__()
        # Define preprocessing with conv2ds
        # 12 is the number of chess pieces magic number
        channels = zip([12] + prep_channels, prep_channels)
        self.prep = nn.Sequential(
                *[nn.Sequential(
                    nn.Conv2d(c_in, c_out, kernel_size=1),
                    prep_activation,
                    nn.BatchNorm2d(c_out),
                    *[nn.Sequential(
                        nn.Conv2d(c_out, c_out, kernel_size=k, padding="same", padding_mode="circular"),
                        prep_activation,
                        nn.BatchNorm2d(c_out)
                        ) for k in prep_kernel_sizes
                    ],
                )
                for c_in, c_out in channels]
        )

        # Define transformer stack
        H = prep_channels[-1]
        enc_layer = nn.TransformerEncoderLayer(
            d_model=H,
            nhead=transformer_nhead,
            dim_feedforward=transformer_nhead,
            dropout=0.1,
            activation=nn.ReLU(),
            batch_first=True,
        )
        self.transformer = nn.TransformerEncoder(enc_layer, transformer_layer_count)
        
        # Transform to labels now
        self.fc = nn.Linear(H*64, label_num)

        # Tracking losses
        self.batch_losses = []
        self.epoch_train_losses = []
        self.epoch_test_losses = []

    def forward(self, x):
        # x.shape = B,12,8,8
        h = self.prep(x)
        # h.shape = B,H,8,8
        h1 = torch.flatten(h, 2, 3).mT
        # h1.shape = B,64,H
        h2 = self.transformer(h1)
        # h2.shape = B,64,H
        h3 = torch.flatten(h2, 1, 2)
        # h3.shape = B,64H
        themes = self.fc(h3)

        return themes

    @staticmethod
    def to_labels(y):
        return [
            (_THEMES[idx.item()], v.item()) 
            for v, idx in zip(*torch.topk(torch.squeeze(F.sigmoid(y)), 5))
        ]

    
    def total_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


In [None]:
CUSTOM_FENS = [
    ("6k1/5ppp/Q4q2/8/5b2/5N2/1rP2PPP/2R3K1", "backrank, M2"),
    ("8/3brpr1/4pRpk/pp1pP1Rp/2pP3P/P1P5/2P1B1P1/6K1", "pin, M1"),
    ("r3r1n1/bp6/p2p2kp/3N4/2P3n1/1PQ3Pq/P4P2/4RRK1", "fork")
]
def train_loop(model: nn.Module, epochs: int, param_dict):
    print("Total params:", model.total_params())
    optim = torch.optim.Adam(model.parameters(), lr=1e-5)

    model.to(DEVICE)
    
    for t in range(epochs):
        model.train()
        batch_losses = []
        # Iterate over training dataset
        for b, (data, themes) in enumerate(tbatch := tqdm(train_dataloader, unit="batch")):
            data = data.to(DEVICE)
            themes = themes.to(DEVICE)

            pred_themes = model(data)
            loss = F.binary_cross_entropy_with_logits(pred_themes, themes.float())

            optim.zero_grad()
            loss.backward()
            optim.step()

            loss_item = loss.item()
            batch_losses.append(loss_item)
            model.batch_losses.append(loss_item)
            if b % 10 == 0:
                # Get prediction for custom FENs too
                model.eval()

                fens = {}
                for f, theme in CUSTOM_FENS:
                    bb = fen_to_bitboards(f)[None]
                    p_themes = model.to_labels(model(bb.to(DEVICE)))
                    fens[theme] = [f"{k}:{v:.3f}" for k, v in p_themes]
                model.train()

                tbatch.set_postfix(loss=loss_item, **fens)
            
        model.eval()
        batch_test_losses = []
        # Iterate over test dataset
        for b, (data, themes) in enumerate(tbatch := tqdm(test_dataloader, unit="batch")):
            data = data.to(DEVICE)
            themes = themes.to(DEVICE)

            pred_themes = model(data)
            loss = F.binary_cross_entropy_with_logits(pred_themes, themes.float())

            loss_item = loss.item()
            batch_test_losses.append(loss_item)
            if b % 10 == 0:
                tbatch.set_postfix(loss=loss_item)

        # Save the model
        torch.save({
            "epoch": t,
            "model_state_dict": model.state_dict(),
            "optimiser_state_dict": optim.state_dict(),
            "model_param_dict": param_dict,
        }, f"model_{t}.pth")

        # Average loss in epoch
        avg_loss = sum(batch_losses) / len(batch_losses) 
        avg_test_loss = sum(batch_test_losses) / len(batch_test_losses)
        model.epoch_train_losses.append(avg_loss)
        model.epoch_test_losses.append(avg_test_loss)


param_dict = {
    "prep_channels": [64, 128, 256, 512],
    "prep_kernel_sizes": [8],
    "prep_activation": nn.ReLU(),
    "transformer_nhead": 128,
    "transformer_dim_ff": 8192,
    "transformer_layer_count": 8
}
model = ChessNN(**param_dict)
train_loop(model, 10, param_dict)

In [None]:
model = ChessNN([64, 128, 256, 512], nn.ReLU())
checkpoint = torch.load("model_8.pth")
model.load_state_dict(checkpoint["model_state_dict"])
model.to(DEVICE)


In [None]:
fens = [
    "6k1/5ppp/Q4q2/8/5b2/5N2/1rP2PPP/2R3K1", # backrank, M2
    "8/3brpr1/4pRpk/pp1pP1Rp/2pP3P/P1P5/2P1B1P1/6K1", # pin, M1
    "r3r1n1/bp6/p2p2kp/3N4/2P3n1/1PQ3Pq/P4P2/4RRK1", # fork
    "r1b1k1nr/pppp2pp/4Pq2/2b5/4p3/8/PPPP2PP/RNBQKB1R", # queen fork
    "1k6/pp3p2/b4Npp/4r3/4P3/5P1K/Pr5P/2R5", # fork AND backrank M2
    "2R5/4bppk/1p1p4/5R1P/4PQ2/5P2/r4q1P/7K",
    "5R2/bp4pk/2n3p1/P7/P1q3bP/6P1/3Q3K/1R6"
]
for f in fens:
    bb = fen_to_bitboards(f)[None]
    pred = model.to_labels(model(bb.to(DEVICE)))
    print(pred)