[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/one-ware/OneAI_demo_datasets/blob/main/dataset_generator/create_dataset.ipynb)


# Synthetic Dataset Generator

Generate synthetic training datasets by placing objects on backgrounds with automatic labeling.
You can create object detection datasets, either single images or image pairs for reference detection.


## Quick Start Guide

### Step 1: Choose Your Environment

**Option A - Local Use:**
- Skip cells marked with üåê (Google Colab only)
- Use the configuration cell marked with üíª (Local)

**Option B - Google Colab:**
- Run cells marked with üåê (Google Colab)
- Use the configuration cell marked with üåê (Colab)

### Step 2: Prepare Your Images

**Object Images (Required):**
- Format: PNG with **transparent background**
- Organize into folders by category (e.g., `/birds/`, `/drones/`)
- Examples: bird.png, drone_01.png, etc.

**Background Images (Required):**
- Format: PNG or JPG
- Place all in one folder
- For video backgrounds: Name sequentially (e.g., `frame_0001.jpg`, `frame_0002.jpg`)

### Step 3: Configure & Generate

1. Set paths in the configuration cell
2. Adjust parameters with sliders
3. Run preview to check results
4. Generate full dataset

# Setup

Import necessary libraries and get images for the dataset.

In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display
import numpy as np
from pathlib import Path
import os
import sys
import cv2

#### üåê Google Colab Only: Video Frame Extraction

*(Skip this if working locally or if you already have frame images)*

In [None]:
def extract_frames(video_path, output_folder, frame_skip=5):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)
    
    cap = cv2.VideoCapture(video_path)
    frame_count = 0
    saved_count = 0

    while cap.isOpened() and frame_count < 500:
        ret, frame = cap.read()
        if not ret:
            break
        
        if frame_count % frame_skip == 0:
            output_path = os.path.join(output_folder, f"frame_{saved_count:04d}.jpg")
            cv2.imwrite(output_path, frame)
            saved_count += 1
        
        frame_count += 1

    cap.release()
    print(f"Extracted {saved_count} frames to {output_folder}")

video_path = "/content/video.mp4"
output_folder = "/content/output_images"
extract_frames(video_path, output_folder, frame_skip=5)

#### üåê Google Colab Only: File Upload

Upload your object and background images to Colab.

In [None]:
# üåê Google Colab Only - Upload files
# Skip this cell if working locally

from google.colab import files
uploaded = files.upload()

# After upload, organize files into folders:
# !mkdir -p /content/objects/birds
# !mkdir -p /content/objects/drones
# !mkdir -p /content/backgrounds
# Then move uploaded files to appropriate folders

#### üåê Google Colab Only: Download Dataset Generator

Download the required Python module from GitHub.

In [None]:
# üåê Google Colab Only
# Local users: Make sure dataset_generator.py is in the same folder as this notebook

import requests

# Replace this with the actual raw URL of your Python file
file_url = "https://raw.githubusercontent.com/one-ware/OneAI_demo_datasets/refs/heads/main/dataset_generator/dataset_generator.py"
file_name = "dataset_generator.py" # Name you want to save the file as

try:
    response = requests.get(file_url)
    response.raise_for_status()

    with open(file_name, 'w') as f:
        f.write(response.text)
    print(f"Successfully downloaded '{file_name}' from {file_url}")

except requests.exceptions.RequestException as e:
    print(f"Error downloading the file: {e}")
    print("Please ensure the URL is correct and accessible (e.g., raw.githubusercontent.com for GitHub files).")

## Import Required Functions

Run this cell to load the dataset generator functions.

In [2]:
from dataset_generator import (
    load_background_metadata, 
    preload_backgrounds, 
    get_image_dataset_from_folder, 
    create_dataset_from_generator
)

## 1. Configure Paths and Categories

### Object Image Requirements:
- Must be PNG format with **transparent background**
- Organize into separate folders by category
- Example structure:
  ```
  /objects/
    /birds/
      bird_01.png
      bird_02.png
    /drones/
      drone_01.png
      drone_02.png
  ```
  - start label ids with 1

### Background Image Requirements:
- PNG or JPG format
- All in one folder
- For video backgrounds: Use sequential naming (frame_0001.jpg, frame_0002.jpg, etc.)

In [None]:
# Define your object categories and adjust paths accordingly
categories = [
    {"name": "bird", "folder": "/content/bird", "class_id": 1},
    {"name": "drone", "folder": "/content/drone", "class_id": 2},
]

# Background folder
background_folder = "/content/bg"

# Output path
output_path = "./generated_dataset"

### Step 2: Configure & Generate

You can generate two types of datasets:

1. **Reference Dataset (Overlay Mode)**:
  - Generates **two images per sample**:
    - **Reference Image**: Background only
    - **Test Image**: Background with objects
  - Useful for tasks like **change detection** or **object tracking**.
  - Set the variable `is_overlay.value` to `True` to enable this mode.

2. **Single Detection Dataset**:
  - Generates **one image per sample**:
    - **Detection Image**: Background with objects
  - Useful for tasks like **object detection**.
  - Set the variable `is_overlay.value` to `False` to use this mode.

**How to Choose:**
- Simply set `is_overlay` to `True` or `False` below to switch between these modes.

In [None]:
# Set if reference-test image dataset or single image detection dataset
is_overlay = False  # Set to True for overlay mode, False for single image mode

### 3. Set Parameters

Adjust these parameters to control how objects are placed on the background.

In [None]:
widget_config = {
    "Object Parameters": [
        {"type": "FloatRangeSlider", "name": "Object Size (cm)", "key": "object_size", "value": [11.0, 15.0], "min": 1.0, "max": 50.0, "step": 0.5, "description": "Minimum and maximum size of objects in centimeters."},
        {"type": "FloatRangeSlider", "name": "Rotation (¬∞)", "key": "rotation", "value": [-1.0, 1.0], "min": -180.0, "max": 180.0, "step": 1.0, "description": "Range of rotation angles for objects in degrees (e.g., -1¬∞ to 1¬∞ for slight rotation)."},
        {"type": "IntRangeSlider", "name": "Objects/Image", "key": "num_objects", "value": [1, 4], "min": 1, "max": 20, "step": 1, "description": "Number of objects per image (e.g., 1 to 4 objects per image)."},
        {"type": "FloatSlider", "name": "Min Distance (cm)", "key": "min_distance", "value": 2.0, "min": 0.0, "max": 10.0, "step": 0.5, "description": "Minimum distance between objects in centimeters to prevent crowding."},
        {"type": "Checkbox", "name": "Allow Overlap", "key": "allow_overlap", "value": False, "description": "Whether objects are allowed to overlap."},
        {"type": "FloatSlider", "name": "Max Overlap (%)", "key": "max_overlap", "value": 0.1, "min": 0.0, "max": 1.0, "step": 0.01, "description": "Maximum allowable overlap between objects as a percentage of their area (0.0 to 1.0)."},
        {"type": "Text", "name": "Category Weights", "key": "category_weights", "value": "1.0, 1.0", "description": "Comma-separated weights for each category (e.g., '1.0, 2.0' means category 2 is twice as likely to appear). Leave empty for equal weights."}
    ],
    "Background Parameters": [
        {"type": "FloatText", "name": "BG Width (cm)", "key": "bg_width", "value": 150.0, "description": "Default width of the background in centimeters."},
        {"type": "FloatRangeSlider", "name": "Placement Area Width (%)", "key": "placement_area_width", "value": [0.05, 0.95], "min": 0.0, "max": 1.0, "step": 0.01, "description": "Placement area as percentages of the background width (e.g., 5% to 95%)."},
        {"type": "FloatRangeSlider", "name": "Placement Area Height (%)", "key": "placement_area_height", "value": [0.05, 0.95], "min": 0.0, "max": 1.0, "step": 0.01, "description": "Placement area as percentages of the background height (e.g., 5% to 95%)."},
        {"type": "Checkbox", "name": "Video Background", "key": "is_video_background", "value": True, "description": "Whether the background is a video with different reference and test frames. Reference images can be one frame before/after. Deactivate for same reference and test background."}
    ],
    "Placement Logic": [
        {"type": "FloatSlider", "name": "Edge Margin", "key": "edge_margin", "value": 0.1, "min": 0.0, "max": 0.5, "step": 0.05, "description": "Margin to keep objects away from the edges (0.0 = no margin, 0.5 = large margin)."},
        {"type": "FloatSlider", "name": "Center Bias", "key": "center_bias", "value": 0.3, "min": 0.0, "max": 1.0, "step": 0.1, "description": "Bias for placing objects closer to the center (0.0 = uniform, 1.0 = strongly favors center)."}
    ],
    "Blending and Visual Quality": [
        {"type": "FloatSlider", "name": "Blending Strength", "key": "object_blending_strength", "value": 0.0, "min": 0.0, "max": 1.0, "step": 0.05, "description": "Strength of blending at object edges (0.0 = sharp edges, 1.0 = heavy blurring)."},
        {"type": "IntText", "name": "Output Channels", "key": "output_channels", "value": 3, "description": "Number of color channels in output images (3 for RGB, 1 for grayscale)."}
    ],
    "Output Parameters": [
        {"type": "IntText", "name": "Width (px)", "key": "output_width", "value": 1080, "description": "Width of the generated images in pixels."},
        {"type": "IntText", "name": "Height (px)", "key": "output_height", "value": 720, "description": "Height of the generated images in pixels."},
        {"type": "IntText", "name": "Total Images", "key": "num_images", "value": 10, "description": "Total number of images to generate in the dataset."},
        {"type": "IntText", "name": "Max Labels", "key": "max_labels", "value": 10, "description": "Maximum number of label boxes per image."}
    ]
}

def create_widgets_from_config(config):
    widgets_dict = {}
    widget_groups = []
    
    for group_name, params in config.items():
        group_widgets = [widgets.HTML(f"<h3 style='margin-bottom: 15px; color: #333;'>{group_name}</h3>")]
        for param in params:
            widget_type = getattr(widgets, param["type"])
            # Set default step if not provided
            step = param.get("step", 1 if "Int" in param["type"] else 0.1)
            widget = widget_type(
                value=param["value"],
                description=param["name"],
                min=param.get("min"),
                max=param.get("max"),
                step=step,
                style={"description_width": "150px"},
                layout=widgets.Layout(width="90%"),
            )
            widgets_dict[param["key"]] = widget
            
            # Add description below the widget with better styling
            description_html = widgets.HTML(
                f"<div style='color: #666; font-size: 12px; margin-left: 160px; margin-top: -5px; margin-bottom: 5px; line-height: 1.4;'>"
                f"{param.get('description', '')}</div>"
            )
            
            # Combine widget and description in a VBox
            widget_with_desc = widgets.VBox(
                [widget, description_html],
                layout=widgets.Layout(margin="0 0 2px 0")
            )
            group_widgets.append(widget_with_desc)
        
        widget_groups.append(widgets.VBox(
            group_widgets, 
            layout=widgets.Layout(
                width="90%", 
                margin="5 0",
                padding="5px",
                border="1px solid #ddd",
                border_radius="5px"
            )
        ))
    
    return widgets_dict, widget_groups

# Create widgets
widgets_dict, widget_groups = create_widgets_from_config(widget_config)

# Display widget groups
for group in widget_groups:
    display(group)

output = widgets.Output()
save_button = widgets.Button(
    description="Save Settings",
    button_style="success",
    icon="check",
    layout=widgets.Layout(width="200px", height="40px", margin="10px 0")
)

def save_settings(b):
    with output:
        output.clear_output()
        print("‚úì Settings saved successfully!")

save_button.on_click(save_settings)

button_container = widgets.VBox(
    [save_button, output],
    layout=widgets.Layout(margin="20px 0")
)
display(button_container)

## 4. Preview

**Generate 3 sample images** to verify your settings before creating the full dataset.

The preview shows in overlay mode:
- Left: Background image (reference frame)
- Right: Background + objects (test frame) with bounding boxes
else:
- Image with objects and bounding boxes

**Check:**
- ‚úÖ Object sizes look realistic
- ‚úÖ Objects are placed in good positions
- ‚úÖ Bounding boxes are accurate
- ‚úÖ Class labels are correct

If something looks wrong, adjust parameters above and run this cell again.

In [None]:
# Load data
background_metadata = load_background_metadata(
    [background_folder],
    default_bg_width_cm=widgets_dict["bg_width"].value,
    default_placement_area_pct=(widgets_dict["placement_area_width"].value[0], widgets_dict["placement_area_height"].value[0], widgets_dict["placement_area_width"].value[1], widgets_dict["placement_area_height"].value[1])
)
background_images, bg_widths, bg_placement_areas = preload_backgrounds(background_metadata)

category_datasets = []
category_class_ids = []
for cat in categories:
    ds = get_image_dataset_from_folder(cat["folder"])
    category_datasets.append(ds)
    category_class_ids.append(cat["class_id"])

# Build parameters from widgets and configuration
params = {
    "rotation_range": widgets_dict["rotation"].value,
    "number_of_objects_per_image": widgets_dict["num_objects"].value,
    "min_object_distance_cm": widgets_dict["min_distance"].value,
    "allow_overlap": widgets_dict["allow_overlap"].value,
    "max_overlap_pct": widgets_dict["max_overlap"].value,
    "default_bg_width_cm": widgets_dict["bg_width"].value,
    "default_placement_area_pct": (
        widgets_dict["placement_area_width"].value[0], 
        widgets_dict["placement_area_height"].value[0], 
        widgets_dict["placement_area_width"].value[1], 
        widgets_dict["placement_area_height"].value[1]
    ),
    "is_video_background": widgets_dict["is_video_background"].value,
    "edge_avoidance": widgets_dict["edge_margin"].value,
    "prefer_center": widgets_dict["center_bias"].value,
    "output_width": widgets_dict["output_width"].value,
    "output_height": widgets_dict["output_height"].value,
    "num_images": widgets_dict["num_images"].value,
    "max_labels": widgets_dict["max_labels"].value,
    "is_overlay": is_overlay,
    "object_blending_strength": widgets_dict["object_blending_strength"].value,
    "output_channels": widgets_dict["output_channels"].value
}

# Set min_scale_cm and max_scale_cm based on the object_size widget
params["min_scale_cm"] = widgets_dict["object_size"].value[0]
params["max_scale_cm"] = widgets_dict["object_size"].value[1]

# Validate parameters
if params["min_scale_cm"] >= params["max_scale_cm"]:
    print(f"Warning: min_scale_cm ({params['min_scale_cm']}) >= max_scale_cm ({params['max_scale_cm']})")
    print("Setting max_scale_cm to min_scale_cm + 1.0")
    params["max_scale_cm"] = params["min_scale_cm"] + 1.0

if params["number_of_objects_per_image"][0] > params["number_of_objects_per_image"][1]:
    print(f"Warning: Invalid object range, fixing...")
    params["number_of_objects_per_image"] = (1, max(2, params["number_of_objects_per_image"][1]))

# Parse category_weights from text input
category_weights_str = widgets_dict["category_weights"].value.strip()
if category_weights_str:
    try:
        category_weights = [float(w.strip()) for w in category_weights_str.split(",")]
        if len(category_weights) != len(categories):
            print(f"Warning: Number of weights ({len(category_weights)}) doesn't match number of categories ({len(categories)}). Using equal weights.")
            category_weights = None
    except ValueError:
        print("Warning: Invalid category weights format. Using equal weights.")
        category_weights = None
else:
    category_weights = None

# Check channel corretness
if params["output_channels"] not in [1, 3]:
    print(f"Warning: output_channels ({params['output_channels']}) is not 1 or 3. Setting to 3.")
    params["output_channels"] = 3

params["category_weights"] = category_weights

# Generate preview
preview_dataset = create_dataset_from_generator(
    overlay=params['is_overlay'],
    params=params,
    background_images=background_images,
    bg_widths=bg_widths,
    bg_placement_areas=bg_placement_areas,
    category_datasets=category_datasets,
    category_class_ids=category_class_ids,
    num_samples=3
)

# Display
display_images = 2 if params['is_overlay'] else 1
fig, axes = plt.subplots(3, display_images, figsize=(12, 16))

for idx, (images, labels) in enumerate(preview_dataset.take(3)):
    if params['is_overlay']:
        axes[idx, 0].imshow(images[0].numpy())
        axes[idx, 0].set_title(f'Background {idx+1}')
        axes[idx, 0].axis('off')

        axes[idx, 1].imshow(images[1].numpy())
        axes[idx, 1].set_title(f'With Objects {idx+1}')
        axes[idx, 1].axis('off')
    else:
        axes[idx].imshow(images.numpy())
        axes[idx].set_title(f'Sample {idx+1}')
        axes[idx].axis('off')
    
    # Draw bounding boxes (labels are now: exists, xmin, ymin, xmax, ymax, class_id in absolute pixels)
    for label in labels:
        exists, xmin, ymin, xmax, ymax, class_id = label.numpy()
        if exists > 0.5:
            # Coordinates are already in absolute pixels
            box_w = xmax - xmin
            box_h = ymax - ymin
            
            rect = plt.Rectangle((xmin, ymin), box_w, box_h,
                                linewidth=2, edgecolor='lime', facecolor='none')
            if params['is_overlay']:
                axes[idx, 1].add_patch(rect)
            else:
                axes[idx].add_patch(rect)
            
            # Class label lookup by class_id
            cat_name = next((c['name'] for c in categories if c['class_id'] == int(class_id)), f"Class {int(class_id)}")
            if params['is_overlay']:
                axes[idx, 1].text(
                    xmin, ymin - 5, cat_name,
                    color='lime', fontsize=10, weight='bold',
                    bbox=dict(boxstyle='round', facecolor='black', alpha=0.5)
                )
            else:
                axes[idx].text(
                        xmin, ymin - 5, cat_name,
                        color='lime', fontsize=10, weight='bold',
                        bbox=dict(boxstyle='round', facecolor='black', alpha=0.5)
                    )
plt.tight_layout()
plt.show()

## 5. Generate Full Dataset

**Only run this after the preview looks good!**

This will generate all images and may take a while depending on the number of images.

### Output Format:

**Reference Dataset (Overlay Mode)**:
- `image_XXXXX_temp.png`: Reference image (background only)
- `image_XXXXX_test.png`: Test image (background + objects)
- `image_XXXXX_test.txt`: Labels (bounding boxes)

**Single Detection Dataset**:
- `image_XXXXX.png`: Detection image (background + objects)
- `image_XXXXX.txt`: Labels (bounding boxes)

### Label Format (for OneAI):
Each line in `.txt` file: `xmin ymin xmax ymax class_id`
- Coordinates in absolute pixels
- class_id matches your category configuration

In [None]:
# Set the amount of images to generate
num_images = 10  # Adjust as needed

In [None]:
# Generate dataset
params["num_composite_images"] = num_images
print(f"Generating {params['num_composite_images']} images...")
full_dataset = create_dataset_from_generator(
    overlay=params['is_overlay'],
    params=params,
    background_images=background_images,
    bg_widths=bg_widths,
    bg_placement_areas=bg_placement_areas,
    category_datasets=category_datasets,
    category_class_ids=category_class_ids,
    num_samples=params['num_composite_images']
)

# Save to disk
export_path = Path(output_path)
export_path.mkdir(parents=True, exist_ok=True)

print(f"Saving to {export_path}...")
for idx, (images, labels) in enumerate(full_dataset):
    # Save images
    if params['is_overlay']:
        bg_img = (images[0].numpy() * 255).astype(np.uint8)
        obj_img = (images[1].numpy() * 255).astype(np.uint8)
        tf.io.write_file(str(export_path / f"image_{idx:05d}_temp.png"), tf.image.encode_png(bg_img))
        tf.io.write_file(str(export_path / f"image_{idx:05d}_test.png"), tf.image.encode_png(obj_img))
    else:
        img = (images.numpy() * 255).astype(np.uint8)
        tf.io.write_file(str(export_path / f"image_{idx:05d}.png"), tf.image.encode_png(img))
    
   
    # Save labels in corner format: xmin ymin xmax ymax class_id (absolute pixels, 1-based)
    if params['is_overlay']:
        label_file = export_path / f"image_{idx:05d}_test.txt"
    else:
        label_file = export_path / f"image_{idx:05d}.txt"
    with open(label_file, 'w') as f:
        for label in labels:
            exists, xmin, ymin, xmax, ymax, class_id = label.numpy()
            if exists > 0.5:
                # Write in format: xmin ymin xmax ymax class_id (absolute pixel coordinates)
                f.write(f"{int(xmin)} {int(ymin)} {int(xmax)} {int(ymax)} {int(class_id)}\n")
    
    if (idx + 1) % 50 == 0:
        print(f"  {idx + 1}/{params['num_composite_images']} images saved...")

print(f"‚úì Done! {params['num_composite_images']} images saved to {export_path}")
print(f"Label format: xmin ymin xmax ymax class_id (absolute pixel coordinates, 1-based class IDs)")