In [1]:
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import chess
from tqdm.notebook import tqdm
from itertools import product
from torch.utils.data import Dataset, DataLoader

In [2]:
PIECES = "pPnNbBrRqQkK"
_THEMES = ["advancedPawn", "advantage", "anastasiaMate", "arabianMate", "attackingF2F7", "attraction", "backRankMate", "bishopEndgame", "bodenMate", "capturingDefender", "castling", "clearance", "crushing", "defensiveMove", "deflection", "discoveredAttack", "doubleBishopMate", "doubleCheck", "dovetailMate", "enPassant", "endgame", "equality", "exposedKing", "fork", "hangingPiece", "hookMate", "interference", "intermezzo", "kingsideAttack", "knightEndgame", "long", "master", "masterVsMaster", "mate", "mateIn1", "mateIn2", "mateIn3", "mateIn4", "mateIn5", "middlegame", "oneMove", "opening", "pawnEndgame", "pin", "promotion", "queenEndgame", "queenRookEndgame", "queensideAttack", "quietMove", "rookEndgame", "sacrifice", "short", "skewer", "smotheredMate", "superGM", "trappedPiece", "underPromotion", "veryLong", "xRayAttack", "zugzwang"]

def bitboards_to_fen(bbs: torch.Tensor) -> str:
    """Given a 12x8x8 tensor of pPnNbBrRqQkK bitboards, convert it to FEN"""
    assert bbs.shape == (12, 8, 8)
    board = chess.BaseBoard.empty()

    for p_idx, sq_rank, sq_file in (bbs == 1).nonzero():
        piece = chess.Piece.from_symbol(PIECES[p_idx])
        square = chess.square(sq_file, sq_rank)
        board.set_piece_at(square, piece)
    
    return board.board_fen()


        
def fen_to_bitboards(fen: str) -> torch.Tensor:
    """Given a FEN string, convert it to a 12x8x8 tensor of pPnNbBrRqQkK bitboards."""
    # Defensively split fen string, just in case
    fen = fen.split()[0]
    
    board = chess.BaseBoard(fen)
    bbs = torch.zeros((12, 8, 8))

    for sq, piece in board.piece_map().items():
        sq_rank, sq_file = chess.square_rank(sq), chess.square_file(sq)
        # p -> 0, P -> 1, n -> 2, N -> 3, ... k -> 10, K -> 11
        p_idx = (piece.piece_type - 1) * 2 + piece.color
        
        bbs[p_idx, sq_rank, sq_file] = 1
    
    return bbs



class ChessPuzzleDataset(Dataset):
    def __init__(self, filename="lichess_db_puzzle_processed.csv", themes=_THEMES):
        self.df = pd.read_csv(filename)
        self.themes = themes
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        puzzle = self.df.iloc[idx]
        bitboards = fen_to_bitboards(puzzle["FEN"]) 
        themes = torch.tensor(puzzle[self.themes].values.astype(int))
        return bitboards, themes


In [3]:
train_dataloader = DataLoader(ChessPuzzleDataset("lichess_db_puzzle_train.csv"), batch_size=128)
test_dataloader = DataLoader(ChessPuzzleDataset("lichess_db_puzzle_test.csv"), batch_size=128, shuffle=True)

In [4]:
class ChessNN(nn.Module):
    def __init__(self, prep_channels: list, prep_activation, label_num=len(_THEMES)):
        super().__init__()
        # Define preprocessing with 1x1 conv2ds
        # 12 is the number of chess pieces magic number
        channels = zip([12] + prep_channels, prep_channels)
        self.prep = nn.Sequential(
            *[nn.Sequential(
                nn.Conv2d(c_in, c_out, kernel_size=1),
                prep_activation)
                for c_in, c_out in channels]
        )
        # Define transformer stack
        H = prep_channels[-1]
        enc_layer = nn.TransformerEncoderLayer(
            d_model=H,
            nhead=8,
            dim_feedforward=2048,
            dropout=0.1,
            activation=nn.ReLU(),
            batch_first=True,
        )
        self.transformer = nn.TransformerEncoder(enc_layer, 3)
        
        # Transform to labels now
        self.final_cast = nn.Linear(H*64, label_num)

        # Tracking losses
        self.batch_losses = []
        self.epoch_train_losses = []
        self.epoch_test_losses = []

    def forward(self, x):
        # x.shape = B,12,8,8
        h = self.prep(x)
        # h.shape = B,H,8,8
        h1 = torch.flatten(h, 2, 3).mT
        # h1.shape = B,64,H
        h2 = self.transformer(h1)
        # h2.shape = B,64,H
        h3 = torch.flatten(h2, 1, 2)
        # h3.shape = B,64H
        h4 = self.final_cast(h3)
        return h4

    @staticmethod
    def to_labels(y):
        return [
            (_THEMES[idx.item()], v.item()) 
            for v, idx in zip(*torch.topk(torch.squeeze(F.sigmoid(y)), 5))
        ]

    
    def total_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


In [5]:
def train_loop(model: nn.Module, epochs: int):
    print("Total params:", model.total_params())
    optim = torch.optim.Adam(model.parameters())
    
    for t in (tepoch := tqdm(range(epochs), "epoch")):
        model.train()
        batch_losses = []
        # Iterate over training dataset
        for b, (data, label) in enumerate(tbatch := tqdm(train_dataloader, leave=False, unit="batch")):
            y = model(data)
            loss = F.binary_cross_entropy_with_logits(y, label.float())

            optim.zero_grad()
            loss.backward()
            optim.step()

            loss_item = loss.item()
            batch_losses.append(loss_item)
            model.batch_losses.append(loss_item)
            if b % 100 == 0:
                tbatch.set_postfix(loss=loss_item)
            
        model.eval()
        batch_test_losses = []
        # Iterate over test dataset
        for b, (data, label) in enumerate(tbatch := tqdm(test_dataloader, leave=False, unit="batch")):
            y = model(data)
            loss = F.binary_cross_entropy_with_logits(y, label)
            loss_item = loss.item()
            batch_test_losses.append(loss_item)
            if b % 100 == 0:
                tbatch.set_postfix(loss=loss_item)

        # Average loss in epoch
        avg_loss = sum(batch_losses) / len(batch_losses) 
        avg_test_loss = sum(batch_test_losses) / len(batch_test_losses)
        model.epoch_train_losses.append(avg_loss)
        model.epoch_test_losses.append(avg_test_loss)
        tepoch.set_postfix(train_loss=avg_loss, test_loss=avg_test_loss)

train_loop(ChessNN([16, 32], nn.ReLU()), 100)

Total params: 536204


epoch:   0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/23951 [00:00<?, ?batch/s]

KeyboardInterrupt: 

  0%|          | 0/100000 [00:00<?, ?it/s]

  0%|          | 0/10000000 [00:00<?, ?it/s]

  0%|          | 0/10000000 [00:00<?, ?it/s]

  0%|          | 0/10000000 [00:00<?, ?it/s]

  0%|          | 0/10000000 [00:00<?, ?it/s]

  0%|          | 0/10000000 [00:00<?, ?it/s]

  0%|          | 0/10000000 [00:00<?, ?it/s]

  0%|          | 0/10000000 [00:00<?, ?it/s]

  0%|          | 0/10000000 [00:00<?, ?it/s]

  0%|          | 0/10000000 [00:00<?, ?it/s]

  0%|          | 0/10000000 [00:00<?, ?it/s]

  0%|          | 0/10000000 [00:00<?, ?it/s]

  0%|          | 0/10000000 [00:00<?, ?it/s]

KeyboardInterrupt: 