In [1]:
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import chess
from tqdm import tqdm
from itertools import product
from torch.utils.data import Dataset, DataLoader

In [2]:
PIECES = "pPnNbBrRqQkK"
_THEMES = ["advancedPawn", "anastasiaMate", "arabianMate", "attackingF2F7", "attraction", "backRankMate", "bishopEndgame", "bodenMate", "capturingDefender", "castling", "clearance", "defensiveMove", "deflection", "discoveredAttack", "doubleBishopMate", "doubleCheck", "dovetailMate", "enPassant", "equality", "exposedKing", "fork", "hangingPiece", "hookMate", "interference", "intermezzo", "kingsideAttack", "knightEndgame", "master", "masterVsMaster", "mate", "mateIn1", "mateIn2", "mateIn3", "mateIn4", "mateIn5", "oneMove", "opening", "pawnEndgame", "pin", "promotion", "queenEndgame", "queenRookEndgame", "queensideAttack", "quietMove", "rookEndgame", "sacrifice", "skewer", "smotheredMate", "superGM", "trappedPiece", "underPromotion", "xRayAttack", "zugzwang"]
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def bitboards_to_fen(bbs: torch.Tensor) -> str:
    """Given a 12x8x8 tensor of pPnNbBrRqQkK bitboards, convert it to FEN"""
    assert bbs.shape == (12, 8, 8)
    board = chess.BaseBoard.empty()

    for p_idx, sq_rank, sq_file in (bbs == 1).nonzero():
        piece = chess.Piece.from_symbol(PIECES[p_idx])
        square = chess.square(sq_file, sq_rank)
        board.set_piece_at(square, piece)
    
    return board.board_fen()


        
def fen_to_bitboards(fen: str) -> torch.Tensor:
    """Given a FEN string, convert it to a 12x8x8 tensor of pPnNbBrRqQkK bitboards."""
    # Defensively split fen string, just in case
    board = chess.BaseBoard(fen)
    bbs = torch.zeros((12, 8, 8))

    for sq, piece in board.piece_map().items():
        sq_rank, sq_file = chess.square_rank(sq), chess.square_file(sq)
        # p -> 0, P -> 1, n -> 2, N -> 3, ... k -> 10, K -> 11
        p_idx = (piece.piece_type - 1) * 2 + piece.color
        
        bbs[p_idx, sq_rank, sq_file] = 1
    
    return bbs



class ChessPuzzleDataset(Dataset):
    def __init__(self, filename="lichess_db_puzzle_processed.csv", themes=_THEMES):
        self.df = pd.read_csv(filename)
        self.themes = themes
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        puzzle = self.df.iloc[idx]
        bitboards = fen_to_bitboards(puzzle["white_to_play_FEN"]) 
        themes = torch.tensor(puzzle[self.themes].values.astype(int))
        return bitboards, themes


In [3]:
train_dataloader = DataLoader(ChessPuzzleDataset("lichess_db_puzzle_train.csv"), batch_size=128)
test_dataloader = DataLoader(ChessPuzzleDataset("lichess_db_puzzle_test.csv"), batch_size=128, shuffle=True)

In [5]:
class ChessNN(nn.Module):
    def __init__(self, prep_channels: list, prep_activation, label_num=len(_THEMES)):
        super().__init__()
        # Define preprocessing with 8x8 conv2ds
        # 12 is the number of chess pieces magic number
        channels = zip([12] + prep_channels, prep_channels)
        self.prep = nn.Sequential(
            *[nn.Sequential(
                nn.Conv2d(c_in, c_out, kernel_size=8, padding="same", padding_mode="reflect"),
                prep_activation)
                for c_in, c_out in channels]
        )
        # Define transformer stack
        H = prep_channels[-1]
        enc_layer = nn.TransformerEncoderLayer(
            d_model=H,
            nhead=128,
            dim_feedforward=8192,
            dropout=0.1,
            activation=nn.ReLU(),
            batch_first=True,
        )
        self.transformer = nn.TransformerEncoder(enc_layer, 5)
        
        # Transform to labels now
        self.fc = nn.Linear(H*64, label_num)

        # Tracking losses
        self.batch_losses = []
        self.epoch_train_losses = []
        self.epoch_test_losses = []

    def forward(self, x):
        # x.shape = B,12,8,8
        h = self.prep(x)
        # h.shape = B,H,8,8
        h1 = torch.flatten(h, 2, 3).mT
        # h1.shape = B,64,H
        h2 = self.transformer(h1)
        # h2.shape = B,64,H
        h3 = torch.flatten(h2, 1, 2)
        # h3.shape = B,64H
        h4 = self.fc(h3)
        return h4

    @staticmethod
    def to_labels(y):
        return [
            (_THEMES[idx.item()], v.item()) 
            for v, idx in zip(*torch.topk(torch.squeeze(F.sigmoid(y)), 5))
        ]

    
    def total_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


In [21]:
CUSTOM_FENS = [
    ("6k1/5ppp/Q4q2/8/5b2/5N2/1rP2PPP/2R3K1", "backrank, M2"),
    ("8/3brpr1/4pRpk/pp1pP1Rp/2pP3P/P1P5/2P1B1P1/6K1", "pin, M1"),
    ("r3r1n1/bp6/p2p2kp/3N4/2P3n1/1PQ3Pq/P4P2/4RRK1", "fork")
]
def train_loop(model: nn.Module, epochs: int):
    print("Total params:", model.total_params())
    optim = torch.optim.Adam(model.parameters(), lr=1e-5)

    model.to(DEVICE)
    
    for t in range(epochs):
        model.train()
        batch_losses = []
        # Iterate over training dataset
        for b, (data, label) in enumerate(tbatch := tqdm(train_dataloader, unit="batch")):
            data = data.to(DEVICE)
            label = label.to(DEVICE)

            y = model(data)
            loss = F.binary_cross_entropy_with_logits(y, label.float())

            optim.zero_grad()
            loss.backward()
            optim.step()

            loss_item = loss.item()
            batch_losses.append(loss_item)
            model.batch_losses.append(loss_item)
            if b % 10 == 0:
                # Get prediction for custom FENs too
                model.eval()

                fens = {}
                for f, theme in CUSTOM_FENS:
                    bb = fen_to_bitboards(f)[None]
                    pred = model.to_labels(model(bb.to(DEVICE)))
                    fens[theme] = [f"{k}:{v:.3f}" for k, v in pred]
                model.train()

                tbatch.set_postfix(loss=loss_item, **fens)
            
        model.eval()
        batch_test_losses = []
        # Iterate over test dataset
        for b, (data, label) in enumerate(tbatch := tqdm(test_dataloader, unit="batch")):
            data = data.to(DEVICE)
            label = label.to(DEVICE)

            y = model(data)
            loss = F.binary_cross_entropy_with_logits(y, label.float())
            loss_item = loss.item()
            batch_test_losses.append(loss_item)
            if b % 10 == 0:
                tbatch.set_postfix(loss=loss_item)

        # Save the model
        torch.save({
            "epoch": t,
            "model_state_dict": model.state_dict(),
            "optimiser_state_dict": optim.state_dict(),
        }, f"model_{t}.pth")

        # Average loss in epoch
        avg_loss = sum(batch_losses) / len(batch_losses) 
        avg_test_loss = sum(batch_test_losses) / len(batch_test_losses)
        model.epoch_train_losses.append(avg_loss)
        model.epoch_test_losses.append(avg_test_loss)

model = ChessNN([64, 128, 256, 512], nn.ReLU())
train_loop(model, 10)

Total params: 60046837


  2%|▏         | 421/23951 [05:01<4:42:17,  1.39batch/s, backrank, M2=['mate:0.310', 'mateIn2:0.218', 'sacrifice:0.094', 'fork:0.082', 'oneMove:0.080'], fork=['mate:0.236', 'mateIn2:0.150', 'fork:0.118', 'oneMove:0.094', 'sacrifice:0.092'], loss=0.103, pin, M1=['mate:0.447', 'mateIn2:0.274', 'kingsideAttack:0.244', 'oneMove:0.123', 'sacrifice:0.116']]                                                                             

In [6]:
model = ChessNN([64, 128, 256, 512], nn.ReLU())
checkpoint = torch.load("model_8.pth")
model.load_state_dict(checkpoint["model_state_dict"])
model.to(DEVICE)


<All keys matched successfully>

In [23]:
fens = [
    "6k1/5ppp/Q4q2/8/5b2/5N2/1rP2PPP/2R3K1", # backrank, M2
    "8/3brpr1/4pRpk/pp1pP1Rp/2pP3P/P1P5/2P1B1P1/6K1", # pin, M1
    "r3r1n1/bp6/p2p2kp/3N4/2P3n1/1PQ3Pq/P4P2/4RRK1", # fork
    "r1b1k1nr/pppp2pp/4Pq2/2b5/4p3/8/PPPP2PP/RNBQKB1R", # queen fork
    "1k6/pp3p2/b4Npp/4r3/4P3/5P1K/Pr5P/2R5", # fork AND backrank M2
    "2R5/4bppk/1p1p4/5R1P/4PQ2/5P2/r4q1P/7K",
    "5R2/bp4pk/2n3p1/P7/P1q3bP/6P1/3Q3K/1R6"
]
for f in fens:
    bb = fen_to_bitboards(f)[None]
    pred = model.to_labels(model(bb.to(DEVICE)))
    print(pred)

[('mate', 0.9988625049591064), ('mateIn2', 0.9982642531394958), ('backRankMate', 0.9273577332496643), ('xRayAttack', 0.20983801782131195), ('sacrifice', 0.08945082128047943)]
[('mate', 0.17066384851932526), ('fork', 0.15374240279197693), ('capturingDefender', 0.10465613752603531), ('master', 0.07743911445140839), ('deflection', 0.05064989626407623)]
[('fork', 0.7485860586166382), ('pin', 0.5237708687782288), ('master', 0.08492258191108704), ('masterVsMaster', 0.034428391605615616), ('hangingPiece', 0.032398976385593414)]
[('opening', 0.9999942779541016), ('fork', 0.8114109039306641), ('advancedPawn', 0.15881244838237762), ('castling', 0.05270930752158165), ('defensiveMove', 0.04499756917357445)]
[('mate', 0.9943329095840454), ('backRankMate', 0.929938018321991), ('mateIn2', 0.7910739779472351), ('mateIn3', 0.30065977573394775), ('fork', 0.2982390820980072)]
[('sacrifice', 0.7257033586502075), ('mate', 0.5319075584411621), ('attraction', 0.10258003324270248), ('oneMove', 0.0625035837292