In [None]:
%%capture
!pip install chess numpy matplotlib

In [None]:
import chess
from chess import pgn
import numpy as np
import pandas as pd
import csv

csv_file = 'lichess_elite_2016-03.pgn.csv'

In [None]:
def fen_to_input(fen: str) -> np.ndarray:
    board = chess.Board(fen)

    # Initialize 19 feature planes
    planes = []

    # Turn (1 plane)
    turn = np.full((8, 8), board.turn, dtype=np.float32)  # True for White, False for Black
    planes.append(turn)

    # Castling rights (4 planes)
    castling = [
        np.full((8, 8), board.has_queenside_castling_rights(chess.WHITE), dtype=np.float32),
        np.full((8, 8), board.has_kingside_castling_rights(chess.WHITE), dtype=np.float32),
        np.full((8, 8), board.has_queenside_castling_rights(chess.BLACK), dtype=np.float32),
        np.full((8, 8), board.has_kingside_castling_rights(chess.BLACK), dtype=np.float32)
    ]
    planes.extend(castling)

    # Fifty-move counter (1 plane)
    counter = np.ones((8, 8), dtype=np.float32) if board.can_claim_fifty_moves() else np.zeros((8, 8), dtype=np.float32)
    planes.append(counter)

    # Piece positions (12 planes: 6 piece types × 2 colors)
    for color in chess.COLORS:
        for piece in chess.PIECE_TYPES:
            array = np.zeros((8, 8), dtype=np.float32)
            for index in board.pieces(piece, color):
                array[chess.square_rank(index), chess.square_file(index)] = 1
            planes.append(array)

    # En passant (1 plane)
    en_passant = np.zeros((8, 8), dtype=np.float32)
    if board.has_legal_en_passant() and board.ep_square is not None:
        en_passant[chess.square_rank(board.ep_square), chess.square_file(board.ep_square)] = 1
    planes.append(en_passant)

    # Stack planes and reshape to (19, 8, 8)
    r = np.stack(planes, axis=0)  # Shape: (19, 8, 8)
    del board
    return r

In [None]:
from enum import Enum
import chess
from chess import PieceType
import numpy as np

class QueenDirection(Enum):
  NORTHWEST = 0
  NORTH = 1
  NORTHEAST = 2
  EAST = 3
  SOUTHEAST = 4
  SOUTH = 5
  SOUTHWEST = 6
  WEST = 7

class KnightMove(Enum):
  NORTH_LEFT = 0  # diff == -15
  NORTH_RIGHT = 1  # diff == -17
  EAST_UP = 2  # diff == -6
  EAST_DOWN = 3  # diff == 10
  SOUTH_RIGHT = 4  # diff == 15
  SOUTH_LEFT = 5  # diff == 17
  WEST_DOWN = 6  # diff == 6
  WEST_UP = 7  # diff == -10

class UnderPromotion(Enum):
  KNIGHT = 0
  BISHOP = 1
  ROOK = 2

class Mapping:
  """
  The mapper is a dictionary of moves.

  ```
  * the index is the type of move
  * the value is the plane's index, or an array of plane indices (for distance)
  """
  # knight moves from north_left to west_up (clockwise)
  knight_mappings = [-15, -17, -6, 10, 15, 17, 6, -10]

  def get_index(self, piece_type: PieceType, direction: Enum, distance: int = 1) -> int:
      if piece_type == chess.KNIGHT:
          return 56 + KnightMove(direction).value
      else:
          return QueenDirection(direction).value * 8 + distance

  @staticmethod
  def get_underpromotion_move(piece_type: PieceType, from_square: int, to_square: int) -> tuple[UnderPromotion, int]:
      piece = UnderPromotion(piece_type - 2)
      diff = from_square - to_square
      direction = 0
      if to_square < 8:
          # black promotes (1st rank)
          direction = diff - 8
      elif to_square > 55:
          # white promotes (8th rank)
          direction = diff + 8
      return (piece, direction)

  @staticmethod
  def get_knight_move(from_square: int, to_square: int) -> KnightMove:
      return KnightMove(Mapping.knight_mappings.index(from_square - to_square))

  @staticmethod
  def get_queenlike_move(from_square: int, to_square: int) -> tuple[QueenDirection, int]:
      diff = from_square - to_square
      if diff % 8 == 0:
          # north and south
          if diff > 0:
              direction = QueenDirection.SOUTH
          else:
              direction = QueenDirection.NORTH
          distance = int(diff / 8)
      elif diff % 9 == 0:
          # southwest and northeast
          if diff > 0:
              direction = QueenDirection.SOUTHWEST
          else:
              direction = QueenDirection.NORTHEAST
          distance = np.abs(int(diff / 8))
      elif from_square // 8 == to_square // 8:
          # east and west
          if diff > 0:
              direction = QueenDirection.WEST
          else:
              direction = QueenDirection.EAST
          distance = np.abs(diff)
      elif diff % 7 == 0:
          if diff > 0:
              direction = QueenDirection.SOUTHEAST
          else:
              direction = QueenDirection.NORTHWEST
          distance = np.abs(int(diff / 8)) + 1
      else:
          raise Exception("Invalid queen-like move")
      return (direction, distance)

  mapper = {
      # queens
      QueenDirection.NORTHWEST: [0, 1, 2, 3, 4, 5, 6],
      QueenDirection.NORTH: [7, 8, 9, 10, 11, 12, 13],
      QueenDirection.NORTHEAST: [14, 15, 16, 17, 18, 19, 20],
      QueenDirection.EAST: [21, 22, 23, 24, 25, 26, 27],
      QueenDirection.SOUTHEAST: [28, 29, 30, 31, 32, 33, 34],
      QueenDirection.SOUTH: [35, 36, 37, 38, 39, 40, 41],
      QueenDirection.SOUTHWEST: [42, 43, 44, 45, 46, 47, 48],
      QueenDirection.WEST: [49, 50, 51, 52, 53, 54, 55],
      # knights
      KnightMove.NORTH_LEFT: 56,
      KnightMove.NORTH_RIGHT: 57,
      KnightMove.EAST_UP: 58,
      KnightMove.EAST_DOWN: 59,
      KnightMove.SOUTH_RIGHT: 60,
      KnightMove.SOUTH_LEFT: 61,
      KnightMove.WEST_DOWN: 62,
      KnightMove.WEST_UP: 63,
      # underpromotions
      UnderPromotion.KNIGHT: [64, 65, 66],
      UnderPromotion.BISHOP: [67, 68, 69],
      UnderPromotion.ROOK: [70, 71, 72]
  }

mapper = Mapping()

In [None]:
def generate_policy_vector(fen: str, best_move: chess.Move) -> np.ndarray:
    policy = np.zeros((73, 8, 8), dtype=np.float32)
    board = chess.Board(fen)
    legal_moves = list(board.legal_moves)

    best_entry = map_to_policy_index(best_move, board)
    policy[best_entry[0], best_entry[1], best_entry[2]] = 0.9

    other_weight = 0.1 / (len(legal_moves) - 1) if len(legal_moves) > 1 else 0
    for move in legal_moves:
        if move == best_move:
            continue
        entry = map_to_policy_index(move, board)
        policy[entry[0], entry[1], entry[2]] = other_weight

    return policy


def map_to_policy_index(move: chess.Move, board: chess.Board) -> tuple[int, int, int]:
    from_square = move.from_square
    to_square = move.to_square
    piece = board.piece_at(from_square)

    if move.promotion and move.promotion != chess.QUEEN:
        promo_type, direction = mapper.get_underpromotion_move(move.promotion, from_square, to_square)
        plane = mapper.mapper[promo_type][1 - direction]
    elif piece.piece_type == chess.KNIGHT:
        direction = mapper.get_knight_move(from_square, to_square)
        plane = mapper.mapper[direction]
    else:
        direction, distance = mapper.get_queenlike_move(from_square, to_square)
        plane = mapper.mapper[direction][abs(distance) - 1]

    row = from_square % 8
    col = 7 - (from_square // 8)
    return (plane, row, col)


In [None]:
csv_file = 'lichess_elite_2016-03.pgn.csv'
H5PATH = 'lichess_elite_2016-03.pgn.csv.h5'
CHUNK_SIZE = 500_000
chunks = pd.read_csv(csv_file, chunksize=CHUNK_SIZE)

In [None]:
import gc
gc.collect()  # Force garbage collection to release memory
del df

NameError: name 'df' is not defined

In [None]:
df = next(chunks)
input = fen_to_input(df.iloc[0].fen)
output_policy = generate_policy_vector(df.iloc[0].fen, chess.Move.from_uci(df.iloc[0].move))
print(f'input shape: {input.shape}')
print(f'policy output shape: {output_policy.shape}')
print(f'value output shape: {df.iloc[0].value.shape} = {df.iloc[0].value}')

input shape: (19, 8, 8)
policy output shape: (73, 8, 8)
value output shape: () = -1.0


In [None]:
import h5py
import numpy as np
from tqdm import tqdm
import gc

def save_to_hdf5_incrementally(
    inputs_batch,
    policies_batch,
    values_batch,
    h5_file,
    compression="gzip",
    compression_opts=4,
):
    # Create datasets only once with compression
    if "inputs" not in h5_file:
        h5_file.create_dataset(
            "inputs",
            data=inputs_batch,
            maxshape=(None, *inputs_batch.shape[1:]),
            chunks=(32, *inputs_batch.shape[1:]),
            compression=compression,
            compression_opts=compression_opts,
        )
    else:
        h5_file["inputs"].resize(
            h5_file["inputs"].shape[0] + inputs_batch.shape[0], axis=0
        )
        h5_file["inputs"][-inputs_batch.shape[0] :] = inputs_batch

    if "policies" not in h5_file:
        h5_file.create_dataset(
            "policies",
            data=policies_batch,
            maxshape=(None, *policies_batch.shape[1:]),
            chunks=(32, *policies_batch.shape[1:]),
            compression=compression,
            compression_opts=compression_opts,
        )
    else:
        h5_file["policies"].resize(
            h5_file["policies"].shape[0] + policies_batch.shape[0], axis=0
        )
        h5_file["policies"][-policies_batch.shape[0] :] = policies_batch

    if "values" not in h5_file:
        h5_file.create_dataset(
            "values",
            data=values_batch,
            maxshape=(None, values_batch.shape[1]),
            chunks=(32, values_batch.shape[1]),
            compression=compression,
            compression_opts=compression_opts,
        )
    else:
        h5_file["values"].resize(
            h5_file["values"].shape[0] + values_batch.shape[0], axis=0
        )
        h5_file["values"][-values_batch.shape[0] :] = values_batch


# Set batch size to control memory consumption
BATCH_SIZE = 50000

# Preallocate NumPy arrays for the batch
inputs_batch = np.zeros((BATCH_SIZE, 19, 8, 8), dtype=np.float32)
policies_batch = np.zeros((BATCH_SIZE, 73, 8, 8), dtype=np.float32)
values_batch = np.zeros((BATCH_SIZE, 1), dtype=np.float32)

# Open HDF5 file once
with h5py.File(H5PATH, "a") as h5_file:
    batch_idx = 0
    # Use itertuples for faster iteration
    for df in chunks:
      for row in tqdm(df.itertuples(), total=len(df)):
          try:
              board_input = fen_to_input(row.fen)  # shape: (8, 8, 19)
              move = chess.Move.from_uci(row.move)
              policy_output = generate_policy_vector(row.fen, move)  # shape: (73, 8, 8)
              value_output = float(row.value)  # scalar

              # Store in preallocated arrays
              inputs_batch[batch_idx] = board_input
              policies_batch[batch_idx] = policy_output
              values_batch[batch_idx, 0] = value_output
              batch_idx += 1

              # When batch size is reached, save and reset
              if batch_idx >= BATCH_SIZE:
                  save_to_hdf5_incrementally(
                      inputs_batch[:batch_idx],
                      policies_batch[:batch_idx],
                      values_batch[:batch_idx],
                      h5_file,
                  )
                  batch_idx = 0
                  gc.collect()  # Force garbage collection to release memory

          except Exception as e:
              print(f"Skipping row {row.Index} due to error: {e}")
              continue

    # Save remaining data if any
    if batch_idx > 0:
        save_to_hdf5_incrementally(
            inputs_batch[:batch_idx],
            policies_batch[:batch_idx],
            values_batch[:batch_idx],
            h5_file,
        )
        gc.collect()  # Final garbage collection

100%|██████████| 500000/500000 [06:13<00:00, 1337.59it/s]
100%|██████████| 500000/500000 [06:19<00:00, 1317.85it/s]
100%|██████████| 500000/500000 [06:10<00:00, 1350.83it/s]
100%|██████████| 500000/500000 [06:16<00:00, 1326.54it/s]
100%|██████████| 500000/500000 [06:14<00:00, 1336.49it/s]
100%|██████████| 411902/411902 [05:05<00:00, 1349.98it/s]


In [None]:
with h5py.File('chess_dataset-1M.h5', 'r') as f:
    print(f['inputs'].shape)
    print(f['policies'].shape)
    print(f['values'].shape)

(1010000, 19, 8, 8)
(1010000, 73, 8, 8)
(1010000, 1)


In [None]:
from google.colab import files
files.download('chess_dataset-1M.h5')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>