In [10]:
import chess
import chess.pgn
import numpy as np
import os
import multiprocessing
import gc
import mmap
from tqdm.notebook import tqdm

# =================================================================
# 1. LOGIC: CHESS ENCODER (Robust & Normal Approach)
# =================================================================
class ChessEncoder:
    PIECE_MAP = {chess.PAWN: 0, chess.KNIGHT: 1, chess.BISHOP: 2, chess.ROOK: 3, chess.QUEEN: 4, chess.KING: 5}

    @staticmethod
    def get_passed_pawns(board, color):
        mask = np.zeros((8, 8), dtype=np.float32)
        pawns = board.pieces(chess.PAWN, color)
        opp_pawns = board.pieces(chess.PAWN, not color)
        for sq in pawns:
            r, f = divmod(sq, 8)
            files = {f, f - 1, f + 1} & {0, 1, 2, 3, 4, 5, 6, 7}
            is_passed = True
            for check_file in files:
                if color == chess.WHITE:
                    ranks = range(r + 1, 8)
                else:
                    ranks = range(0, r)
                for rank in ranks:
                    if (rank * 8 + check_file) in opp_pawns:
                        is_passed = False
                        break
                if not is_passed: break
            if is_passed: mask[r, f] = 1.0
        return mask

    @staticmethod
    def get_attack_map(board, color):
        mask = np.zeros((8, 8), dtype=np.float32)
        for sq in range(64):
            if board.is_attacked_by(color, sq):
                r, f = divmod(sq, 8)
                mask[r, f] = 1.0
        return mask

    @staticmethod
    def encode_board(board, move_count):
        tensor = np.zeros((8, 8, 41), dtype=np.float32)

        # 0-11: Pieces
        for sq, piece in board.piece_map().items():
            r, f = divmod(sq, 8)
            layer = ChessEncoder.PIECE_MAP[piece.piece_type] + (0 if piece.color == chess.WHITE else 6)
            tensor[r, f, layer] = 1

        # 12: Turn
        if board.turn == chess.WHITE: tensor[:, :, 12] = 1

        # 13-16: Castling
        if board.has_kingside_castling_rights(chess.WHITE): tensor[:, :, 13] = 1
        if board.has_queenside_castling_rights(chess.WHITE): tensor[:, :, 14] = 1
        if board.has_kingside_castling_rights(chess.BLACK): tensor[:, :, 15] = 1
        if board.has_queenside_castling_rights(chess.BLACK): tensor[:, :, 16] = 1

        # 17: En Passant
        if board.ep_square:
            r, f = divmod(board.ep_square, 8)
            tensor[r, f, 17] = 1

        # 18-19: Rule 50 & Move Count
        tensor[:, :, 18] = board.halfmove_clock / 100.0
        tensor[:, :, 19] = min(move_count / 200.0, 1.0)

        # 20: Check
        if board.is_check(): tensor[:, :, 20] = 1

        # 21-22: Attacks
        w_att = ChessEncoder.get_attack_map(board, chess.WHITE)
        b_att = ChessEncoder.get_attack_map(board, chess.BLACK)
        tensor[:, :, 21] = w_att
        tensor[:, :, 22] = b_att

        # 23-24: Repetition
        if board.is_repetition(2): tensor[:, :, 23] = 1
        if board.is_repetition(3): tensor[:, :, 24] = 1

        # 25-28, 31, 33: Last Move
        if board.move_stack:
            m = board.peek()
            r1, f1 = divmod(m.from_square, 8)
            r2, f2 = divmod(m.to_square, 8)
            tensor[r1, f1, 25] = 1
            tensor[r2, f2, 26] = 1
            if board.is_capture(m): tensor[:, :, 31] = 1
            if m.promotion: tensor[:, :, 33] = 1

        # 29-30: Material
        w_mat = sum(ChessEncoder.PIECE_MAP[p.piece_type] for p in board.piece_map().values() if p.color == chess.WHITE)
        b_mat = sum(ChessEncoder.PIECE_MAP[p.piece_type] for p in board.piece_map().values() if p.color == chess.BLACK)
        tensor[:, :, 29] = w_mat / 40.0
        tensor[:, :, 30] = b_mat / 40.0

        # 34-35: Passed Pawns
        tensor[:, :, 34] = ChessEncoder.get_passed_pawns(board, chess.WHITE)
        tensor[:, :, 35] = ChessEncoder.get_passed_pawns(board, chess.BLACK)

        # 36-37: Mobility (Approximate)
        cur_mob = board.legal_moves.count() / 50.0
        tensor[:, :, 36] = min(cur_mob, 1.0)
        tensor[:, :, 37] = 0 # Expensive to calculate opp mobility accurately

        # 38-39: King Safety (Simple ring check)
        for color, l in [(chess.WHITE, 38), (chess.BLACK, 39)]:
            king = board.king(color)
            if king:
                for sq in chess.SQUARES:
                    if chess.square_distance(king, sq) == 1 and board.is_attacked_by(not color, sq):
                        r, f = divmod(sq, 8)
                        tensor[r, f, l] = 1

        # 40: Tension
        tensor[:, :, 40] = np.clip(w_att * b_att, 0, 1)

        return tensor

# =================================================================
# 2. WORKER (CPU Task)
# =================================================================
def worker_task(args):
    pgn_path, offsets, start_idx = args
    X, y = [], []
    result_mapper = {'1-0': 0, '0-1': 1, '1/2-1/2': 2}

    try:
        with open(pgn_path, "r", encoding="utf-8", errors="replace") as f:
            for offset in offsets:
                f.seek(offset)
                try:
                    game = chess.pgn.read_game(f)
                    if game is None: continue

                    res = game.headers.get("Result")
                    if res not in result_mapper: continue
                    label_idx = result_mapper[res]

                    board = game.board()
                    for i, move in enumerate(game.mainline_moves()):
                        board.push(move)
                        move_count = i + 1

                        # FILTER: Skip first 20 moves, sample every 5 moves
                        if move_count > 20 and move_count % 5 == 0:
                            tensor = ChessEncoder.encode_board(board, move_count)
                            X.append(tensor)

                            lbl = np.zeros(3, dtype=np.float32)
                            lbl[label_idx] = 1.0
                            y.append(lbl)
                except:
                    continue
    except:
        pass
    return X, y

# =================================================================
# 3. MANAGER (Processing Logic)
# =================================================================
def process_pgn_to_shards(input_pgn, output_dir, shard_size=30000):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    print(f"\n[file] Processing: {os.path.basename(input_pgn)}")
    offsets = []

    # Fast Indexing
    try:
        with open(input_pgn, "rb") as f:
            with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as mm:
                pos = 0
                while True:
                    pos = mm.find(b"[Event", pos)
                    if pos == -1: break
                    offsets.append(pos)
                    pos += 1
    except FileNotFoundError:
        print(f"[!] Error: File not found at {input_pgn}")
        return

    total_games = len(offsets)
    print(f"[*] Found {total_games} games.")

    # Colab Environment: usually 2 CPU cores available
    num_workers = 2
    chunk_size = 1500 # Games per worker
    offset_chunks = [offsets[i:i+chunk_size] for i in range(0, total_games, chunk_size)]
    worker_args = [(input_pgn, chunk, i) for i, chunk in enumerate(offset_chunks)]

    buffer_X, buffer_y = [], []
    shard_count = 0
    total_samples = 0

    with multiprocessing.Pool(processes=num_workers) as pool:
        for res_X, res_y in tqdm(pool.imap_unordered(worker_task, worker_args), total=len(worker_args), desc="Progress"):
            buffer_X.extend(res_X)
            buffer_y.extend(res_y)

            # Save Shard if buffer full
            if len(buffer_X) >= shard_size:
                shard_name = os.path.join(output_dir, f"shard_{shard_count:03d}.npz")

                arr_X = np.array(buffer_X, dtype=np.float32)
                arr_y = np.array(buffer_y, dtype=np.float32)

                np.savez_compressed(shard_name, X=arr_X, y=arr_y)

                total_samples += len(buffer_X)
                shard_count += 1
                buffer_X, buffer_y = [], [] # Reset
                gc.collect()

    # Save leftovers
    if buffer_X:
        shard_name = os.path.join(output_dir, f"shard_{shard_count:03d}.npz")
        arr_X = np.array(buffer_X, dtype=np.float32)
        arr_y = np.array(buffer_y, dtype=np.float32)
        np.savez_compressed(shard_name, X=arr_X, y=arr_y)
        total_samples += len(buffer_X)

    print(f"[*] Completed {os.path.basename(input_pgn)} -> {total_samples} samples.")

# =================================================================
# 4. EXECUTION (FIXED DIRECTORY FOR COLAB)
# =================================================================
if __name__ == "__main__":

    # 1. SETUP PATHS
    # Since you uploaded files directly, they are in "/content/"
    BASE_DIR = "/content"

    # List of files to process
    # Format: (Input Filename, Output Directory Name)
    files_to_process = [
        ('train.pgn', 'processed_train'),
        ('validation.pgn', 'processed_val'),
        ('test.pgn', 'processed_test')
    ]

    print(f"[*] Starting batch processing in {BASE_DIR}...")

    # 2. RUN LOOP
    for filename, out_folder in files_to_process:
        input_path = os.path.join(BASE_DIR, filename)
        output_path = os.path.join(BASE_DIR, out_folder)

        # Only process if file exists
        if os.path.exists(input_path):
            process_pgn_to_shards(input_path, output_path, shard_size=40000)
        else:
            print(f"[!] Warning: {filename} not found in Colab files. Skipping.")

    print("\n[SUCCESS] All files processed.")

[*] Starting batch processing in /content...

[file] Processing: train.pgn
[*] Found 240000 games.


Progress:   0%|          | 0/160 [00:00<?, ?it/s]

[*] Completed train.pgn -> 2798692 samples.

[file] Processing: validation.pgn
[*] Found 30000 games.


Progress:   0%|          | 0/20 [00:00<?, ?it/s]

[*] Completed validation.pgn -> 350834 samples.

[file] Processing: test.pgn
[*] Found 30000 games.


Progress:   0%|          | 0/20 [00:00<?, ?it/s]

[*] Completed test.pgn -> 350534 samples.

[SUCCESS] All files processed.


In [13]:
!zip -r /content/train.zip /content/processed_train
!zip -r /content/valid.zip /content/processed_val
!zip -r /content/test.zip /content/processed_test

  adding: content/processed_train/ (stored 0%)
  adding: content/processed_train/shard_032.npz (deflated 9%)
  adding: content/processed_train/shard_009.npz (deflated 9%)
  adding: content/processed_train/shard_036.npz (deflated 10%)
  adding: content/processed_train/shard_011.npz (deflated 10%)
  adding: content/processed_train/shard_034.npz (deflated 10%)
  adding: content/processed_train/shard_021.npz (deflated 9%)
  adding: content/processed_train/shard_053.npz (deflated 10%)
  adding: content/processed_train/shard_044.npz (deflated 9%)
  adding: content/processed_train/shard_007.npz (deflated 9%)
  adding: content/processed_train/shard_043.npz (deflated 9%)
  adding: content/processed_train/shard_015.npz (deflated 10%)
  adding: content/processed_train/shard_022.npz (deflated 10%)
  adding: content/processed_train/shard_008.npz (deflated 9%)
  adding: content/processed_train/shard_023.npz (deflated 9%)
  adding: content/processed_train/shard_049.npz (deflated 9%)
  adding: content