# Generate dataset

We only have 12 pieces in our original dataset: 6 for each color as well as two empty squares.

This notebook creates a bunch of variations for each one.

In [12]:
from PIL import Image  # type: ignore
from tensorflow.keras.preprocessing.image import (  # type: ignore
    ImageDataGenerator,
    img_to_array,
    array_to_img
)
import numpy as np
import random
import uuid
import os
from concurrent.futures import ThreadPoolExecutor

## Config

The `dataset` directory will be created from the code below, including the background squares (using only a dark and light square as background seem to be enough)

In [13]:
base_dir = "data/pieces"
empty_board = "data/empty_board.png"
backgrounds = ["dataset/squares/square_0_1.png", "dataset/squares/square_0_2.png"]
training_data_dir = "dataset/training"
test_data_dir = "dataset/test"
squares_dir = "dataset/squares"
num_images = 10_000 # Images to generate / piece

In [14]:
os.makedirs(training_data_dir, exist_ok=True)
os.makedirs(test_data_dir, exist_ok=True)
os.makedirs(squares_dir, exist_ok=True)

## Create empty squares from empty board

In [15]:
# Load the image
board_image = Image.open(empty_board)

# Get the size of the board
board_width, board_height = board_image.size
if board_width != board_height:
    raise ValueError("The board image is not square!")

square_size = board_width // 8

# Loop to crop and save each square
count = 0
for row in range(8):
    for col in range(8):
        left = col * square_size
        top = row * square_size
        right = left + square_size
        bottom = top + square_size

        square_image = board_image.crop((left, top, right, bottom))

        # Save the square image
        square_name = f"square_{row}_{col}.png"
        square_image.save(os.path.join(squares_dir, square_name))
        count += 1

# Should always be 64
print(f"Empty squares have been generated and saved. Total squares: {count}")

Empty squares have been generated and saved. Total squares: 64


In [16]:
# Define augmentation generator for the pieces
datagen = ImageDataGenerator(
    rotation_range=30,
    width_shift_range=0.3,
    height_shift_range=0.3,
    shear_range=0.3,
    zoom_range=0.3,
    horizontal_flip=True,
    vertical_flip=True,
    brightness_range=[0.5, 1.5],
    fill_mode="nearest",
)

# Define augmentation generator for the empty squares
# same as above but basically divided ranges by two.
# Not sure it's a good iea, needs testing
datagen_empty = ImageDataGenerator(
    rotation_range=15,
    width_shift_range=0.15,
    height_shift_range=0.15,
    shear_range=0.15,
    zoom_range=0.15,
    horizontal_flip=True,
    vertical_flip=True,
    brightness_range=[0.5, 1.5],
    fill_mode="nearest",
)

In [17]:
# Overlay image on background
def add_background(piece_path, background_paths, num_images_per_bg):
    piece = Image.open(piece_path).convert("RGBA")
    images = []
    for bg_path in background_paths:
        background = Image.open(bg_path).convert("RGBA")
        background = background.resize(piece.size, Image.Resampling.LANCZOS)
        for _ in range(num_images_per_bg):
            combined = Image.alpha_composite(background, piece)
            images.append(combined)
    return images


# Generate and save augmented images for the pieces
def augment_and_save(images, output_path, num_augmented=50):
    os.makedirs(output_path, exist_ok=True)
    for img in images:
        x = img_to_array(img)
        x = x.reshape((1,) + x.shape)

        i = 0
        for batch in datagen.flow(
            x,
            batch_size=1,
            save_to_dir=None,
            save_format="png",
        ):
            augmented_img = batch[0]
            img = array_to_img(augmented_img, scale=True)
            unique_id = uuid.uuid4().hex
            filename = f"augmented_{unique_id}.png"
            img.save(os.path.join(output_path, filename))
            i += 1
            if i >= num_augmented // len(images):
                break


# Check if an image is valid and non-empty
def is_valid_image(img):
    if img is None:
        return False
    if img.size == (0, 0):
        return False
    if np.array(img).mean() == 0:
        return False
    return True


# Augment and save empty squares separately, ensuring unique types are included
def augment_empty_square(img, num_augmented):
    augmented_images = []
    x = img_to_array(img)
    if x.size == 0:
        print(f"Skipping invalid image with shape: {img.size}")
        return []
    x = x.reshape((1,) + x.shape)

    for batch in datagen_empty.flow(
        x,
        batch_size=1,
        save_to_dir=None,
        save_format="png",
    ):
        augmented_img = batch[0]
        if np.array(augmented_img).mean() == 0:
            print("Skipping generated empty image")
            continue
        augmented_images.append(augmented_img)
        if len(augmented_images) >= num_augmented:
            break

    return augmented_images


def load_and_classify_empty_squares(empty_square_dir):
    unique_squares = {
        "left_hedge": [],
        "bottom_hedge": [],
        "dark_square": None,
        "light_square": None,
    }

    for filename in os.listdir(empty_square_dir):
        piece_path = os.path.join(empty_square_dir, filename)
        try:
            empty_image = Image.open(piece_path).convert("RGBA")
            if not is_valid_image(empty_image):
                print(f"Invalid image: {piece_path}")
                continue
        except Exception as e:
            print(f"Error loading image {piece_path}: {e}")
            continue

        _, row, col = filename.rstrip(".png").split("_")
        row, col = int(row), int(col)

        if col == 0:
            unique_squares["left_hedge"].append(empty_image)
        elif row == 7:
            unique_squares["bottom_hedge"].append(empty_image)
        elif (
            unique_squares["dark_square"] is None and np.array(empty_image).mean() < 127
        ):
            unique_squares["dark_square"] = empty_image
        elif (
            unique_squares["light_square"] is None
            and np.array(empty_image).mean() >= 127
        ):
            unique_squares["light_square"] = empty_image

    return (
        unique_squares["left_hedge"]
        + unique_squares["bottom_hedge"]
        + [unique_squares["dark_square"], unique_squares["light_square"]]
    )


def process_empty_squares(
    unique_images, output_train_dir, output_test_dir, num_augmented
):
    all_augmented_images = []
    with ThreadPoolExecutor() as executor:
        futures = [
            executor.submit(augment_empty_square, img, num_augmented)
            for img in unique_images
        ]
        for future in futures:
            all_augmented_images.extend(future.result())

    # Split between training and test
    random.shuffle(all_augmented_images)
    split_index = int(len(all_augmented_images) * 0.8)
    train_images = all_augmented_images[:split_index]
    test_images = all_augmented_images[split_index:]

    # Save augmented empty squares to training set
    os.makedirs(output_train_dir, exist_ok=True)
    for img in train_images:
        img = array_to_img(img, scale=True)
        unique_id = uuid.uuid4().hex
        filename = f"augmented_{unique_id}.png"
        img.save(os.path.join(output_train_dir, filename))

    # Save augmented empty squares to test set
    os.makedirs(output_test_dir, exist_ok=True)
    for img in test_images:
        img = array_to_img(img, scale=True)
        unique_id = uuid.uuid4().hex
        filename = f"augmented_{unique_id}.png"
        img.save(os.path.join(output_test_dir, filename))


def process_piece(
    color, filename, backgrounds, num_images, training_data_dir, test_data_dir
):
    piece_type = filename.split("_")[1].split(".")[0]

    # Add background and augment
    piece_path = os.path.join(base_dir, color, filename)
    images = add_background(piece_path, backgrounds, num_images)

    # Split between training and test
    random.shuffle(images)
    split_index = int(len(images) * 0.8)
    train_images = images[:split_index]
    test_images = images[split_index:]

    # Save augmented images to training set
    output_piece_train_dir = os.path.join(training_data_dir, f"{color[0]}_{piece_type}")
    os.makedirs(output_piece_train_dir, exist_ok=True)
    augment_and_save(train_images, output_piece_train_dir)

    # Save augmented images to test set
    output_piece_test_dir = os.path.join(test_data_dir, f"{color[0]}_{piece_type}")
    os.makedirs(output_piece_test_dir, exist_ok=True)
    augment_and_save(test_images, output_piece_test_dir)

In [18]:
# Augment and save pieces
with ThreadPoolExecutor() as executor:
    futures = []
    for color in ["black", "white"]:
        color_dir = os.path.join(base_dir, color)
        for filename in os.listdir(color_dir):
            futures.append(
                executor.submit(
                    process_piece,
                    color,
                    filename,
                    backgrounds,
                    num_images,
                    training_data_dir,
                    test_data_dir,
                )
            )

print("Done generating pieces dataset")

Done generating pieces dataset


In [21]:
# Augment and save empty squares
unique_images = load_and_classify_empty_squares(squares_dir)
unique_images = [img for img in unique_images if img is not None]

process_empty_squares(
    unique_images,
    os.path.join(training_data_dir, "empty"),
    os.path.join(test_data_dir, "empty"),
    round(num_images / 15),
)


print("Done generating empty squares dataset")

Done generating empty squares dataset
