# Chess Piece Detection Model Trainer

This notebook consolidates the entire workflow for training a YOLOv8-based chess piece detection model. 
It handles dependency installation, asset downloading, synthetic data generation, training, and ONNX export.

In [None]:
# Cell 1: Setup
# Install system dependencies
# !sudo apt-get update && sudo apt-get install -y libcairo2-dev libpango1.0-dev

# Install Python packages
!pip install ultralytics --upgrade
!pip install aiohttp cairosvg gitpython tqdm opencv-python pyyaml

import os
import sys
import subprocess
import torch
import shutil
from pathlib import Path
import logging

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    handlers=[logging.StreamHandler(sys.stdout)],
    force=True
)
logger = logging.getLogger(__name__)

# Set up environment
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
WORKING_DIR = os.getcwd()

# Check GPU
print(f"PyTorch version: {torch.__version__}")
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {DEVICE}")

if torch.cuda.is_available():
    num_gpus = torch.cuda.device_count()
    print(f"Number of GPUs available: {num_gpus}")
    for i in range(num_gpus):
        print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
else:
    print("No GPU available")

In [None]:
# Cell 2: Asset Download
# Embeds logic from training/get_pieces.py

import asyncio
import aiohttp
import cairosvg
import git
import stat
import os
from pathlib import Path
from tqdm.asyncio import tqdm as tqdm_async
from tqdm import tqdm
from typing import List, Tuple, Optional
from concurrent.futures import ProcessPoolExecutor
import functools

# Constants
CHESSCOM_BASE_URL = "https://images.chesscomfiles.com/chess-themes/pieces"
DEFAULT_OUTPUT_DIR = Path("./assets/pieces")
COLORS = ["w", "b"]
PIECES = ["k", "q", "r", "b", "n", "p"]

CHESSCOM_PIECE_THEMES = [
    "neo", "game_room", "wood", "glass", "gothic", "classic", "metal", "bases",
    "neo_wood", "icy_sea", "club", "ocean", "newspaper", "blindfold", "space",
    "cases", "condal", "3d_chesskid", "8_bit", "marble", "book", "alpha",
    "bubblegum", "dash", "graffiti", "light", "lolz", "luca", "maya", "modern",
    "nature", "neon", "sky", "tigers", "tournament", "vintage", "3d_wood",
    "3d_staunton", "3d_plastic", "real_3d",
]

LICHESS_REPO_URL = "https://github.com/lichess-org/lila.git"

async def download_chesscom_piece(
    session: aiohttp.ClientSession,
    set_name: str,
    color: str,
    piece: str,
    output_dir: Path,
    retries: int = 3
) -> bool:
    url = f"{CHESSCOM_BASE_URL}/{set_name}/150/{color}{piece}.png"
    save_path = output_dir / set_name / f"{color}{piece}.png"
    
    if save_path.exists():
        return True

    for attempt in range(retries):
        try:
            async with session.get(url) as response:
                if response.status == 200:
                    save_path.parent.mkdir(parents=True, exist_ok=True)
                    content = await response.read()
                    save_path.write_bytes(content)
                    return True
                elif response.status == 404:
                    return False
                else:
                    pass # Retry silently
        except Exception:
            pass

        if attempt < retries - 1:
            await asyncio.sleep(0.5 * (attempt + 1))

    return False

async def download_chesscom_pieces(output_dir: Path, concurrency: int = 100):
    logger.info(f"Starting Chess.com downloads to {output_dir}...")

    connector = aiohttp.TCPConnector(limit=concurrency)
    timeout = aiohttp.ClientTimeout(total=30)

    tasks = []
    async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
        for set_name in CHESSCOM_PIECE_THEMES:
            for color in COLORS:
                for piece in PIECES:
                    tasks.append(
                        download_chesscom_piece(session, set_name, color, piece, output_dir)
                    )

        results = []
        for coro in tqdm_async.as_completed(tasks, desc="Downloading Chess.com pieces"):
            result = await coro
            results.append(result)

    success_count = sum(results)
    logger.info(f"Chess.com downloads complete: {success_count}/{len(tasks)} successful")

def on_rm_error(func, path, exc_info):
    try:
        os.chmod(path, stat.S_IWRITE)
        func(path)
    except Exception as e:
        logger.error(f"Failed to remove {path}: {e}")

def convert_single_svg(args):
    """Helper function for parallel SVG conversion"""
    svg_path, png_path = args
    try:
        png_path.parent.mkdir(parents=True, exist_ok=True)
        cairosvg.svg2png(
            url=str(svg_path),
            write_to=str(png_path),
            output_height=150,
            output_width=150,
        )
        return True
    except Exception:
        return False

async def download_lichess_pieces(output_dir: Path):
    repo_dir = Path("temp_lila_repo")
    target_subfolder = "public/piece"

    logger.info("Cloning Lichess repo (this may take a moment)...")

    if repo_dir.exists():
        shutil.rmtree(repo_dir, onerror=on_rm_error)

    try:
        await asyncio.to_thread(git.Repo.clone_from, LICHESS_REPO_URL, str(repo_dir), depth=1)
    except git.Exc as e:
        logger.error(f"Failed to clone Lichess repo: {e}")
        return

    logger.info("Clone complete. Converting SVGs to PNGs...")
    source_path = repo_dir / target_subfolder

    if not source_path.exists():
        logger.error(f"Could not find piece directory in repo: {source_path}")
        shutil.rmtree(repo_dir, onerror=on_rm_error)
        return

    svg_files = []
    for root, _, files in os.walk(source_path):
        root_path = Path(root)
        if root_path == source_path:
            continue

        for file in files:
            if file.lower().endswith(".svg"):
                rel_dir = os.path.relpath(root, source_path)
                output_dir_path = output_dir / rel_dir
                svg_path = root_path / file
                png_filename = os.path.splitext(file)[0] + ".png"
                png_path = output_dir_path / png_filename
                
                # Skip if already exists
                if not png_path.exists():
                    svg_files.append((svg_path, png_path))

    logger.info(f"Found {len(svg_files)} SVG files to convert")

    # Process in parallel using ProcessPoolExecutor
    # We prefer ProcessPoolExecutor for CPU-bound SVG conversion
    max_workers = max(1, os.cpu_count() - 1)
    failed_count = 0
    
    if svg_files:
        with ProcessPoolExecutor(max_workers=max_workers) as executor:
            results = list(tqdm(
                executor.map(convert_single_svg, svg_files),
                total=len(svg_files),
                desc="Converting Lichess SVGs (Parallel)"
            ))
            failed_count = results.count(False)

    if failed_count > 0:
        logger.warning(f"Note: {failed_count} SVG files failed to convert")

    logger.info("Lichess conversion complete. Cleaning up repo...")
    shutil.rmtree(repo_dir, onerror=on_rm_error)
    logger.info("Cleanup complete.")

async def download_assets():
    output_dir = DEFAULT_OUTPUT_DIR
    output_dir.mkdir(parents=True, exist_ok=True)
    
    tasks = [
        download_chesscom_pieces(output_dir),
        download_lichess_pieces(output_dir)
    ]
    
    await asyncio.gather(*tasks)
    logger.info(f"All assets downloaded to {output_dir.absolute()}")

# Run download
await download_assets()

In [None]:
# Cell 3: Data Generation (Parallelized)
# This cell writes the generation logic to a script and runs it to ensure multiprocessing works correctly.

import os
from pathlib import Path

# Define the script content
data_gen_script = """
import os
import sys
import random
import cv2
import numpy as np
import gc
import shutil
import time
from pathlib import Path
from PIL import Image
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm
from typing import List, Tuple, Optional

# --- Constants ---
BOARD_THEMES = [
    ((240, 217, 181), (181, 136, 99)),  # Classic brown
    ((238, 238, 210), (118, 150, 86)),  # Green
    ((222, 227, 230), (140, 162, 173)),  # Blue
    ((255, 206, 158), (209, 139, 71)),  # Wood
    ((235, 236, 208), (115, 149, 82)),  # Tournament
    ((240, 240, 240), (120, 120, 120)),  # Gray
    ((255, 228, 181), (205, 133, 63)),  # Tan
    ((250, 235, 215), (139, 90, 43)),  # Marble
    ((173, 216, 230), (100, 149, 237)),  # Sky blue
    ((255, 218, 185), (160, 82, 45)),  # Peach
    ((245, 222, 179), (139, 69, 19)),  # Wheat
    ((230, 230, 250), (147, 112, 219)),  # Purple
]

# Global cache for workers
PIECE_CACHE = {}

# --- Core Logic ---

def random_fen_generation() -> str:
    pieces = ["r", "n", "b", "q", "k", "p", "R", "N", "B", "Q", "K", "P"]
    fen = []
    for _ in range(8):
        row = [None] * 8
        for i in range(8):
            if random.random() < 0.5:
                row[i] = random.choice(pieces)
        row_str = ""
        empty_count = 0
        for piece in row:
            if piece is not None:
                if empty_count > 0:
                    row_str += str(empty_count)
                    empty_count = 0
                row_str += piece
            else:
                empty_count += 1
        if empty_count > 0:
            row_str += str(empty_count)
        fen.append(row_str)
    return "/".join(fen) + " w KQkq - 0 1"

def generate_empty_board_fen() -> str:
    board_type = random.choice(["completely_empty", "very_sparse", "sparse", "hard_negative_midgame"])
    if board_type == "completely_empty":
        return "8/8/8/8/8/8/8/8 w KQkq - 0 1"
    pieces = ["r", "n", "b", "q", "k", "p", "R", "N", "B", "Q", "K", "P"]
    fen = []
    piece_probability = 0.05 if board_type == "very_sparse" else 0.15 if board_type == "sparse" else 0.35
    for _ in range(8):
        row = [None] * 8
        for i in range(8):
            if random.random() < piece_probability:
                row[i] = random.choice(pieces)
        row_str = ""
        empty_count = 0
        for piece in row:
            if piece is not None:
                if empty_count > 0:
                    row_str += str(empty_count)
                    empty_count = 0
                row_str += piece
            else:
                empty_count += 1
        if empty_count > 0:
            row_str += str(empty_count)
        fen.append(row_str)
    return "/".join(fen) + " w KQkq - 0 1"

def generate_royalty_focused_fen() -> str:
    fen = []
    royalty_pieces = ["k", "q", "b", "K", "Q", "B"]
    support_pieces = ["r", "n", "p", "R", "N", "P"]
    for _ in range(8):
        row = [None] * 8
        for i in range(8):
            if random.random() < 0.4:
                if random.random() < 0.7:
                    row[i] = random.choice(royalty_pieces)
                else:
                    row[i] = random.choice(support_pieces)
        row_str = ""
        empty_count = 0
        for piece in row:
            if piece is not None:
                if empty_count > 0:
                    row_str += str(empty_count)
                    empty_count = 0
                row_str += piece
            else:
                empty_count += 1
        if empty_count > 0:
            row_str += str(empty_count)
        fen.append(row_str)
    return "/".join(fen) + " w KQkq - 0 1"

def fen_to_grid(fen: str) -> List[int]:
    piece_to_id = {"r": 1, "n": 2, "b": 3, "q": 4, "k": 5, "p": 6, "R": 7, "N": 8, "B": 9, "Q": 10, "K": 11, "P": 12}
    board_fen = fen.split()[0]
    grid = []
    for row in board_fen.split("/"):
        for char in row:
            if char.isdigit():
                grid.extend([0] * int(char))
            else:
                grid.append(piece_to_id[char])
    return grid

def fen_to_piece_list(fen: str) -> List[Tuple[str, int, int]]:
    piece_map = {"r": "bR", "n": "bN", "b": "bB", "q": "bQ", "k": "bK", "p": "bP", "R": "wR", "N": "wN", "B": "wB", "Q": "wQ", "K": "wK", "P": "wP"}
    board_fen = fen.split()[0]
    pieces = []
    for row_idx, row in enumerate(board_fen.split("/")):
        col_idx = 0
        for char in row:
            if char.isdigit():
                col_idx += int(char)
            else:
                pieces.append((piece_map[char], row_idx, col_idx))
                col_idx += 1
    return pieces

def get_available_piece_sets(assets_dir: Path) -> List[str]:
    pieces_dir = assets_dir / "pieces"
    if not pieces_dir.exists(): return []
    blacklist = ["blindfold"]
    piece_sets = []
    for theme_dir in pieces_dir.iterdir():
        if theme_dir.is_dir() and (theme_dir / "wK.png").exists():
            if theme_dir.name.lower() not in blacklist:
                piece_sets.append(theme_dir.name)
    return sorted(piece_sets)

def draw_chess_board(fen: str, piece_set: str, board_theme: Tuple, assets_dir: Path, board_size: int = 640) -> Optional[np.ndarray]:
    global PIECE_CACHE
    square_size = board_size // 8
    light_color, dark_color = board_theme
    
    # Fast board creation using numpy
    board = np.zeros((board_size, board_size, 3), dtype=np.uint8)
    
    # Fill board colors (optimized)
    # Create 8x8 mask
    y_indices, x_indices = np.indices((8, 8))
    is_light = (y_indices + x_indices) % 2 == 0
    
    # Create small 8x8x3 grid
    grid_colors = np.zeros((8, 8, 3), dtype=np.uint8)
    grid_colors[is_light] = light_color
    grid_colors[~is_light] = dark_color
    
    # Scale up to full board
    board = grid_colors.repeat(square_size, axis=0).repeat(square_size, axis=1)

    pieces = fen_to_piece_list(fen)
    
    # Load pieces for this set if not in cache (should be handled by worker_init, but fallback)
    if piece_set not in PIECE_CACHE:
        # Fallback load (slow)
        piece_dir = assets_dir / "pieces" / piece_set
        PIECE_CACHE[piece_set] = {}
        for piece_name in ["bR", "bN", "bB", "bQ", "bK", "bP", "wR", "wN", "wB", "wQ", "wK", "wP"]:
            p = piece_dir / f"{piece_name}.png"
            if p.exists():
                img = cv2.imread(str(p), cv2.IMREAD_UNCHANGED)
                if img is not None:
                     if img.shape[0] != square_size:
                        img = cv2.resize(img, (square_size, square_size))
                     PIECE_CACHE[piece_set][piece_name] = img
    
    piece_dict = PIECE_CACHE.get(piece_set, {})
    
    for piece_name, row, col in pieces:
        piece_img = piece_dict.get(piece_name)
        
        if piece_img is None or len(piece_img.shape) < 3: continue
        
        y, x = row * square_size, col * square_size
        
        # Alpha blending
        try:
            if piece_img.shape[2] == 4:
                # Optimization: Use pre-split channels if possible, but here we assume raw cv2 img
                alpha = piece_img[:, :, 3:4] / 255.0
                piece_rgb = piece_img[:, :, :3]
                
                roi = board[y : y + square_size, x : x + square_size]
                
                # Vectorized blending
                blended = (alpha * piece_rgb + (1 - alpha) * roi)
                board[y : y + square_size, x : x + square_size] = blended.astype(np.uint8)
            else:
                # No alpha, just overwrite
                if len(piece_img.shape) == 2: 
                    piece_img = cv2.cvtColor(piece_img, cv2.COLOR_GRAY2BGR)
                board[y : y + square_size, x : x + square_size] = piece_img[:, :, :3]
        except Exception: 
            continue
            
    return board

def add_gray_dots_to_board(board: np.ndarray, num_dots: Optional[int] = None) -> np.ndarray:
    if board is None: return board
    board_copy = board.copy()
    square_size = board.shape[0] // 8
    if num_dots is None: num_dots = random.randint(1, 5)
    for _ in range(num_dots):
        row, col = random.randint(0, 7), random.randint(0, 7)
        center_y = int(row * square_size + square_size // 2)
        center_x = int(col * square_size + square_size // 2)
        dot_radius = int(random.randint(square_size // 8, square_size // 4))
        gray_value = int(random.randint(120, 160))
        dot_color = (gray_value, gray_value, gray_value)
        offset_x = int(random.randint(-square_size // 6, square_size // 6))
        offset_y = int(random.randint(-square_size // 6, square_size // 6))
        try:
            # Draw directly on copy
            overlay = board_copy.copy()
            cv2.circle(overlay, (center_x + offset_x, center_y + offset_y), dot_radius, dot_color, -1)
            alpha = random.uniform(0.5, 0.8)
            cv2.addWeighted(overlay, alpha, board_copy, 1 - alpha, 0, board_copy)
        except Exception:
            pass
    return board_copy

def generate_training_sample(assets_dir, piece_sets, output_size=640, is_negative=False, mode="normal") -> Tuple[Optional[np.ndarray], str]:
    if mode == "hard_negative_dots": fen = generate_empty_board_fen()
    elif mode == "royalty_focus": fen = generate_royalty_focused_fen()
    elif is_negative: fen = generate_empty_board_fen()
    else: fen = random_fen_generation()
    
    piece_set = random.choice(piece_sets)
    board_theme = random.choice(BOARD_THEMES)
    
    board = draw_chess_board(fen, piece_set, board_theme, assets_dir, output_size)
    if board is None: return None, fen
    
    if mode == "hard_negative_dots": 
        board = add_gray_dots_to_board(board)
        
    # Optional: Add noise/brightness (removed perspective/scale here as requested)
    if random.random() < 0.3:
        try:
            alpha, beta = random.uniform(0.9, 1.1), random.uniform(-5, 5)
            board = cv2.convertScaleAbs(board, alpha=alpha, beta=beta)
        except Exception:
            pass
            
    return board, fen

# --- Worker Function ---

def worker_init(assets_dir_str, piece_sets, square_size):
    # Pre-load assets for this worker
    global PIECE_CACHE
    PIECE_CACHE.clear()
    assets_dir = Path(assets_dir_str)
    
    # Optimization: Only load assigned piece sets
    for piece_set in piece_sets:
        piece_dir = assets_dir / "pieces" / piece_set
        if not piece_dir.exists(): continue
        
        PIECE_CACHE[piece_set] = {}
        for piece_name in ["bR", "bN", "bB", "bQ", "bK", "bP", "wR", "wN", "wB", "wQ", "wK", "wP"]:
            p = piece_dir / f"{piece_name}.png"
            if p.exists():
                img = cv2.imread(str(p), cv2.IMREAD_UNCHANGED)
                if img is not None:
                    # Resize once here
                    if img.shape[0] != square_size:
                        img = cv2.resize(img, (square_size, square_size), interpolation=cv2.INTER_AREA)
                    PIECE_CACHE[piece_set][piece_name] = img

def generate_batch(start_idx, n_samples, split_name, piece_sets, assets_dir_str, output_dir_str, mode, seed, save_clean_boards, boards_dir_str):
    # Initialize worker logic (re-seed)
    random.seed(seed)
    np.random.seed(seed % (2**32))
    
    assets_dir = Path(assets_dir_str)
    output_dir = Path(output_dir_str)
    boards_dir = Path(boards_dir_str) if boards_dir_str else None
    
    # Ensure caching (if worker_init wasn't called or lost context)
    global PIECE_CACHE
    if not PIECE_CACHE:
        worker_init(assets_dir_str, piece_sets, 640//8)
        
    images_dir = output_dir / split_name / "images"
    labels_dir = output_dir / split_name / "labels"
    
    processed = 0
    for i in range(n_samples):
        idx = start_idx + i
        
        # Generate
        board, fen = generate_training_sample(assets_dir, piece_sets, output_size=640, is_negative=(mode=="neg"), mode=mode)
        
        if board is None: continue
        
        try:
            # Save Image (Use PIL for speed/compat)
            # BGR to RGB
            img_rgb = board[:, :, ::-1]
            
            img_path = images_dir / f"sample_{idx:06d}.png"
            Image.fromarray(img_rgb).save(str(img_path))
            
            # Save Labels
            grid = fen_to_grid(fen)
            labels = []
            for idx2, piece_id in enumerate(grid):
                if piece_id == 0: continue
                row, col = idx2 // 8, idx2 % 8
                center_x, center_y = (col + 0.5) / 8, (row + 0.5) / 8
                labels.append(f"{piece_id - 1} {center_x:.6f} {center_y:.6f} {1.0/8:.6f} {1.0/8:.6f}")
            
            lbl_path = labels_dir / f"sample_{idx:06d}.txt"
            with open(lbl_path, "w") as f:
                f.write("\\n".join(labels) + "\\n")
                
            # Save clean board if requested
            if save_clean_boards and boards_dir:
                Image.fromarray(img_rgb).save(str(boards_dir / f"{split_name}_{idx:06d}.png"))
                
            processed += 1
        except Exception as e:
            pass
            
    return processed

# --- Orchestration ---

def generate_dataset_parallel(assets_dir, output_dir, count, negative_ratio, hard_negative_dots_count, royalty_focus_count, save_clean_boards):
    assets_path = Path(assets_dir)
    output_path = Path(output_dir)
    
    # Setup Dirs
    if output_path.exists(): shutil.rmtree(output_path)
    for split in ["train", "val"]:
        (output_path / split / "images").mkdir(parents=True, exist_ok=True)
        (output_path / split / "labels").mkdir(parents=True, exist_ok=True)
        
    boards_path = None
    if save_clean_boards:
        boards_path = assets_path / "boards"
        if boards_path.exists(): shutil.rmtree(boards_path)
        boards_path.mkdir(parents=True, exist_ok=True)

    # Get Piece Sets
    piece_sets = get_available_piece_sets(assets_path)
    if not piece_sets:
        print(f"ERROR: No piece sets found in {assets_path}/pieces!")
        print("Please check your asset download.")
        return
    
    print(f"Found {len(piece_sets)} piece sets. Starting generation...")
    random.shuffle(piece_sets)
    
    # Splits
    split_idx = max(1, int(len(piece_sets) * 0.8))
    train_sets, val_sets = piece_sets[:split_idx], piece_sets[split_idx:]
    
    # Calc Counts
    total_samples = int(count / (1 - negative_ratio))
    neg_count = int(total_samples * negative_ratio)
    pos_count = count
    
    # Task Configs
    tasks = []
    
    def add_task_group(split, sets, mode, n_total):
        # Split n_total into batches of ~1000
        batch_size = 1000
        n_batches = (n_total + batch_size - 1) // batch_size
        
        base_idx = len(tasks) * 100000 # Offset to avoid index collision across types? No, use sequential
        # Better: pass start_idx dynamically
        return n_batches, batch_size

    # We'll just create a list of work items
    # Work Item: (start_idx, count, split, mode, sets)
    
    work_items = []
    global_idx = 0
    
    # Helper to schedule
    def schedule(n, split, sets, mode):
        nonlocal global_idx
        batch_size = 1000
        remaining = n
        while remaining > 0:
            b = min(batch_size, remaining)
            work_items.append({
                "start_idx": global_idx,
                "n": b,
                "split": split,
                "sets": sets,
                "mode": mode,
                "seed": random.randint(0, 1000000)
            })
            global_idx += b
            remaining -= b
            
    # Train
    schedule(int(pos_count * 0.8), "train", train_sets, "normal")
    schedule(int(neg_count * 0.8), "train", train_sets, "neg")
    schedule(int(hard_negative_dots_count * 0.8), "train", train_sets, "hard_negative_dots")
    schedule(int(royalty_focus_count * 0.8), "train", train_sets, "royalty_focus")
    
    # Val
    schedule(pos_count - int(pos_count * 0.8), "val", val_sets, "normal")
    schedule(neg_count - int(neg_count * 0.8), "val", val_sets, "neg")
    schedule(hard_negative_dots_count - int(hard_negative_dots_count * 0.8), "val", val_sets, "hard_negative_dots")
    schedule(royalty_focus_count - int(royalty_focus_count * 0.8), "val", val_sets, "royalty_focus")
    
    print(f"Scheduled {len(work_items)} batches. Total samples: {global_idx}")
    
    # Run Parallel
    # Max workers = CPU count
    max_workers = os.cpu_count() or 4
    print(f"Using {max_workers} workers.")
    
    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        futures = []
        for item in work_items:
            f = executor.submit(
                generate_batch,
                item["start_idx"],
                item["n"],
                item["split"],
                item["sets"],
                str(assets_path),
                str(output_path),
                item["mode"],
                item["seed"],
                save_clean_boards,
                str(boards_path) if boards_path else None
            )
            futures.append(f)
            
        # Progress
        total_done = 0
        for f in tqdm(as_completed(futures), total=len(futures), desc="Generating"):
            total_done += f.result()
            
    print(f"Generation complete. {total_done} images generated.")
    
    # Create YAML
    yaml_content = f"path: {output_path.absolute().as_posix()}\\ntrain: train/images\\nval: val/images\\nnames:\\n"
    for i, n in enumerate(['r','n','b','q','k','p','R','N','B','Q','K','P']): yaml_content += f"  {i}: {n}\\n"
    with open(output_path / "data.yaml", "w") as f: f.write(yaml_content)

if __name__ == "__main__":
    generate_dataset_parallel(
        assets_dir="./assets",
        output_dir="yolo_pieces",
        count=15000,
        negative_ratio=0.15,
        hard_negative_dots_count=2000,
        royalty_focus_count=3000,
        save_clean_boards=True
    )
"""

# Write to file
with open("data_gen.py", "w") as f:
    f.write(data_gen_script)

# Run the script
print("Running data generation script...")
!python data_gen.py


In [None]:
# Cell 4: Training
!pip install "numpy<2"

from ultralytics import YOLO
from IPython.display import Image, display
import os
import shutil

def train_model(config_yaml, project_name, base_model='yolov8n.pt', epochs=50, imgsz=640, device_ids=[0, 1]):
    try:
        model = YOLO(base_model) 
    except Exception as e:
        print(f"Error loading base model {base_model}: {e}")
        return None
        
    print(f"\n--- Starting Training for {project_name} using {base_model} ---")
    
    # Handle devices list for single GPU or CPU
    if isinstance(device_ids, int):
        device_ids = [device_ids]
    elif device_ids == 'cpu':
        device_ids = 'cpu'

    print(f"Using device: {device_ids}")
    
    # Optimize workers based on CPU count
    workers = min(os.cpu_count(), 16)
    print(f"Using {workers} workers for data loading")

    try:
        # Train the model
        model.train(
            data=os.path.join(WORKING_DIR, config_yaml), 
            epochs=epochs, 
            imgsz=imgsz, 
            device=device_ids,
            patience=20,
            batch=32, 
            workers=workers, 
            # Augmentations settings for "Screen/Screenshot" use case
            # Disable geometric distortions
            augment=True, # Enables other augs like HSV, Mosaic
            perspective=0.0, 
            degrees=0.0,    # No rotation
            translate=0.1,  # Small translation is OK (crop position variation)
            scale=0.0,      # No scaling
            shear=0.0,
            mosaic=0.0,     # Disable mosaic to preserve rigid grid structure
            mixup=0.0,
            flipud=0.0,     # No vertical flip
            fliplr=0.0,     # No horizontal flip (chess board is asymmetric)
            # Color/Lighting augs are still good for different screens/themes
            hsv_v=0.4, 
            hsv_s=0.7,
            
            project=project_name,
            name='train',
            cache=False,
            exist_ok=False 
        )
    except Exception as e:
        print(f"An error occurred during training: {e}")
        return None
    
    # Retrieve the actual save directory from the model trainer
    if hasattr(model, 'trainer') and hasattr(model.trainer, 'save_dir'):
        save_dir = model.trainer.save_dir
        best_model_path = os.path.join(save_dir, 'weights', 'best.pt')
        print(f"Model saved to {save_dir}")
    else:
        print("Warning: Could not retrieve save_dir from model.trainer. Checking default path.")
        best_model_path = os.path.join(project_name, 'train', 'weights', 'best.pt')

    final_path = os.path.join(WORKING_DIR, f'{project_name}_best.pt')
    
    if os.path.exists(best_model_path):
        shutil.copy(best_model_path, final_path)
        print(f"Saved best model to {final_path}")
        try:
            results_png = os.path.join(os.path.dirname(os.path.dirname(best_model_path)), 'results.png')
            if os.path.exists(results_png):
                display(Image(filename=results_png))
        except Exception:
            pass
        return final_path
    
    print(f"Best model weights not found at {best_model_path}")
    return None

# Determine device
if torch.cuda.is_available():
    device_config = [i for i in range(torch.cuda.device_count())]
else:
    device_config = 'cpu'

# Start training
MODEL_PATH = train_model(
    'yolo_pieces/data.yaml', 
    'PieceDetector_Small', 
    base_model='yolov8s.pt', 
    epochs=150, 
    imgsz=640,
    device_ids=device_config
)

In [None]:
# Cell 5: Export to ONNX
def convert_model_to_onnx(model_path: Path, output_path: Optional[Path] = None) -> bool:
    logger.info(f"Converting {model_path}...")
    try:
        model = YOLO(str(model_path))
        onnx_file = model.export(format="onnx", simplify=True)
        logger.info(f"Successfully converted to {onnx_file}")
        if output_path:
             generated_path = Path(onnx_file)
             if generated_path.resolve() != output_path.resolve():
                 shutil.move(generated_path, output_path)
                 logger.info(f"Moved to {output_path}")
        return True
    except Exception as e:
        logger.error(f"Failed to convert {model_path}: {e}")
        return False

if MODEL_PATH and os.path.exists(MODEL_PATH):
    print(f"Converting trained model: {MODEL_PATH}")
    convert_model_to_onnx(Path(MODEL_PATH), Path(MODEL_PATH).with_suffix('.onnx'))
else:
    print("No model found to export.")