[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/one-ware/OneAI_demo_datasets/blob/main/dataset_generator/create_dataset.ipynb)


# Synthetic Dataset Generator

Generate synthetic training datasets by placing objects on backgrounds with automatic labeling.



## Quick Start Guide

### Step 1: Choose Your Environment

**Option A - Local Use:**
- Skip cells marked with üåê (Google Colab only)
- Use the configuration cell marked with üíª (Local)

**Option B - Google Colab:**
- Run cells marked with üåê (Google Colab)
- Use the configuration cell marked with üåê (Colab)

### Step 2: Prepare Your Images

**Object Images (Required):**
- Format: PNG with **transparent background**
- Organize into folders by category (e.g., `/birds/`, `/drones/`)
- Examples: bird.png, drone_01.png, etc.

**Background Images (Required):**
- Format: PNG or JPG
- Place all in one folder
- For video backgrounds: Name sequentially (e.g., `frame_0001.jpg`, `frame_0002.jpg`)

### Step 3: Configure & Generate

1. Set paths in the configuration cell
2. Adjust parameters with sliders
3. Run preview to check results
4. Generate full dataset

# Setup

Import necessary libraries and get images for the dataset.

In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display
import numpy as np
from pathlib import Path
import os
import sys
import cv2

#### üåê Google Colab Only: Video Frame Extraction

*(Skip this if working locally or if you already have frame images)*

In [None]:
def extract_frames(video_path, output_folder, frame_skip=5):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)
    
    cap = cv2.VideoCapture(video_path)
    frame_count = 0
    saved_count = 0

    while cap.isOpened() and frame_count < 500:
        ret, frame = cap.read()
        if not ret:
            break
        
        if frame_count % frame_skip == 0:
            output_path = os.path.join(output_folder, f"frame_{saved_count:04d}.jpg")
            cv2.imwrite(output_path, frame)
            saved_count += 1
        
        frame_count += 1

    cap.release()
    print(f"Extracted {saved_count} frames to {output_folder}")

video_path = "/content/video.mp4"
output_folder = "/content/output_images"
extract_frames(video_path, output_folder, frame_skip=5)

#### üåê Google Colab Only: File Upload

Upload your object and background images to Colab.

In [None]:
# üåê Google Colab Only - Upload files
# Skip this cell if working locally

from google.colab import files
uploaded = files.upload()

# After upload, organize files into folders:
# !mkdir -p /content/objects/birds
# !mkdir -p /content/objects/drones
# !mkdir -p /content/backgrounds
# Then move uploaded files to appropriate folders

#### üåê Google Colab Only: Download Dataset Generator

Download the required Python module from GitHub.

In [None]:
# üåê Google Colab Only
# Local users: Make sure dataset_generator.py is in the same folder as this notebook

import requests

# Replace this with the actual raw URL of your Python file
file_url = "https://raw.githubusercontent.com/one-ware/OneAI_demo_datasets/refs/heads/main/dataset_generator/dataset_generator.py"
file_name = "dataset_generator.py" # Name you want to save the file as

try:
    response = requests.get(file_url)
    response.raise_for_status()

    with open(file_name, 'w') as f:
        f.write(response.text)
    print(f"Successfully downloaded '{file_name}' from {file_url}")

except requests.exceptions.RequestException as e:
    print(f"Error downloading the file: {e}")
    print("Please ensure the URL is correct and accessible (e.g., raw.githubusercontent.com for GitHub files).")

## Import Required Functions

Run this cell to load the dataset generator functions.

In [2]:
from dataset_generator import (
    load_background_metadata, 
    preload_backgrounds, 
    get_image_dataset_from_folder, 
    create_dataset_from_generator
)

## 1. Configure Paths and Categories

### Object Image Requirements:
- Must be PNG format with **transparent background**
- Organize into separate folders by category
- Example structure:
  ```
  /objects/
    /birds/
      bird_01.png
      bird_02.png
    /drones/
      drone_01.png
      drone_02.png
  ```
  - start label ids with 1

### Background Image Requirements:
- PNG or JPG format
- All in one folder
- For video backgrounds: Use sequential naming (frame_0001.jpg, frame_0002.jpg, etc.)

In [None]:
# Define your object categories and adjust paths accordingly
categories = [
    {"name": "bird", "folder": "/content/bird", "class_id": 1},
    {"name": "drone", "folder": "/content/drone", "class_id": 2},
]

# Background folder
background_folder = "/content/bg"

# Output path
output_path = "./generated_dataset"

## 2. Set Parameters

Adjust these parameters to control how objects are placed on the background.

### Explanation of Parameters

The following parameters control the behavior of the synthetic dataset generation process:

**Object Parameters**
- **`min_scale_cm`**: Minimum size of objects in centimeters.
- **`max_scale_cm`**: Maximum size of objects in centimeters.
- **`rotation_range`**: Range of rotation angles for objects, specified in degrees (e.g., `(-1.0, 1.0)` for slight rotation).
- **`allow_overlap`**: Boolean flag indicating whether objects are allowed to overlap. Controlled by the `allow_overlap` widget.
- **`max_overlap_pct`**: Maximum allowable overlap between objects, as a percentage of their area.
- **`min_object_distance_cm`**: Minimum distance between objects in centimeters to prevent crowding.

**Placement Parameters**
- **`edge_avoidance`**: Margin to keep objects away from the edges of the background, where `0.0` means no margin and `0.5` means a large margin.
- **`prefer_center`**: Bias for placing objects closer to the center of the background, where `0.0` is uniform placement and `1.0` strongly favors the center.
- **`category_weights`**: List of weights for each object category, determining the likelihood of each category appearing in the dataset.

**Blending and Visual Quality**
- **`object_blending_strength`**: Strength of blending at the edges of objects, where `0.0` results in sharp edges and `1.0` results in heavy blurring.

**Dataset Composition**
- **`number_of_objects_per_image`**: Tuple specifying the range of the number of objects per image (e.g., `(1, 4)` means 1 to 4 objects per image).
- **`num_composite_images`**: Total number of images to generate in the dataset.

**Background Parameters**
- **`default_bg_width_cm`**: Default width of the background in centimeters.
- **`default_placement_area_pct`**: Tuple specifying the placement area as percentages of the background dimensions (e.g., `(0.05, 0.05, 0.95, 0.6)` for a central region).

**Output Parameters**
- **`output_width`**: Width of the generated images in pixels.
- **`output_height`**: Height of the generated images in pixels.
- **`output_channels`**: Number of color channels in the output images (e.g., `3` for RGB).

**Labeling and Validation**
- **`max_labels`**: Maximum number of label boxes per image.
- **`is_video_background`**: Boolean flag indicating whether the background is a video with different reference and test frames. In that case, reference images can be one frame before/after and it is assumed that the image frames can be ordered by name. Deactivate this setting if you want the same reference and test background image.

These parameters are adjustable through the widgets in the notebook, allowing for fine-tuning of the dataset generation process.

In [None]:
# Object size range in centimeters
object_size_slider = widgets.FloatRangeSlider(
    value=[11.0, 15.0], min=1.0, max=50.0, step=0.5,
    description='Object Size (cm):', style={'description_width': '150px'},
    layout=widgets.Layout(width='90%')
)

# Rotation range in degrees
rotation_slider = widgets.FloatRangeSlider(
    value=[-1.0, 1.0], min=-180.0, max=180.0, step=1.0,
    description='Rotation (¬∞):', style={'description_width': '150px'},
    layout=widgets.Layout(width='90%')
)

# Minimum distance between objects in centimeters
min_distance_slider = widgets.FloatSlider(
    value=2.0, min=0.0, max=10.0, step=0.5,
    description='Min Distance (cm):', style={'description_width': '150px'},
    layout=widgets.Layout(width='90%')
)

# Keep objects away from edges (0.0 = no margin, 0.5 = large margin)
edge_margin_slider = widgets.FloatSlider(
    value=0.1, min=0.0, max=0.5, step=0.05,
    description='Edge Margin:', style={'description_width': '150px'},
    layout=widgets.Layout(width='90%')
)

# Prefer center placement (0.0 = uniform, 1.0 = strongly centered)
center_bias_slider = widgets.FloatSlider(
    value=0.3, min=0.0, max=1.0, step=0.1,
    description='Center Bias:', style={'description_width': '150px'},
    layout=widgets.Layout(width='90%')
)

# Blending strength (0.0 = sharp edges, 1.0 = heavily blurred)
blending_slider = widgets.FloatSlider(
    value=0.3, min=0.0, max=1.0, step=0.1,
    description='Blending:', style={'description_width': '150px'},
    layout=widgets.Layout(width='90%')
)

# Number of objects per image
num_objects_slider = widgets.IntRangeSlider(
    value=[1, 4], min=1, max=20, step=1,
    description='Objects/Image:', style={'description_width': '150px'},
    layout=widgets.Layout(width='90%')
)

# Number of images to generate
num_images_input = widgets.IntText(
    value=10, description='Total Images:', style={'description_width': '150px'},
    layout=widgets.Layout(width='90%')
)

# Image dimensions
width_input = widgets.IntText(
    value=1080, description='Width (px):', style={'description_width': '150px'},
    layout=widgets.Layout(width='90%')
)
height_input = widgets.IntText(
    value=720, description='Height (px):', style={'description_width': '150px'},
    layout=widgets.Layout(width='90%')
)

# Background settings
default_bg_width_cm = widgets.FloatText(
    value=150.0, description='BG Width (cm):', style={'description_width': '150px'},
    layout=widgets.Layout(width='90%')
)
default_placement_area_width = widgets.FloatRangeSlider(
    value=[0.05, 0.95], min=0.0, max=1.0, step=0.01,
    description='Placement Area Width (%):', style={'description_width': '150px'},
    layout=widgets.Layout(width='90%')
)

default_placement_area_height = widgets.FloatRangeSlider(
    value=[0.05, 0.95], min=0.0, max=1.0, step=0.01,
    description='Placement Area Height (%):', style={'description_width': '150px'},
    layout=widgets.Layout(width='90%')
)

# Set background info
is_video_background = widgets.Checkbox(
    value=True, 
    description='Video Background (different ref/test frames)',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='90%')
)

# Overlap settings
allow_overlap = widgets.Checkbox(
    value=False, description='Allow Overlap',
    layout=widgets.Layout(width='90%')
)
max_overlap_pct = widgets.FloatSlider(
    value=0.1, min=0.0, max=1.0, step=0.01,
    description='Max Overlap (%):',
    layout=widgets.Layout(width='90%')
)

# Maximum number of label boxes per image
max_labels = widgets.IntText(
    value=10, description='Max Labels:', style={'description_width': '150px'},
    layout=widgets.Layout(width='90%')
)

# Display all controls with full width
display(widgets.VBox([
    default_bg_width_cm,
    default_placement_area_width,
    default_placement_area_height,
    allow_overlap,
    is_video_background,
    max_overlap_pct,
    max_labels
], layout=widgets.Layout(width='90%')))

display(widgets.VBox([
    object_size_slider,
    rotation_slider,
    min_distance_slider,
    edge_margin_slider,
    center_bias_slider,
    blending_slider,
    num_objects_slider,
    num_images_input,
    width_input,
    height_input
], layout=widgets.Layout(width='90%')))

## 3. Preview

**Generate 3 sample images** to verify your settings before creating the full dataset.

The preview shows:
- Left: Background image (reference frame)
- Right: Background + objects (test frame) with bounding boxes

**Check:**
- ‚úÖ Object sizes look realistic
- ‚úÖ Objects are placed in good positions
- ‚úÖ Bounding boxes are accurate
- ‚úÖ Class labels are correct

If something looks wrong, adjust parameters above and run this cell again.

In [None]:
# Load data
background_metadata = load_background_metadata(
    [background_folder],
    default_bg_width_cm=default_bg_width_cm.value,
    default_placement_area_pct=(default_placement_area_width.value[0], default_placement_area_height.value[0], default_placement_area_width.value[1], default_placement_area_height.value[1])
)
background_images, bg_widths, bg_placement_areas = preload_backgrounds(background_metadata)

category_datasets = []
category_class_ids = []
for cat in categories:
    ds = get_image_dataset_from_folder(cat["folder"])
    category_datasets.append(ds)
    category_class_ids.append(cat["class_id"])

# Build parameters from widgets and configuration
params = {
    "min_scale_cm": object_size_slider.value[0],
    "max_scale_cm": object_size_slider.value[1],
    "rotation_range": tuple(rotation_slider.value),
    "allow_overlap": allow_overlap.value,
    "max_overlap_pct": max_overlap_pct.value,
    "min_object_distance_cm": min_distance_slider.value,
    "edge_avoidance": edge_margin_slider.value,
    "prefer_center": center_bias_slider.value,
    "category_weights": [1.0 / len(categories)] * len(categories),
    "object_blending_strength": blending_slider.value,
    "number_of_objects_per_image": tuple(num_objects_slider.value),
    "num_composite_images": 3,
    "default_bg_width_cm": default_bg_width_cm.value,
    "default_placement_area_pct": (default_placement_area_width.value[0], default_placement_area_height.value[0], default_placement_area_width.value[1], default_placement_area_height.value[1]),
    "output_width": width_input.value,
    "output_height": height_input.value,
    "output_channels": 3,
    "max_labels": max_labels.value,
    "is_video_background": is_video_background.value
}

# Validate parameters
if params["min_scale_cm"] >= params["max_scale_cm"]:
    print(f"Warning: min_scale_cm ({params['min_scale_cm']}) >= max_scale_cm ({params['max_scale_cm']})")
    print("Setting max_scale_cm to min_scale_cm + 1.0")
    params["max_scale_cm"] = params["min_scale_cm"] + 1.0

if params["number_of_objects_per_image"][0] > params["number_of_objects_per_image"][1]:
    print(f"Warning: Invalid object range, fixing...")
    params["number_of_objects_per_image"] = (1, max(2, params["number_of_objects_per_image"][1]))

# Generate preview
preview_dataset = create_dataset_from_generator(
    overlay=True,
    params=params,
    background_images=background_images,
    bg_widths=bg_widths,
    bg_placement_areas=bg_placement_areas,
    category_datasets=category_datasets,
    category_class_ids=category_class_ids,
    num_samples=3
)

# Display
fig, axes = plt.subplots(3, 2, figsize=(12, 16))
for idx, (images, labels) in enumerate(preview_dataset.take(3)):
    axes[idx, 0].imshow(images[0].numpy())
    axes[idx, 0].set_title(f'Background {idx+1}')
    axes[idx, 0].axis('off')
    
    axes[idx, 1].imshow(images[1].numpy())
    axes[idx, 1].set_title(f'With Objects {idx+1}')
    axes[idx, 1].axis('off')
    
    # Draw bounding boxes (labels are now: exists, xmin, ymin, xmax, ymax, class_id in absolute pixels)
    for label in labels:
        exists, xmin, ymin, xmax, ymax, class_id = label.numpy()
        if exists > 0.5:
            # Coordinates are already in absolute pixels
            box_w = xmax - xmin
            box_h = ymax - ymin
            
            rect = plt.Rectangle((xmin, ymin), box_w, box_h,
                                linewidth=2, edgecolor='lime', facecolor='none')
            axes[idx, 1].add_patch(rect)
            
            # Class label lookup by class_id
            cat_name = next((c['name'] for c in categories if c['class_id'] == int(class_id)), f"Class {int(class_id)}")
            axes[idx, 1].text(
                xmin, ymin - 5, cat_name,
                color='lime', fontsize=10, weight='bold',
                bbox=dict(boxstyle='round', facecolor='black', alpha=0.5)
            )

plt.tight_layout()
plt.show()

## 4. Generate Full Dataset

**Only run this after the preview looks good!**

This will generate all images and may take a while depending on the number of images.

### Output Format:
- `image_XXXXX_temp.png`: Reference image (background only)
- `image_XXXXX_test.png`: Test image (background + objects)
- `image_XXXXX_test.txt`: Labels (bounding boxes)

### Label Format:
Each line in `.txt` file: `xmin ymin xmax ymax class_id`
- Coordinates in absolute pixels
- class_id matches your category configuration

In [None]:
# Update number of images
params["num_composite_images"] = num_images_input.value

# Generate dataset
print(f"Generating {params['num_composite_images']} images...")
full_dataset = create_dataset_from_generator(
    overlay=True,
    params=params,
    background_images=background_images,
    bg_widths=bg_widths,
    bg_placement_areas=bg_placement_areas,
    category_datasets=category_datasets,
    category_class_ids=category_class_ids,
    num_samples=params['num_composite_images']
)

# Save to disk
export_path = Path(output_path)
export_path.mkdir(parents=True, exist_ok=True)

print(f"Saving to {export_path}...")
for idx, (images, labels) in enumerate(full_dataset):
    # Save images
    bg_img = (images[0].numpy() * 255).astype(np.uint8)
    obj_img = (images[1].numpy() * 255).astype(np.uint8)
    
    tf.io.write_file(str(export_path / f"image_{idx:05d}_temp.png"), tf.image.encode_png(bg_img))
    tf.io.write_file(str(export_path / f"image_{idx:05d}_test.png"), tf.image.encode_png(obj_img))
    
    # Save labels in corner format: xmin ymin xmax ymax class_id (absolute pixels, 1-based)
    with open(export_path / f"image_{idx:05d}_test.txt", 'w') as f:
        for label in labels:
            exists, xmin, ymin, xmax, ymax, class_id = label.numpy()
            if exists > 0.5:
                # Write in format: xmin ymin xmax ymax class_id (absolute pixel coordinates)
                f.write(f"{int(xmin)} {int(ymin)} {int(xmax)} {int(ymax)} {int(class_id)}\n")
    
    if (idx + 1) % 50 == 0:
        print(f"  {idx + 1}/{params['num_composite_images']} images saved...")

print(f"‚úì Done! {params['num_composite_images']} images saved to {export_path}")
print(f"Label format: xmin ymin xmax ymax class_id (absolute pixel coordinates, 1-based class IDs)")