In [1]:
!pip install chess

Collecting chess
  Downloading chess-1.11.2.tar.gz (6.1 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/6.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m6.1/6.1 MB[0m [31m193.2 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.1/6.1 MB[0m [31m114.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: chess
  Building wheel for chess (setup.py) ... [?25l[?25hdone
  Created wheel for chess: filename=chess-1.11.2-py3-none-any.whl size=147775 sha256=19b7c6d50cbdc02424a3813f8e840cfc01ee14caa41b465828adf44baa105383
  Stored in directory: /root/.cache/pip/wheels/fb/5d/5c/59a62d8a695285e59ec9c1f66add6f8a9ac4152499a2be0113
Successfully built chess
Installing collected packages: chess
Successfully installed chess-1.11.2


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import chess, chess.pgn, numpy as np

Encode state (board) to plane

In [4]:
def board_to_plane(board: chess.Board):
    """
      Convert a chess.Board into an array of feature planes of shape (18, 8, 8)
      Channels layout:
        0: White Pawns
        1: White Knights
        2: White Bishops
        3: White Rooks
        4: White Queens
        5: White King
        6: Black Pawns
        7: Black Knights
        8: Black Bishops
        9: Black Rooks
        10: Black Queens
        11: Black King
        12: Side to move (1 if White to move, 0 if Black)
        13: White kingside castling right
        14: White queenside castling right
        15: Black kingside castling right
        16: Black queenside castling right
        17: En-passant file (1 on all squares of the target file if available)
    """
    planes = np.zeros((18, 8, 8), dtype=np.float32)

    piece_to_plane = {
        (chess.PAWN, True) : 0,
        (chess.KNIGHT, True) : 1,
        (chess.BISHOP, True) : 2,
        (chess.ROOK, True) : 3,
        (chess.QUEEN, True) : 4,
        (chess.KING, True) : 5,
        (chess.PAWN, False) : 6,
        (chess.KNIGHT, False) : 7,
        (chess.BISHOP, False) : 8,
        (chess.ROOK, False) : 9,
        (chess.QUEEN, False): 10,
        (chess.KING, False) : 11,
    }

    # Piece plane
    for square, piece in board.piece_map().items():
      row = chess.square_rank(square)
      col = chess.square_file(square)
      plane_idx = piece_to_plane[(piece.piece_type, piece.color)]
      planes[plane_idx, row, col] = 1

    # Side to move plan
    if board.turn == chess.WHITE:
      planes[12, :, :] = 1

    # Castling rights
    planes[13, :, :] = int(board.has_kingside_castling_rights(chess.WHITE))
    planes[14, :, :] = int(board.has_queenside_castling_rights(chess.WHITE))
    planes[15, :, :] = int(board.has_kingside_castling_rights(chess.BLACK))
    planes[16, :, :] = int(board.has_queenside_castling_rights(chess.BLACK))

    if board.ep_square is not None:
      ep_file = chess.square_file(board.ep_square)
      planes[17, :, ep_file] = 1

    return planes

Encode move to integer id

In [5]:
direction_table = [
    (+1, 0),  # N
    (+1,+1),  # NE
    ( 0,+1),  # E
    (-1,+1),  # SE
    (-1, 0),  # S
    (-1,-1),  # SW
    ( 0,-1),  # W
    (+1,-1),  # NW
]

# 8 knight moves (dr, dc)
knight_moves = [
    (+2,+1), (+1,+2), (-1,+2), (-2,+1),
    (-2,-1), (-1,-2), (+1,-2), (+2,-1),
]

def sign(x: int) -> int:
    return 0 if x == 0 else (1 if x > 0 else -1)

promo_order = [chess.ROOK, chess.BISHOP, chess.KNIGHT]

ACTION_PLANES = 73

def encode_move(move: chess.Move) -> int:
    """
    Encode a chess.Move into a unique index in [0..4671].
    Layout per origin square (fr): fr * 73 + plane_id.
    plane_id:
      0-55: queen-like moves (8 dirs * 7 steps)
      56-63: knight jumps
      64-72: under-promotions (3 piece types * 3 move-types)
    Queen-promotions are treated as queen-like step=1 moves.
    """
    fr = move.from_square               # 0..63
    to = move.to_square
    dr = chess.square_rank(to) - chess.square_rank(fr)
    dc = chess.square_file(to) - chess.square_file(fr)

    if (dr, dc) in knight_moves:
        plane = 56 + knight_moves.index((dr, dc))

    elif move.promotion and move.promotion != chess.QUEEN:
        pidx = promo_order.index(move.promotion)
        if dc == 0:
            t = 0
        elif dc < 0:
            t = 1
        else:
            t = 2
        plane = 56 + 8 + (pidx * 3 + t)

    else:
        d = direction_table.index((sign(dr), sign(dc)))
        step = max(abs(dr), abs(dc))  # 1..7
        plane = d * 7 + (step - 1)

    return fr * ACTION_PLANES + plane  # 0..4671


def decode_move(idx: int) -> chess.Move:
    """
    Decode an action index [0..4671] back to a chess.Move.
    """
    fr = idx // ACTION_PLANES
    plane = idx % ACTION_PLANES

    rank_fr = chess.square_rank(fr)
    file_fr = chess.square_file(fr)

    # 1) Queen-like moves (0..55)
    if plane < 56:
        d = plane // 7
        step = (plane % 7) + 1
        dr, dc = direction_table[d]
        rank_to = rank_fr + dr * step
        file_to = file_fr + dc * step
        to = chess.square(file_to, rank_to)
        return chess.Move(fr, to)

    # 2) Knight jumps (56..63)
    if plane < 56 + 8:
        k = plane - 56
        dr, dc = knight_moves[k]
        rank_to = rank_fr + dr
        file_to = file_fr + dc
        to = chess.square(file_to, rank_to)
        return chess.Move(fr, to)

    # 3) Under-promotions (64..72)
    up_plane = plane - (56 + 8)
    pidx = up_plane // 3
    t = up_plane % 3
    promo = promo_order[pidx]

    dc = 0 if t == 0 else (-1 if t == 1 else +1)
    if rank_fr == 6:
        # white pawn
        dr = +1
    else:
        # black pawn
        dr = -1
    rank_to = rank_fr + dr
    file_to = file_fr + dc
    to = chess.square(file_to, rank_to)
    return chess.Move(fr, to, promotion=promo)

In [6]:
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import IterableDataset, DataLoader

In [7]:
class LichessIterable(IterableDataset):
  def __init__(self, pgn_path, max_samples = None):
    self.pgn_path = pgn_path
    self.max_samples = max_samples

  def __iter__(self):
    yielded = 0
    with open(self.pgn_path) as f:
      while self.max_samples is None or yielded < self.max_samples:
        # Read game
        game = chess.pgn.read_game(f)
        if game is None:
          break

        # Iterate through each game
        board = game.board()
        res   = {"1-0":1, "0-1":-1, "1/2-1/2":0}.get(game.headers["Result"], 0)
        for move in game.mainline_moves():
          yielded += 1

          planes   = torch.tensor(board_to_plane(board))
          move_idx = torch.tensor(encode_move(move))
          z        = torch.tensor(res, dtype=torch.float32)

          mask = torch.zeros(4672, dtype=torch.bool)
          for m in board.legal_moves:
              idx = encode_move(m)
              mask[idx] = True

          yield planes, move_idx, z, mask

          board.push(move)

          if self.max_samples is not None and yielded >= self.max_samples:
              break

In [8]:
import os
cpu_cores = os.cpu_count() or 1
num_workers = max(1, cpu_cores - 1)

In [9]:
batch_size = 2048
dataset = LichessIterable("/content/drive/MyDrive/ChessEngine/lichess_db_standard_rated_2013-01.pgn")

loader = DataLoader(dataset,
                    batch_size=batch_size,
                    num_workers=num_workers,
                    pin_memory=True,
                    persistent_workers=True,
                    prefetch_factor=2)

In [10]:
class Block(nn.Module):
    def __init__(self, channels: int):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3,
                               padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3,
                               padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(channels)
        self.relu  = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += x
        return self.relu(out)

In [11]:
class MiniChessNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Conv2d(18, 64, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            *[Block(64) for _ in range(3)]
        )
        self.policy = nn.Linear(64*8*8, 4672)
        self.value  = nn.Linear(64*8*8, 1)

    def forward(self, x):
        x = self.backbone(x).flatten(1)
        return self.policy(x), torch.tanh(self.value(x))

In [12]:
from torch.amp import autocast, GradScaler
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

net = MiniChessNet().to(device)
ce_loss  = nn.CrossEntropyLoss()
mse_loss = nn.MSELoss()

optimizer = optim.AdamW(net.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                 T_max=200_000)

scaler = GradScaler()

In [13]:
def train_step(batch):
    planes, move_id, z, mask = batch
    planes  = planes.to(device, non_blocking=True)
    move_id = move_id.to(device, non_blocking=True)
    z       = z.unsqueeze(1).to(device, non_blocking=True)
    mask    = mask.to(device, non_blocking=True)

    with autocast(device_type="cuda"):
        logits, v = net(planes)
        logits = logits.masked_fill(~mask, float('-1e4'))

        loss_p = ce_loss(logits, move_id)
        loss_v = mse_loss(v, z)
        loss   = loss_p + 0.25 * loss_v # loss = Cross-Entropy + Lambda * MSE

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad(set_to_none=True)
    scheduler.step()

    top1 = (logits.argmax(1) == move_id).float().mean().item()

    return loss.item(), top1

In [14]:
from tqdm.notebook import tqdm

num_epochs = 5
# ckpt = torch.load("/content/drive/MyDrive/ChessEngine/policy_epoch1.pth", map_location=device)
# net.load_state_dict(ckpt)

for epoch in range(1, num_epochs+1):
    net.train()
    running_loss = 0.0
    running_acc  = 0.0
    count        = 0

    for batch in tqdm(loader, desc=f"Epoch {epoch}", unit="batch"):
        loss, top1 = train_step(batch)
        bs = batch[0].shape[0]

        running_loss += loss * bs
        running_acc  += top1  * bs
        count        += bs

    avg_loss = running_loss / count
    avg_acc  = running_acc  / count

    print(f">>> Epoch {epoch:02d} — Loss: {avg_loss:.4f}, Top1 Acc: {avg_acc:.4f}")
    torch.save(net.state_dict(), f"policy_epoch{epoch}.pth")

Epoch 1: 0batch [00:00, ?batch/s]

>>> Epoch 01 — Loss: 2.2185, Top1 Acc: 0.4130


Epoch 2: 0batch [00:00, ?batch/s]

>>> Epoch 02 — Loss: 2.0781, Top1 Acc: 0.4325


Epoch 3: 0batch [00:00, ?batch/s]

>>> Epoch 03 — Loss: 1.9809, Top1 Acc: 0.4557


Epoch 4: 0batch [00:00, ?batch/s]

>>> Epoch 04 — Loss: 1.9244, Top1 Acc: 0.4697


Epoch 5: 0batch [00:00, ?batch/s]

>>> Epoch 05 — Loss: 1.8999, Top1 Acc: 0.4758
