In [1]:
import chess

def moves_to_mate(fen, solution_moves):
    board = chess.Board(fen)
    move_count = 0

    for uci in solution_moves:
        move = chess.Move.from_uci(uci)
        if move in board.legal_moves:
            board.push(move)
            if board.turn == chess.WHITE:  # full move completed (after Blackâ€™s move)
                move_count += 1
        else:
            return 0  # illegal move, not a valid mate sequence

    return move_count if board.is_checkmate() else 0

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import chess
# Utility: convert FEN to tensor with positional features
def fen_to_tensor(fen):
    board = chess.Board(fen)
    # Base piece planes: 12 channels
    planes = np.zeros((12, 8, 8), dtype=np.float32)
    for sq, piece in board.piece_map().items():
        idx = {'P':0,'N':1,'B':2,'R':3,'Q':4,'K':5}[piece.symbol().upper()]
        color_offset = 0 if piece.color == chess.WHITE else 6
        row = 7 - (sq // 8)
        col = sq % 8
        planes[idx + color_offset, row, col] = 1

    # Side to move plane
    stm_plane = np.full((1, 8, 8), float(board.turn), dtype=np.float32)

    # Additional positional features: 7 channels
    # Attack maps (white, black)
    attack_w = np.zeros((8, 8), dtype=np.float32)
    attack_b = np.zeros((8, 8), dtype=np.float32)
    for sq in chess.SQUARES:
        r = 7 - (sq // 8)
        c = sq % 8
        if board.attackers(chess.WHITE, sq):
            attack_w[r, c] = 1
        if board.attackers(chess.BLACK, sq):
            attack_b[r, c] = 1

    # Legal move mask
    legal_mask = np.zeros((8, 8), dtype=np.float32)
    for mv in board.legal_moves:
        r = 7 - (mv.to_square // 8)
        c = mv.to_square % 8
        legal_mask[r, c] = 1

    # Distance to kings
    dist_wk = np.zeros((8, 8), dtype=np.float32)
    dist_bk = np.zeros((8, 8), dtype=np.float32)
    wksq = board.king(chess.WHITE)
    bksq = board.king(chess.BLACK)
    for sq in chess.SQUARES:
        r = 7 - (sq // 8)
        c = sq % 8
        if wksq is not None:
            dist_wk[r, c] = chess.square_distance(sq, wksq)
        if bksq is not None:
            dist_bk[r, c] = chess.square_distance(sq, bksq)

    # Check status plane
    check_pl = np.full((8, 8), float(board.is_check()), dtype=np.float32)

    # Pinned pieces map
    pinned = np.zeros((8, 8), dtype=np.float32)
    for sq in chess.SQUARES:
        piece = board.piece_at(sq)
        if piece and board.is_pinned(piece.color, sq):
            r = 7 - (sq // 8)
            c = sq % 8
            pinned[r, c] = 1
    # Checking moves mask
    checking_moves_mask = np.zeros((8,8), dtype=np.float32)
    for mv in board.legal_moves:
        board.push(mv)
        if board.is_check():
            r = 7 - (mv.to_square // 8)
            c = mv.to_square % 8
            checking_moves_mask[r, c] = 1
        board.pop()

    controlled_white = np.zeros((8,8), dtype=np.float32)
    controlled_black = np.zeros((8,8), dtype=np.float32)
    for sq in chess.SQUARES:
        r = 7 - (sq // 8)
        c = sq % 8
        white_attackers = len(board.attackers(chess.WHITE, sq))
        black_attackers = len(board.attackers(chess.BLACK, sq))
        if white_attackers > black_attackers:
            controlled_white[r, c] = 1.0
        elif black_attackers > white_attackers:
            controlled_black[r, c] = 1.0

     # Stack all planes: 12 + 1 + 2 + 1 + 2 + 1 + 1 + 2 = 22 channels
    extra = [attack_w, attack_b, legal_mask, dist_wk, dist_bk, check_pl, pinned, controlled_white, controlled_black, checking_moves_mask]
    feature_planes = np.stack(extra, axis=0)
    all_planes = np.concatenate([planes, stm_plane, feature_planes], axis=0)

    return torch.from_numpy(all_planes)


In [None]:
EPOCHS = ???


class NumberMateCNN(nn.Module):
    def __init__(self):
        super(NumberMateCNN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(22, 32, 3, padding=1), nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2)
        )
        self.fc = nn.Sequential(
            nn.Flatten(), nn.Linear(128*2*2, 256), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(256, 1)
        )
    def forward(self, x):
        out = self.fc(self.conv(x)).squeeze(-1)
        return torch.sigmoid(out)  # outputs in [0,1]

# Binary classification model
class IsMateCNN(nn.Module):
    def __init__(self):
        super(IsMateCNN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(22, 32, 3, padding=1), nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2)
        )
        self.fc = nn.Sequential(
            nn.Flatten(), nn.Linear(128*2*2, 256), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(256, BINARY_CLASSES)
        )
    def forward(self, x): return self.fc(self.conv(x))


def train_binary(model, loader, device, save_path='binary_model.pt'):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LR)
    model.to(device)
    model.train()
    for epoch in range(EPOCHS):
        total_loss = 0
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * x.size(0)
        avg_loss = total_loss / len(loader.dataset)
        print(f"Epoch {epoch+1} binary loss: {avg_loss:.4f}")
        # Save weights after each epoch
        torch.save(model.state_dict(), f"{save_path}_epoch{epoch+1}.pt")


def train_distance(model, loader, device, save_path='distance_model.pt'):
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=LR)
    model.to(device)
    model.train()
    for epoch in range(EPOCHS):
        total_loss = 0
        for x, t in loader:
            x, t = x.to(device), t.to(device)
            # target: inverse mate distance (1/mate_moves), 0 if no mate
            t_recip = torch.where(t > 0, 1.0 / t, torch.zeros_like(t))
            optimizer.zero_grad()
            pred = model(x)
            loss = criterion(pred, t_recip)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * x.size(0)
        avg_loss = total_loss / len(loader.dataset)
        print(f"Epoch {epoch+1} distance loss: {avg_loss:.4f}")
        torch.save(model.state_dict(), f"{save_path}_epoch{epoch+1}.pt")

