# LEGO Dataset Preprocessing - Complete Workflow

This notebook:
1. Takes original pile images from `pile_of_augmented_lego_pieces/images`
2. **Creates augmented versions of the pile images** (for better pile detection)
3. Crops individual LEGO pieces from the original pile images
4. Augments the cropped pieces (creates multiple versions)
5. Saves everything back to `pile_of_augmented_lego_pieces/images` and `/labels`

**Final result:** Your original folder will contain:
- Original pile images
- Augmented pile images (with noise, rotation, brightness variations)
- Augmented cropped individual pieces

This gives your model training data for both **pile detection** and **individual piece recognition**!

## Setup: Check GPU and Memory

In [None]:
import tensorflow as tf
print("GPU Available:", tf.config.list_physical_devices('GPU'))

GPU Available: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [None]:
# Check memory
!free -h

               total        used        free      shared  buff/cache   available
Mem:            52Gi       1.5Gi        48Gi       1.0Mi       3.4Gi        50Gi
Swap:             0B          0B          0B


In [None]:
!nvidia-smi

Thu Oct 23 16:16:27 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA L4                      Off |   00000000:00:03.0 Off |                    0 |
| N/A   53C    P8             17W /   72W |       3MiB /  23034MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

## Mount Google Drive

In [None]:
# Mount Google Drive
from google.colab import drive
import os
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


## Install Required Packages

In [None]:
!pip install albumentations opencv-python-headless -q

## Import Libraries

In [None]:
import albumentations as A
import cv2
import numpy as np
from pathlib import Path
import shutil
from tqdm import tqdm
import random

## Define Augmentation Pipeline

In [None]:
# Augmentation pipeline for individual LEGO pieces
# BALANCED VERSION - handles mixed quality (clean studio + noisy camera photos)
individual_piece_augmentation = A.Compose([
    # Rotation and flipping
    A.Rotate(limit=180, border_mode=cv2.BORDER_REPLICATE, p=0.8),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),

    # Lighting variations (moderate range for mixed sources)
    A.RandomBrightnessContrast(
        brightness_limit=(-0.1, 0.15),   # Slight darkening to moderate brightening
        contrast_limit=0.2,               # Moderate contrast changes
        p=0.6                             # 60% probability
    ),

    # Color variations (moderate)
    A.HueSaturationValue(
        hue_shift_limit=15,
        sat_shift_limit=20,
        val_shift_limit=15,
        p=0.4
    ),

    # Camera effects (MODERATE - won't over-augment noisy images)
    A.OneOf([
        A.MultiplicativeNoise(
            multiplier=(0.88, 1.12),     # Moderate noise
            per_channel=True,
            p=1.0
        ),
        A.ISONoise(
            color_shift=(0.01, 0.06),    # Moderate color shift
            intensity=(0.15, 0.45),      # Moderate intensity
            p=1.0
        ),
        A.GaussNoise(var_limit=(10.0, 40.0), p=1.0),  # Moderate Gaussian
    ], p=0.5),                            # 50% chance - balanced approach

    # Blur (light)
    A.OneOf([
        A.MotionBlur(blur_limit=5, p=1.0),
        A.GaussianBlur(blur_limit=(3, 5), p=1.0),
    ], p=0.3),                            # Only 30% - keeps sharpness

    # Lighting effects (subtle)
    A.RandomToneCurve(scale=0.06, p=0.25),

    # Compression (moderate)
    A.ImageCompression(
        quality_range=(70, 95),           # Good quality range
        p=0.4                             # Moderate probability
    ),
], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))

print("‚úÖ Balanced augmentation pipeline configured!")
print("   - Works for both clean studio AND noisy camera images")
print("   - Moderate augmentations prevent over-processing")
print("   - 50% noise chance - adds variety without overdoing it")
print("   - RandomShadow REMOVED to prevent overly dark pieces")

‚úÖ Balanced augmentation pipeline configured!
   - Works for both clean studio AND noisy camera images
   - Moderate augmentations prevent over-processing
   - 50% noise chance - adds variety without overdoing it
   - RandomShadow REMOVED to prevent overly dark pieces


  A.GaussNoise(var_limit=(10.0, 40.0), p=1.0),  # Moderate Gaussian


## Define Helper Functions

In [None]:
def parse_yolo_label(label_path):
    """Parse YOLO format label file with validation"""
    bboxes = []
    class_ids = []

    with open(label_path, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) >= 5:
                class_id = int(parts[0])
                x_center, y_center, width, height = map(float, parts[1:5])

                # Clip coordinates to valid range [0, 1]
                x_center = np.clip(x_center, 0.0, 1.0)
                y_center = np.clip(y_center, 0.0, 1.0)
                width = np.clip(width, 0.0, 1.0)
                height = np.clip(height, 0.0, 1.0)

                # Skip invalid bounding boxes (zero or tiny width/height)
                if width <= 0.001 or height <= 0.001:
                    continue

                # Ensure bbox doesn't go outside image bounds
                x_center = np.clip(x_center, width/2, 1.0 - width/2)
                y_center = np.clip(y_center, height/2, 1.0 - height/2)

                bboxes.append([x_center, y_center, width, height])
                class_ids.append(class_id)

    return bboxes, class_ids

def yolo_to_pixel_coords(bbox, img_width, img_height, padding=20):
    """Convert YOLO bbox to pixel coordinates with padding"""
    x_center, y_center, width, height = bbox

    # Convert to pixels
    x_center_px = x_center * img_width
    y_center_px = y_center * img_height
    width_px = width * img_width
    height_px = height * img_height

    # Add padding
    x1 = max(0, int(x_center_px - width_px/2 - padding))
    y1 = max(0, int(y_center_px - height_px/2 - padding))
    x2 = min(img_width, int(x_center_px + width_px/2 + padding))
    y2 = min(img_height, int(y_center_px + height_px/2 + padding))

    return x1, y1, x2, y2

print("‚úÖ Helper functions defined!")

‚úÖ Helper functions defined!


## Main Processing Function

In [None]:
def process_lego_dataset(base_path, num_augmentations=5, num_pile_augmentations=3, padding=20, split=''):
    """
    Complete pipeline:
    1. Augment original pile images
    2. Crop individual pieces from original piles
    3. Augment the cropped pieces
    4. Save everything back to original folder

    Args:
        base_path: Path to pile_of_augmented_lego_pieces folder
        num_augmentations: Number of augmented versions to create per cropped piece
        num_pile_augmentations: Number of augmented versions to create per pile image
        padding: Pixels to add around each piece when cropping
        split: Subfolder name ('' for root, 'train' for train folder)
    """
    base_path = Path(base_path)

    # Define directories
    image_dir = base_path / 'images' / split if split else base_path / 'images'
    label_dir = base_path / 'labels' / split if split else base_path / 'labels'

    # Verify directories exist
    if not image_dir.exists():
        print(f"‚ùå Image directory not found: {image_dir}")
        return

    if not label_dir.exists():
        print(f"‚ùå Label directory not found: {label_dir}")
        return

    # Get all original pile images (not augmented ones we'll create)
    # Only process original images without '_pile_aug' or '_piece' in the name
    all_images = list(image_dir.glob('*.jpg')) + list(image_dir.glob('*.png'))
    image_files = [img for img in all_images if '_pile_aug' not in img.stem and '_piece' not in img.stem]

    if not image_files:
        print(f"‚ùå No original images found in {image_dir}")
        return

    print(f"\n{'='*60}")
    print(f"LEGO Dataset Processing Pipeline")
    print(f"{'='*60}")
    print(f"Source folder: {base_path}")
    print(f"Original pile images: {len(image_files)}")
    print(f"Augmentations per pile image: {num_pile_augmentations}")
    print(f"Augmentations per cropped piece: {num_augmentations}")
    print(f"Padding around pieces: {padding}px")
    print(f"{'='*60}\n")

    total_pieces_created = 0
    total_augmentations_created = 0
    total_pile_augmentations_created = 0

    # ============================================
    # STEP 1: Augment Original Pile Images
    # ============================================
    print("\nüì∏ STEP 1: Augmenting original pile images...\n")

    for img_path in tqdm(image_files, desc="Augmenting pile images"):
        # Read image
        image = cv2.imread(str(img_path))
        if image is None:
            print(f"‚ö†Ô∏è Could not read image: {img_path.name}")
            continue

        image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Read corresponding label
        label_path = label_dir / f"{img_path.stem}.txt"
        if not label_path.exists():
            print(f"‚ö†Ô∏è No label found for: {img_path.name}")
            continue

        bboxes, class_ids = parse_yolo_label(label_path)

        if not bboxes:
            print(f"‚ö†Ô∏è No bounding boxes in label: {label_path.name}")
            continue

        # Create augmented versions of the pile image
        for pile_aug_idx in range(num_pile_augmentations):
            try:
                # Apply augmentation to the full pile image
                augmented = individual_piece_augmentation(
                    image=image_rgb,
                    bboxes=bboxes,
                    class_labels=class_ids
                )

                aug_image = augmented['image']
                aug_bboxes = augmented['bboxes']
                aug_labels = augmented['class_labels']

                # Skip if bboxes were lost during augmentation
                if not aug_bboxes:
                    continue

                # Generate unique filename for pile augmentation
                pile_aug_name = f"{img_path.stem}_pile_aug{pile_aug_idx:02d}"

                # Save augmented pile image
                output_img_path = image_dir / f"{pile_aug_name}{img_path.suffix}"
                cv2.imwrite(str(output_img_path), cv2.cvtColor(aug_image, cv2.COLOR_RGB2BGR))

                # Save augmented pile label
                output_label_path = label_dir / f"{pile_aug_name}.txt"
                with open(output_label_path, 'w') as f:
                    for bbox_aug, label_aug in zip(aug_bboxes, aug_labels):
                        f.write(f"{label_aug} {bbox_aug[0]} {bbox_aug[1]} {bbox_aug[2]} {bbox_aug[3]}\n")

                total_pile_augmentations_created += 1

            except Exception as e:
                print(f"‚ö†Ô∏è Pile augmentation failed for {img_path.name}_aug{pile_aug_idx}: {e}")
                continue

    print(f"\n‚úÖ Created {total_pile_augmentations_created} augmented pile images\n")

    # ============================================
    # STEP 2: Crop and Augment Individual Pieces
    # ============================================
    print("\n‚úÇÔ∏è STEP 2: Cropping and augmenting individual pieces...\n")

    # Process each original pile image for cropping
    for img_idx, img_path in enumerate(tqdm(image_files, desc="Processing pile images")):
        # Read image
        image = cv2.imread(str(img_path))
        if image is None:
            print(f"‚ö†Ô∏è Could not read image: {img_path.name}")
            continue

        img_height, img_width = image.shape[:2]

        # Read corresponding label
        label_path = label_dir / f"{img_path.stem}.txt"
        if not label_path.exists():
            print(f"‚ö†Ô∏è No label found for: {img_path.name}")
            continue

        bboxes, class_ids = parse_yolo_label(label_path)

        if not bboxes:
            print(f"‚ö†Ô∏è No bounding boxes in label: {label_path.name}")
            continue

        # Crop each piece from this image
        for piece_idx, (bbox, class_id) in enumerate(zip(bboxes, class_ids)):
            x1, y1, x2, y2 = yolo_to_pixel_coords(bbox, img_width, img_height, padding)

            # Crop the piece
            cropped = image[y1:y2, x1:x2]

            if cropped.size == 0:
                continue

            # Calculate new bbox for cropped piece (centered)
            crop_width = x2 - x1
            crop_height = y2 - y1
            orig_width_px = bbox[2] * img_width
            orig_height_px = bbox[3] * img_height

            new_x_center = 0.5
            new_y_center = 0.5
            new_width = orig_width_px / crop_width
            new_height = orig_height_px / crop_height

            # Convert cropped image to RGB for augmentation
            cropped_rgb = cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB)

            # Create augmented versions of this cropped piece
            for aug_idx in range(num_augmentations):
                try:
                    # Apply augmentation
                    augmented = individual_piece_augmentation(
                        image=cropped_rgb,
                        bboxes=[[new_x_center, new_y_center, new_width, new_height]],
                        class_labels=[class_id]
                    )

                    aug_image = augmented['image']
                    aug_bboxes = augmented['bboxes']
                    aug_labels = augmented['class_labels']

                    # Skip if bbox was lost during augmentation
                    if not aug_bboxes:
                        continue

                    # Generate unique filename
                    piece_name = f"{img_path.stem}_piece{piece_idx:04d}_aug{aug_idx:02d}"

                    # Save augmented image
                    output_img_path = image_dir / f"{piece_name}{img_path.suffix}"
                    cv2.imwrite(str(output_img_path), cv2.cvtColor(aug_image, cv2.COLOR_RGB2BGR))

                    # Save augmented label
                    output_label_path = label_dir / f"{piece_name}.txt"
                    with open(output_label_path, 'w') as f:
                        for bbox_aug, label_aug in zip(aug_bboxes, aug_labels):
                            f.write(f"{label_aug} {bbox_aug[0]} {bbox_aug[1]} {bbox_aug[2]} {bbox_aug[3]}\n")

                    total_augmentations_created += 1

                except Exception as e:
                    print(f"‚ö†Ô∏è Augmentation failed for {img_path.stem}_piece{piece_idx}_aug{aug_idx}: {e}")
                    continue

            total_pieces_created += 1

    # Print summary
    print(f"\n{'='*60}")
    print(f"‚úÖ Processing Complete!")
    print(f"{'='*60}")
    print(f"Original pile images: {len(image_files)}")
    print(f"Augmented pile images: {total_pile_augmentations_created}")
    print(f"Cropped pieces: {total_pieces_created}")
    print(f"Augmented cropped pieces: {total_augmentations_created}")
    print(f"\nüìä Total images in dataset now:")
    print(f"   Original piles: {len(image_files)}")
    print(f"   + Augmented piles: {total_pile_augmentations_created}")
    print(f"   + Augmented cropped pieces: {total_augmentations_created}")
    print(f"   = TOTAL: {len(image_files) + total_pile_augmentations_created + total_augmentations_created}")
    print(f"\nüìÅ All saved to: {base_path}")
    print(f"   Images: {image_dir}")
    print(f"   Labels: {label_dir}")
    print(f"{'='*60}\n")

print("‚úÖ Main processing function defined!")

‚úÖ Main processing function defined!


## Run the Complete Pipeline

This will:
1. **Augment the original pile images** (creates 3 versions with noise, rotation, brightness)
2. **Crop individual LEGO pieces** from the original pile images
3. **Augment the cropped pieces** (creates 5 versions per piece)
4. Save everything back to the same folder

**Your final dataset will contain:**
- Original pile images (for detecting pieces in cluttered scenes)
- Augmented pile images (more training variety for pile detection)
- Augmented individual pieces (for piece recognition)

**Note:** This will take some time depending on how many pieces are in your images!

In [None]:
# Set your base path
base_path = '/content/drive/MyDrive/lego-training/pile_of_augmented_lego_pieces'

# Run the complete pipeline
process_lego_dataset(
    base_path=base_path,
    num_augmentations=5,          # Create 5 augmented versions per cropped piece
    num_pile_augmentations=3,     # Create 3 augmented versions per pile image
    padding=20,                   # 20 pixels of padding around each piece
    split=''                      # Use '' if no train subfolder, 'train' if you have one
)


LEGO Dataset Processing Pipeline
Source folder: /Volumes/lego-Images/pile_of_augmented_lego_pieces
Original pile images: 2000
Augmentations per pile image: 3
Augmentations per cropped piece: 5
Padding around pieces: 20px


üì∏ STEP 1: Augmenting original pile images...



Augmenting pile images:  30%|‚ñà‚ñà‚ñâ       | 597/2000 [31:21<1:12:08,  3.09s/it]

‚ö†Ô∏è Could not read image: 550.png


Augmenting pile images:  43%|‚ñà‚ñà‚ñà‚ñà‚ñé     | 867/2000 [43:32<32:41,  1.73s/it]

‚ö†Ô∏è Could not read image: 1001.png


Augmenting pile images:  50%|‚ñà‚ñà‚ñà‚ñà‚ñâ     | 999/2000 [49:38<1:01:36,  3.69s/it]

‚ö†Ô∏è Could not read image: ._1646.png


Augmenting pile images:  87%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã | 1732/2000 [1:46:59<19:12,  4.30s/it]

‚ö†Ô∏è Could not read image: 1646.png


Augmenting pile images:  93%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé| 1864/2000 [1:59:01<09:41,  4.27s/it]

‚ö†Ô∏è Could not read image: 1094.png


Augmenting pile images:  97%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã| 1931/2000 [2:05:11<06:15,  5.44s/it]

‚ö†Ô∏è Could not read image: 1334.png


Augmenting pile images: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2000/2000 [2:11:14<00:00,  3.94s/it]



‚úÖ Created 5982 augmented pile images


‚úÇÔ∏è STEP 2: Cropping and augmenting individual pieces...



Processing pile images:  30%|‚ñà‚ñà‚ñâ       | 598/2000 [5:26:33<10:08:32, 26.04s/it]

‚ö†Ô∏è Could not read image: 550.png


Processing pile images:  43%|‚ñà‚ñà‚ñà‚ñà‚ñé     | 867/2000 [8:16:01<9:10:12, 29.14s/it] 

‚ö†Ô∏è Could not read image: 1001.png


Processing pile images:  50%|‚ñà‚ñà‚ñà‚ñà‚ñà     | 1000/2000 [9:43:53<8:03:11, 28.99s/it]

‚ö†Ô∏è Could not read image: ._1646.png


Processing pile images:  76%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå  | 1517/2000 [15:55:48<5:04:19, 37.80s/it]


KeyboardInterrupt: 

## Verify Results

In [None]:
# Verify the output
base_path = Path('/content/drive/MyDrive/lego-training/pile_of_augmented_lego_pieces')
image_dir = base_path / 'images'
label_dir = base_path / 'labels'

print(f"\nüìä Dataset Summary")
print(f"{'='*60}")

if image_dir.exists():
    all_images = list(image_dir.glob('*.jpg')) + list(image_dir.glob('*.png'))
    original_images = [img for img in all_images if '_piece' not in img.stem]
    augmented_images = [img for img in all_images if '_piece' in img.stem]

    print(f"Total images: {len(all_images)}")
    print(f"  - Original pile images: {len(original_images)}")
    print(f"  - Augmented cropped pieces: {len(augmented_images)}")
else:
    print("‚ùå Image directory not found")

if label_dir.exists():
    all_labels = list(label_dir.glob('*.txt'))
    print(f"\nTotal labels: {len(all_labels)}")
else:
    print("‚ùå Label directory not found")

print(f"{'='*60}\n")

#Proper split

In [None]:
import os
import shutil
from pathlib import Path
import random

print("üîÑ SPLIT USING REAL DRIVE PATH (NO SYMLINK)\n")
print("="*60)

# USE REAL DRIVE PATH ONLY!
dataset_path = Path('/content/drive/MyDrive/lego-training/pile_of_augmented_lego_pieces')
images_dir = dataset_path / 'images'
labels_dir = dataset_path / 'labels'

# Step 1: Clean up any existing train/val folders
print("üßπ Cleaning up old train/val folders...")
for folder in ['train', 'val']:
    for parent in [images_dir, labels_dir]:
        folder_path = parent / folder
        if folder_path.exists():
            try:
                shutil.rmtree(str(folder_path))
                print(f"  ‚úÖ Deleted {folder_path}")
            except Exception as e:
                print(f"  ‚ö†Ô∏è  {folder_path}: {e}")

# Step 2: Verify all files are in root
print("\nüìä Checking root files...")
image_files = list(images_dir.glob('*.jpg')) + list(images_dir.glob('*.png')) + list(images_dir.glob('*.jpeg'))
label_files = list(labels_dir.glob('*.txt'))

print(f"  Images: {len(image_files):,}")
print(f"  Labels: {len(label_files):,}")

if len(image_files) < 230000:
    print("\n‚ùå Not enough files! Aborting.")
else:
    # Step 3: Split
    print("\nüì¶ Splitting...")
    random.seed(42)
    random.shuffle(image_files)

    split_idx = int(len(image_files) * 0.8)
    train_imgs = image_files[:split_idx]
    val_imgs = image_files[split_idx:]

    print(f"  Train: {len(train_imgs):,}")
    print(f"  Val: {len(val_imgs):,}")

    # Create fresh directories
    (images_dir / 'train').mkdir()
    (images_dir / 'val').mkdir()
    (labels_dir / 'train').mkdir()
    (labels_dir / 'val').mkdir()

    # Move training
    print(f"\nüöö Moving training...")
    for i, img in enumerate(train_imgs):
        if i % 10000 == 0:
            print(f"  {i:,}")
        shutil.move(str(img), str(images_dir / 'train' / img.name))
        lbl = labels_dir / f"{img.stem}.txt"
        if lbl.exists():
            shutil.move(str(lbl), str(labels_dir / 'train' / lbl.name))

    # Move validation
    print(f"\nüöö Moving validation...")
    for i, img in enumerate(val_imgs):
        if i % 10000 == 0:
            print(f"  {i:,}")
        shutil.move(str(img), str(images_dir / 'val' / img.name))
        lbl = labels_dir / f"{img.stem}.txt"
        if lbl.exists():
            shutil.move(str(lbl), str(labels_dir / 'val' / lbl.name))

    # Verify
    print("\n" + "="*60)
    t_i = len(list((images_dir / 'train').glob('*')))
    v_i = len(list((images_dir / 'val').glob('*')))
    t_l = len(list((labels_dir / 'train').glob('*.txt')))
    v_l = len(list((labels_dir / 'val').glob('*.txt')))

    print(f"‚úÖ train: {t_i:,} imgs / {t_l:,} lbls")
    print(f"‚úÖ val:   {v_i:,} imgs / {v_l:,} lbls")

    if v_i > 40000 and v_l > 40000:
        print("\n‚úÖ‚úÖ‚úÖ SUCCESS! ‚úÖ‚úÖ‚úÖ")
    else:
        print("\n‚ùå PROBLEM!")

#Upload to GCS

In [None]:
# In Colab
from google.colab import auth
auth.authenticate_user()

# Set project
!gcloud config set project lego-training

# Upload from Drive to GCS (one-time, 20-30 min)
!gsutil -m cp -r /content/drive/MyDrive/lego-training/pile_of_augmented_lego_pieces gs://lego-dataset-di/

print("‚úÖ Uploaded to GCS!")

# Downloading dataset from GCS to Colab local storage

In [None]:

print("‚¨áÔ∏è Downloading dataset from GCS to Colab local storage\n")
print("="*60)

from google.colab import auth
auth.authenticate_user()

# Create directory structure FIRST
print("üìÅ Creating directory structure...")
!mkdir -p /content/dataset/pile_of_augmented_lego_pieces/images
!mkdir -p /content/dataset/pile_of_augmented_lego_pieces/labels

# Download with trailing slashes (tells gsutil these are directories)
print("\nüì¶ Downloading 232K files from GCS...")
print("   This may take 3-5 minutes...\n")

!gsutil -m rsync -r gs://lego-dataset-di/pile_of_augmented_lego_pieces/ /content/dataset/pile_of_augmented_lego_pieces/

print("\n‚úÖ Download complete!")

# Verification
import os
img_count = len([f for f in os.listdir('/content/dataset/pile_of_augmented_lego_pieces/images') if f.endswith('.png')])
lbl_count = len([f for f in os.listdir('/content/dataset/pile_of_augmented_lego_pieces/labels') if f.endswith('.txt')])
print(f"\nüìä Downloaded: {img_count:,} images, {lbl_count:,} labels")


#Cell 1: Configure Paths

In [None]:
# Update these paths to match your dataset location
SOURCE_IMAGES = '/content/dataset/pile_of_augmented_lego_pieces/images'
SOURCE_LABELS = '/content/dataset/pile_of_augmented_lego_pieces/labels'
OUTPUT_DIR = '/content/dataset/lego_split'

# Split ratio (0.8 = 80% train, 20% validation)
TRAIN_RATIO = 0.8

print(f"‚úÖ Source images: {SOURCE_IMAGES}")
print(f"‚úÖ Source labels: {SOURCE_LABELS}")
print(f"‚úÖ Output directory: {OUTPUT_DIR}")
print(f"‚úÖ Train/Val split: {TRAIN_RATIO*100:.0f}% / {(1-TRAIN_RATIO)*100:.0f}%")

#Cell 2: Dataset Splitter Function

In [None]:
import os
import shutil
from pathlib import Path
import random
from collections import defaultdict

def split_yolo_dataset_stratified(source_images_dir, source_labels_dir, output_dir, train_ratio=0.8, seed=42):
    """
    Split YOLO dataset into train/val sets with STRATIFIED class distribution
    Ensures ALL classes appear in both train and validation sets
    NO DATA LEAKAGE - each image appears in only one set
    """

    random.seed(seed)

    # Convert to Path objects
    source_images = Path(source_images_dir)
    source_labels = Path(source_labels_dir)
    output = Path(output_dir)

    # Create output directory structure
    train_images = output / 'train' / 'images'
    train_labels = output / 'train' / 'labels'
    val_images = output / 'val' / 'images'
    val_labels = output / 'val' / 'labels'

    for folder in [train_images, train_labels, val_images, val_labels]:
        folder.mkdir(parents=True, exist_ok=True)

    print("=" * 80)
    print("üéØ YOLO Dataset Splitter for LEGO Pieces (STRATIFIED)")
    print("=" * 80)

    # Get all image files
    image_extensions = {'.jpg', '.jpeg', '.png', '.bmp'}
    all_images = [f for f in source_images.iterdir()
                  if f.suffix.lower() in image_extensions]

    print(f"\nüìÅ Source: {source_images}")
    print(f"üìä Total images found: {len(all_images)}")

    # Map each image to its classes
    image_to_classes = {}
    class_to_images = defaultdict(set)
    valid_pairs = []
    missing_labels = []

    for img_path in all_images:
        label_path = source_labels / f"{img_path.stem}.txt"

        if label_path.exists():
            valid_pairs.append((img_path, label_path))
            image_classes = set()

            # Read label file to track class distribution
            with open(label_path, 'r') as f:
                for line in f:
                    parts = line.strip().split()
                    if parts:
                        class_id = int(parts[0])
                        image_classes.add(class_id)
                        class_to_images[class_id].add(img_path.stem)

            image_to_classes[img_path.stem] = image_classes
        else:
            missing_labels.append(img_path.name)

    print(f"‚úÖ Valid image-label pairs: {len(valid_pairs)}")

    if missing_labels:
        print(f"‚ö†Ô∏è  Missing labels for {len(missing_labels)} images")

    # LEGO class names
    class_names = {
        0: "Plates Special", 1: "Bars, Ladders and Fences", 2: "Bricks Special",
        3: "Plates", 4: "Technic Pins", 5: "Bricks Curved", 6: "Tiles",
        7: "Tiles Round and Curved", 8: "Technic Connectors", 9: "Technic Special",
        10: "Projectiles / Launchers", 11: "Tiles Special", 12: "Bricks Sloped",
        13: "Bricks", 14: "Hinges, Arms and Turntables", 15: "Plates Angled",
        16: "Plants and Animals", 17: "Plates Round Curved and Dishes",
        18: "Bricks Round and Cones", 19: "Technic Bricks", 20: "Technic Axles",
        21: "Technic Beams", 22: "Technic Bushes", 23: "Minifig Accessories",
        24: "Panels", 25: "Windows and Doors", 26: "Bricks Wedged",
        27: "Duplo, Quatro and Primo", 28: "Supports, Girders and Cranes",
        29: "Technic Beams Special", 30: "Transportation - Land", 31: "Technic Gears",
        32: "Technic Panels", 33: "Technic Steering, Suspension and Engine",
        34: "Wheels and Tyres", 35: "Large Buildable Figures", 36: "Pneumatics",
        37: "String, Bands and Reels", 38: "Transportation - Sea and Air",
        39: "Electronics", 40: "Energy Effects", 41: "Rock", 42: "Minifig Headwear",
        43: "Windscreens and Fuselage", 44: "Containers", 45: "Tools",
        46: "Minifigs", 47: "Minifig Lower Body", 48: "Baseplates",
        49: "Minifig Upper Body", 50: "Flags, Signs, Plastics and Cloth",
        51: "Tubes and Hoses"
    }

    # Display class distribution
    print(f"\nüìä Class Distribution:")
    sorted_classes = sorted(class_to_images.items(), key=lambda x: len(x[1]), reverse=True)
    for class_id, images in sorted_classes:
        count = len(images)
        name = class_names.get(class_id, "Unknown")
        print(f"   Class {class_id:2d}: {count:5d} images - {name}")

    # STRATIFIED SPLIT: Assign each image to train or val, ensuring class balance
    # FIXED: Prevents data leakage by assigning each image only once
    train_set = set()
    val_set = set()

    print(f"\nüîÑ Performing stratified split (preventing data leakage)...")

    # Sort classes by size (smallest first) to handle rare classes carefully
    sorted_class_list = sorted(class_to_images.items(), key=lambda x: len(x[1]))

    for class_id, image_stems in sorted_class_list:
        images_list = list(image_stems)
        random.shuffle(images_list)

        # Separate unassigned images from already assigned ones
        unassigned = [img for img in images_list if img not in train_set and img not in val_set]

        if not unassigned:
            continue  # All images of this class already assigned

        # Calculate target split for unassigned images
        target_train = max(1, int(len(unassigned) * train_ratio))
        target_val = len(unassigned) - target_train

        # Ensure at least 1 image in validation if possible
        if target_val == 0 and len(unassigned) > 1:
            target_train = len(unassigned) - 1
            target_val = 1

        # Assign unassigned images
        train_set.update(unassigned[:target_train])
        val_set.update(unassigned[target_train:])

    # Verify no overlap
    overlap = train_set & val_set
    if overlap:
        print(f"‚ùå ERROR: {len(overlap)} images in both sets! (This shouldn't happen)")
        print(f"   Example: {list(overlap)[:5]}")
    else:
        print(f"‚úÖ No data leakage: {len(train_set)} train, {len(val_set)} val (no overlap)")

    # Convert back to paths
    train_pairs = [(img_path, source_labels / f"{img_path.stem}.txt")
                   for img_path, _ in valid_pairs if img_path.stem in train_set]
    val_pairs = [(img_path, source_labels / f"{img_path.stem}.txt")
                 for img_path, _ in valid_pairs if img_path.stem in val_set]

    print(f"\nüìä Split Summary:")
    print(f"   Training set: {len(train_pairs)} images ({len(train_pairs)/len(valid_pairs)*100:.1f}%)")
    print(f"   Validation set: {len(val_pairs)} images ({len(val_pairs)/len(valid_pairs)*100:.1f}%)")

    # Copy files to train folder
    print(f"\nüìã Copying training files...")
    for i, (img_path, label_path) in enumerate(train_pairs):
        shutil.copy2(img_path, train_images / img_path.name)
        shutil.copy2(label_path, train_labels / label_path.name)
        if (i + 1) % 1000 == 0:
            print(f"   Copied {i + 1}/{len(train_pairs)} train files...")

    # Copy files to val folder
    print(f"\nüìã Copying validation files...")
    for i, (img_path, label_path) in enumerate(val_pairs):
        shutil.copy2(img_path, val_images / img_path.name)
        shutil.copy2(label_path, val_labels / label_path.name)
        if (i + 1) % 1000 == 0:
            print(f"   Copied {i + 1}/{len(val_pairs)} val files...")

    # Verify class distribution in splits
    print(f"\n‚úÖ Verifying stratified split...")
    train_classes = set()
    val_classes = set()
    train_class_counts = defaultdict(int)
    val_class_counts = defaultdict(int)

    for _, label_path in train_pairs:
        with open(label_path, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if parts:
                    cls = int(parts[0])
                    train_classes.add(cls)
                    train_class_counts[cls] += 1

    for _, label_path in val_pairs:
        with open(label_path, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if parts:
                    cls = int(parts[0])
                    val_classes.add(cls)
                    val_class_counts[cls] += 1

    print(f"   Training set has {len(train_classes)} classes")
    print(f"   Validation set has {len(val_classes)} classes")

    missing_in_val = train_classes - val_classes
    missing_in_train = val_classes - train_classes

    if missing_in_val:
        print(f"\n‚ö†Ô∏è  WARNING: {len(missing_in_val)} classes missing from validation:")
        for cls_id in sorted(missing_in_val):
            print(f"   - Class {cls_id}: {class_names.get(cls_id, 'Unknown')} ({train_class_counts[cls_id]} in train)")

    if missing_in_train:
        print(f"\n‚ö†Ô∏è  WARNING: {len(missing_in_train)} classes missing from training:")
        for cls_id in sorted(missing_in_train):
            print(f"   - Class {cls_id}: {class_names.get(cls_id, 'Unknown')} ({val_class_counts[cls_id]} in val)")

    if not missing_in_val and not missing_in_train:
        print(f"\n‚úÖ Perfect! All {len(class_to_images)} classes present in BOTH train and validation sets!")

    # Create updated data.yaml
    yaml_content = f"""# LEGO Pieces Dataset - Stratified Split
path: {output.absolute()}

train: train/images
val: val/images

nc: 52
names:
"""

    for i in range(52):
        yaml_content += f"  {i}: {class_names.get(i, 'Unknown')}\n"

    yaml_path = output / 'data.yaml'
    with open(yaml_path, 'w') as f:
        f.write(yaml_content)

    print(f"\n" + "=" * 80)
    print(f"‚úÖ Stratified dataset split complete!")
    print(f"=" * 80)
    print(f"\nüìÅ Output structure:")
    print(f"   {output}/")
    print(f"   ‚îú‚îÄ‚îÄ train/")
    print(f"   ‚îÇ   ‚îú‚îÄ‚îÄ images/ ({len(train_pairs)} files)")
    print(f"   ‚îÇ   ‚îî‚îÄ‚îÄ labels/ ({len(train_pairs)} files)")
    print(f"   ‚îú‚îÄ‚îÄ val/")
    print(f"   ‚îÇ   ‚îú‚îÄ‚îÄ images/ ({len(val_pairs)} files)")
    print(f"   ‚îÇ   ‚îî‚îÄ‚îÄ labels/ ({len(val_pairs)} files)")
    print(f"   ‚îî‚îÄ‚îÄ data.yaml")

    print(f"\nüéØ Ready for training!")

    return str(yaml_path)

# Cell 3: Run the Split

In [None]:
# Run the dataset split
new_yaml_path = split_yolo_dataset_stratified(
    source_images_dir=SOURCE_IMAGES,
    source_labels_dir=SOURCE_LABELS,
    output_dir=OUTPUT_DIR,
    train_ratio=TRAIN_RATIO,
    seed=42
)

print(f"\n‚ú® Your new data.yaml path: {new_yaml_path}")

# Cell 4: Verify the Split

In [None]:
import os
from collections import defaultdict, Counter

def verify_stratified_split(output_dir):
    """
    Comprehensive verification to ensure val/cls_loss = inf won't happen
    """

    print("=" * 80)
    print("üîç COMPREHENSIVE SPLIT VERIFICATION")
    print("=" * 80)

    # 1. Count files
    train_img_dir = os.path.join(output_dir, 'train', 'images')
    train_lbl_dir = os.path.join(output_dir, 'train', 'labels')
    val_img_dir = os.path.join(output_dir, 'val', 'images')
    val_lbl_dir = os.path.join(output_dir, 'val', 'labels')

    train_img_count = len([f for f in os.listdir(train_img_dir) if not f.startswith('.')])
    train_lbl_count = len([f for f in os.listdir(train_lbl_dir) if f.endswith('.txt')])
    val_img_count = len([f for f in os.listdir(val_img_dir) if not f.startswith('.')])
    val_lbl_count = len([f for f in os.listdir(val_lbl_dir) if f.endswith('.txt')])

    print("\nüìä File Count Verification:")
    print(f"  Training Set:")
    print(f"    Images: {train_img_count}")
    print(f"    Labels: {train_lbl_count}")
    print(f"    Match: {'‚úÖ' if train_img_count == train_lbl_count else '‚ùå MISMATCH!'}")

    print(f"\n  Validation Set:")
    print(f"    Images: {val_img_count}")
    print(f"    Labels: {val_lbl_count}")
    print(f"    Match: {'‚úÖ' if val_img_count == val_lbl_count else '‚ùå MISMATCH!'}")

    total_images = train_img_count + val_img_count
    print(f"\n  Total: {total_images} images")
    print(f"  Split: {train_img_count/total_images*100:.1f}% train / {val_img_count/total_images*100:.1f}% val")

    # 2. CRITICAL: Verify class distribution
    print("\n" + "=" * 80)
    print("üéØ CRITICAL: Class Distribution Analysis (prevents val/cls_loss = inf)")
    print("=" * 80)

    train_class_counts = Counter()
    val_class_counts = Counter()

    # Count classes in training
    for label_file in os.listdir(train_lbl_dir):
        if label_file.endswith('.txt'):
            with open(os.path.join(train_lbl_dir, label_file), 'r') as f:
                for line in f:
                    parts = line.strip().split()
                    if parts:
                        train_class_counts[int(parts[0])] += 1

    # Count classes in validation
    for label_file in os.listdir(val_lbl_dir):
        if label_file.endswith('.txt'):
            with open(os.path.join(val_lbl_dir, label_file), 'r') as f:
                for line in f:
                    parts = line.strip().split()
                    if parts:
                        val_class_counts[int(parts[0])] += 1

    all_classes = sorted(set(train_class_counts.keys()) | set(val_class_counts.keys()))

    print(f"\nüìà Class Distribution Summary:")
    print(f"  Classes in training: {len(train_class_counts)}")
    print(f"  Classes in validation: {len(val_class_counts)}")
    print(f"  Total unique classes: {len(all_classes)}")

    # CRITICAL CHECK: Classes missing from validation
    missing_from_val = set(train_class_counts.keys()) - set(val_class_counts.keys())
    missing_from_train = set(val_class_counts.keys()) - set(train_class_counts.keys())

    if missing_from_val:
        print(f"\n‚ùå CRITICAL ERROR: {len(missing_from_val)} classes MISSING from validation!")
        print(f"   This WILL cause val/cls_loss = inf!")
        print(f"   Missing classes: {sorted(missing_from_val)}")
        return False
    else:
        print(f"\n‚úÖ EXCELLENT: All {len(train_class_counts)} classes present in BOTH sets!")
        print(f"   val/cls_loss = inf will NOT occur! ‚úÖ")

    if missing_from_train:
        print(f"\n‚ö†Ô∏è  WARNING: {len(missing_from_train)} classes only in validation: {sorted(missing_from_train)}")

    # 3. Detailed class breakdown
    print(f"\nüìã Detailed Class Distribution (All {len(all_classes)} classes):")
    print(f"{'Class':>6} {'Train':>8} {'Val':>8} {'Total':>8} {'Val%':>6} {'Status':>10}")
    print("-" * 60)

    class_names = {
        0: "Plates Special", 1: "Bars, Ladders and Fences", 2: "Bricks Special",
        3: "Plates", 4: "Technic Pins", 5: "Bricks Curved", 6: "Tiles",
        7: "Tiles Round and Curved", 8: "Technic Connectors", 9: "Technic Special",
        10: "Projectiles / Launchers", 11: "Tiles Special", 12: "Bricks Sloped",
        13: "Bricks", 14: "Hinges, Arms and Turntables", 15: "Plates Angled",
        16: "Plants and Animals", 17: "Plates Round Curved and Dishes",
        18: "Bricks Round and Cones", 19: "Technic Bricks", 20: "Technic Axles",
        21: "Technic Beams", 22: "Technic Bushes", 23: "Minifig Accessories",
        24: "Panels", 25: "Windows and Doors", 26: "Bricks Wedged",
        27: "Duplo, Quatro and Primo", 28: "Supports, Girders and Cranes",
        29: "Technic Beams Special", 30: "Transportation - Land", 31: "Technic Gears",
        32: "Technic Panels", 33: "Technic Steering, Suspension and Engine",
        34: "Wheels and Tyres", 35: "Large Buildable Figures", 36: "Pneumatics",
        37: "String, Bands and Reels", 38: "Transportation - Sea and Air",
        39: "Electronics", 40: "Energy Effects", 41: "Rock", 42: "Minifig Headwear",
        43: "Windscreens and Fuselage", 44: "Containers", 45: "Tools",
        46: "Minifigs", 47: "Minifig Lower Body", 48: "Baseplates",
        49: "Minifig Upper Body", 50: "Flags, Signs, Plastics and Cloth",
        51: "Tubes and Hoses"
    }

    for cls in all_classes:
        train_cnt = train_class_counts.get(cls, 0)
        val_cnt = val_class_counts.get(cls, 0)
        total = train_cnt + val_cnt
        val_pct = (val_cnt / total * 100) if total > 0 else 0

        if val_cnt == 0:
            status = "‚ùå NO VAL"
        elif val_cnt < 5:
            status = "‚ö†Ô∏è FEW VAL"
        else:
            status = "‚úÖ OK"

        print(f"{cls:>6} {train_cnt:>8} {val_cnt:>8} {total:>8} {val_pct:>5.1f}% {status:>10}")

    # 4. Show data.yaml
    yaml_path = os.path.join(output_dir, 'data.yaml')
    print(f"\nüìÑ data.yaml preview:")
    print("=" * 60)
    with open(yaml_path, 'r') as f:
        lines = f.readlines()
        for i, line in enumerate(lines):
            if i < 10 or i >= len(lines) - 5:
                print(line.rstrip())
            elif i == 10:
                print("  ... (classes 10-49 omitted) ...")
    print("=" * 60)

    # 5. Final verdict
    print("\n" + "=" * 80)
    if missing_from_val:
        print("‚ùå VERIFICATION FAILED: Some classes missing from validation")
        print("   val/cls_loss = inf WILL occur with this split!")
        print("   DO NOT use this split for training!")
    else:
        print("‚úÖ VERIFICATION PASSED: All classes present in both train and val")
        print("   val/cls_loss = inf will NOT occur! Safe to train! üéâ")
    print("=" * 80)

    return len(missing_from_val) == 0


# Usage:
verify_stratified_split('/content/dataset/lego_split')

# Cell 1: Authenticate with Google Cloud

In [None]:
from google.colab import auth
auth.authenticate_user()

print("‚úÖ Authenticated with Google Cloud")

# Cell 2: Configure Upload **Settings**

In [None]:
# Configure your GCS settings
PROJECT_ID = 'lego-training-123456'  # ‚Üê Change this to your GCP project ID
BUCKET_NAME = 'lego-dataset-split'  # ‚Üê Change this to your bucket name
LOCAL_FOLDER = '/content/dataset/lego_split'  # Your local folder to upload
GCS_DESTINATION = 'lego_split_2'  # Destination folder name in GCS bucket

print(f"üì¶ Local folder: {LOCAL_FOLDER}")
print(f"‚òÅÔ∏è  GCS bucket: gs://{BUCKET_NAME}/{GCS_DESTINATION}")

# Cell 3: Upload to GCS (Method 1 - Using gsutil)

In [None]:
# Upload using gsutil (fastest for large datasets)
import os

if os.path.exists(LOCAL_FOLDER):
    print(f"üöÄ Starting upload to gs://{BUCKET_NAME}/{GCS_DESTINATION}...")
    print("This may take a while depending on dataset size...\n")

    # Upload with progress
    !gsutil -m cp -r {LOCAL_FOLDER} gs://{BUCKET_NAME}/{GCS_DESTINATION}

    print("\n‚úÖ Upload complete!")
    print(f"üìç Your dataset is now at: gs://{BUCKET_NAME}/{GCS_DESTINATION}")
else:
    print(f"‚ùå Error: Folder {LOCAL_FOLDER} not found!")