In [6]:
import os
import random
import shutil
from pathlib import Path
from typing import List, Tuple, Union
from tqdm import tqdm
from PIL import Image, ImageDraw
import json
from google import genai
from google.genai import types
import time



def detect_objects(
    image: Union[str, Image.Image], 
    prompt: str, 
    client, 
    model: str = "gemini-2.5-flash"
) -> List[List[int]]:
    """Detect objects using Gemini API"""
    if isinstance(image, str):
        img = Image.open(image)
    else:
        img = image
    
    config = types.GenerateContentConfig(
        response_mime_type="application/json"
    )
    
    response = client.models.generate_content(
        model=model,
        contents=[img, prompt],
        config=config
    )
    
    width, height = img.size
    time.sleep(5)
    bounding_boxes = json.loads(response.text)
    
    converted_bounding_boxes = []
    for bounding_box in bounding_boxes:
        abs_y1 = int(bounding_box["box_2d"][0] / 1000 * height)
        abs_x1 = int(bounding_box["box_2d"][1] / 1000 * width)
        abs_y2 = int(bounding_box["box_2d"][2] / 1000 * height)
        abs_x2 = int(bounding_box["box_2d"][3] / 1000 * width)
        converted_bounding_boxes.append([abs_x1, abs_y1, abs_x2, abs_y2])
    
    return converted_bounding_boxes


def visualize_bboxes(
    image: Union[str, Image.Image], 
    bboxes: List[List[int]], 
    save_path: str = None, 
    thickness: int = 3, 
    color: str = "red"
) -> Image.Image:
    """Visualize bounding boxes on image"""
    if isinstance(image, str):
        img = Image.open(image)
    else:
        img = image.copy()
    
    draw = ImageDraw.Draw(img)
    
    for bbox in bboxes:
        x1, y1, x2, y2 = bbox
        draw.rectangle([x1, y1, x2, y2], outline=color, width=thickness)
    
    if save_path:
        img.save(save_path)
    
    return img


def bbox_to_yolo(
    bbox: List[int], 
    img_width: int, 
    img_height: int
) -> Tuple[float, float, float, float]:
    """
    Convert absolute bbox coordinates to YOLO format
    
    Args:
        bbox: [x1, y1, x2, y2] in absolute coordinates
        img_width: Image width
        img_height: Image height
    
    Returns:
        (x_center, y_center, width, height) normalized to 0-1
    """
    x1, y1, x2, y2 = bbox
    
    # Calculate center point
    x_center = ((x1 + x2) / 2) / img_width
    y_center = ((y1 + y2) / 2) / img_height
    
    # Calculate width and height
    width = (x2 - x1) / img_width
    height = (y2 - y1) / img_height
    
    return x_center, y_center, width, height


def create_yolo_annotation(
    bboxes: List[List[int]], 
    class_id: int, 
    img_width: int, 
    img_height: int, 
    save_path: str
):
    """
    Create YOLO format annotation file
    
    Args:
        bboxes: List of bounding boxes [[x1, y1, x2, y2], ...]
        class_id: Class ID (folder index)
        img_width: Image width
        img_height: Image height
        save_path: Path to save annotation .txt file
    """
    with open(save_path, 'w') as f:
        for bbox in bboxes:
            x_center, y_center, width, height = bbox_to_yolo(bbox, img_width, img_height)
            f.write(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n")


def process_dataset(
    root_folder: str,
    detection_prompt: str,
    client,
    output_root: str = "./output",
    images_per_folder: int = 40,
    model: str = "gemini-2.5-flash",
    seed: int = None
):
    """
    Process entire dataset: select images, detect objects, create YOLO annotations
    
    Args:
        root_folder: Root folder containing subfolders with images
        detection_prompt: Prompt for object detection
        client: Google GenAI client
        output_root: Root folder for output
        images_per_folder: Number of images to select from each subfolder
        model: Gemini model name
        seed: Random seed for reproducibility
    """
    if seed is not None:
        random.seed(seed)
    
    root_path = Path(root_folder)
    
    # Get all subfolders and sort alphabetically
    subfolders = sorted([f for f in root_path.iterdir() if f.is_dir()])
    
    print(f"Found {len(subfolders)} subfolders")
    if len(subfolders) != 11:
        print(f"⚠️  Warning: Expected 11 subfolders, found {len(subfolders)}")
    
    # Create output directories
    output_path = Path(output_root)
    selected_path = output_path / "Selected"
    detections_path = output_path / "detections"
    annotations_path = output_path / "annotations"
    
    for path in [selected_path, detections_path, annotations_path]:
        path.mkdir(parents=True, exist_ok=True)
    
    # Supported image extensions
    image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff', '.webp'}
    
    # Calculate total images for progress bar
    total_images_to_process = 0
    folder_image_counts = {}
    
    for subfolder in subfolders:
        images = [
            f for f in subfolder.iterdir() 
            if f.is_file() and f.suffix.lower() in image_extensions
        ]
        folder_image_counts[subfolder] = images
        total_images_to_process += min(len(images), images_per_folder)
    
    print(f"\nTotal images to process: {total_images_to_process}\n")
    
    # Create overall progress bar
    overall_pbar = tqdm(total=total_images_to_process, desc="Overall Progress", position=0)
    
    # Statistics
    stats = {
        'total_processed': 0,
        'total_detections': 0,
        'errors': []
    }
    
    # Process each subfolder
    for class_id, subfolder in enumerate(subfolders):
        folder_name = subfolder.name
        print(f"\n📁 Processing folder {class_id}: {folder_name}")
        
        # Create output subfolders
        selected_subfolder = selected_path / folder_name
        detections_subfolder = detections_path / folder_name
        
        selected_subfolder.mkdir(exist_ok=True)
        detections_subfolder.mkdir(exist_ok=True)
        
        # Get images
        images = folder_image_counts[subfolder]
        
        print(f"   Found {len(images)} images")
        
        # Randomly select images
        if len(images) < images_per_folder:
            print(f"   ⚠️  Only {len(images)} images available, selecting all")
            selected_images = images
        else:
            selected_images = random.sample(images, images_per_folder)
        
        # Process each selected image
        for img_path in selected_images:
            try:
                # Copy to Selected folder
                selected_img_path = selected_subfolder / img_path.name
                shutil.copy2(img_path, selected_img_path)
                
                # Load image
                img = Image.open(img_path)
                img_width, img_height = img.size
                
                # Detect objects
                bboxes = detect_objects(img, detection_prompt, client, model)
                
                # Create visualization
                detection_img_path = detections_subfolder / img_path.name
                visualize_bboxes(img, bboxes, save_path=str(detection_img_path))
                
                # Create YOLO annotation
                annotation_filename = img_path.stem + '.txt'
                annotation_path = annotations_path / annotation_filename
                create_yolo_annotation(bboxes, class_id, img_width, img_height, str(annotation_path))
                
                # Update statistics
                stats['total_processed'] += 1
                stats['total_detections'] += len(bboxes)
                
            except Exception as e:
                error_msg = f"Error processing {folder_name}/{img_path.name}: {str(e)}"
                stats['errors'].append(error_msg)
                tqdm.write(f"   ❌ {error_msg}")
            
            finally:
                overall_pbar.update(1)
    
    overall_pbar.close()
    
    # Print summary
    print(f"\n{'='*60}")
    print(f"✅ Processing Complete!")
    print(f"{'='*60}")
    print(f"📊 Statistics:")
    print(f"   - Images processed: {stats['total_processed']}")
    print(f"   - Total detections: {stats['total_detections']}")
    print(f"   - Errors: {len(stats['errors'])}")
    print(f"\n📂 Output folders:")
    print(f"   - Selected images: {selected_path}")
    print(f"   - Detections: {detections_path}")
    print(f"   - YOLO annotations: {annotations_path}")
    
    # Create dataset.yaml for YOLO training
    yaml_path = output_path / "dataset.yaml"
    class_names = [f.name for f in subfolders]
    
    with open(yaml_path, 'w') as f:
        f.write(f"# YOLO Dataset Configuration\n")
        f.write(f"path: {output_path.absolute()}\n")
        f.write(f"train: Selected\n")
        f.write(f"val: Selected\n\n")
        f.write(f"# Number of classes\n")
        f.write(f"nc: {len(subfolders)}\n\n")
        f.write(f"# Class names\n")
        f.write(f"names:\n")
        for i, name in enumerate(class_names):
            f.write(f"  {i}: {name}\n")
    
    print(f"   - Dataset config: {yaml_path}")
    
    if stats['errors']:
        print(f"\n⚠️  Errors encountered:")
        for error in stats['errors'][:10]:  # Show first 10 errors
            print(f"   - {error}")
        if len(stats['errors']) > 10:
            print(f"   ... and {len(stats['errors']) - 10} more")
    
    print(f"\n{'='*60}\n")
    
    return stats


# Example usage
if __name__ == "__main__":
    from google import genai
    client = genai.Client(api_key='AIzaSyAcA_IVOgrqqc4AlMmOo_BWENzThzZ5IR0')
    
    # Process dataset
    stats = process_dataset(
        root_folder="./LCT_1280/separate",
        detection_prompt="Detect the instruments on the image",
        client=client,
        output_root="./gemini-labels-v4",
        images_per_folder=70,
        model="gemini-2.5-flash-lite",
        seed=1337  # For reproducibility
    )


Found 11 subfolders

Total images to process: 440



Overall Progress:   0%|                                                                                                                                         | 0/440 [00:00<?, ?it/s]


📁 Processing folder 0: 1 Отвертка «-»
   Found 247 images


Overall Progress:   9%|███████████▋                                                                                                                    | 40/440 [04:29<44:06,  6.62s/it]


📁 Processing folder 1: 10 Ключ рожковыйнакидной  ¾
   Found 425 images


Overall Progress:  18%|██████████████████████▉                                                                                                       | 80/440 [11:01<1:35:56, 15.99s/it]


📁 Processing folder 2: 11 Бокорезы
   Found 271 images


Overall Progress:  27%|██████████████████████████████████▋                                                                                            | 120/440 [15:21<34:55,  6.55s/it]


📁 Processing folder 3: 2 Отвертка «+»
   Found 302 images


Overall Progress:  36%|██████████████████████████████████████████████▏                                                                                | 160/440 [19:45<31:58,  6.85s/it]


📁 Processing folder 4: 3 Отвертка на смещенный крест
   Found 275 images


Overall Progress:  45%|█████████████████████████████████████████████████████████▋                                                                     | 200/440 [24:09<26:28,  6.62s/it]


📁 Processing folder 5: 4 Коловорот
   Found 848 images


Overall Progress:  55%|█████████████████████████████████████████████████████████████████████▎                                                         | 240/440 [30:49<54:04, 16.22s/it]


📁 Processing folder 6: 5 Пассатижи контровочные
   Found 819 images


Overall Progress:  64%|████████████████████████████████████████████████████████████████████████████████▊                                              | 280/440 [35:18<17:15,  6.47s/it]


📁 Processing folder 7: 6 Пассатижи
   Found 303 images


Overall Progress:  73%|████████████████████████████████████████████████████████████████████████████████████████████▎                                  | 320/440 [39:45<14:11,  7.10s/it]


📁 Processing folder 8: 7 Шэрница
   Found 627 images


Overall Progress:  82%|███████████████████████████████████████████████████████████████████████████████████████████████████████▉                       | 360/440 [44:14<08:49,  6.61s/it]


📁 Processing folder 9: 8 Разводной ключ
   Found 645 images


Overall Progress:  91%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍           | 400/440 [48:42<04:28,  6.71s/it]


📁 Processing folder 10: 9 Открывашка для банок с маслом
   Found 248 images


Overall Progress: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 440/440 [53:32<00:00,  7.30s/it]


✅ Processing Complete!
📊 Statistics:
   - Images processed: 440
   - Total detections: 440
   - Errors: 0

📂 Output folders:
   - Selected images: gemini-labels-v4/Selected
   - Detections: gemini-labels-v4/detections
   - YOLO annotations: gemini-labels-v4/annotations
   - Dataset config: gemini-labels-v4/dataset.yaml







In [7]:
!python -c "import shutil, pathlib, os; src='./gemini-labels-v4/Selected/'; dst='./gemini-labels-v4/images/'; os.makedirs(dst, exist_ok=True); [shutil.copy2(p, dst) for p in pathlib.Path(src).rglob('*') if p.suffix.lower() in ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp', '.tiff']]"


In [9]:
import os
import shutil
from pathlib import Path
from sklearn.model_selection import train_test_split
from collections import defaultdict

# Configuration
IMAGE_DIR = "./gemini-labels-v4/yolo-dataset/images/"
LABEL_DIR = "./gemini-labels-v4/yolo-dataset/labels/"
TRAIN_RATIO = 0.8
RANDOM_STATE = 42

def get_classes_from_label(label_path):
    """Extract unique classes from a YOLO label file"""
    classes = set()
    try:
        with open(label_path, 'r') as f:
            for line in f:
                if line.strip():
                    class_id = int(line.split()[0])
                    classes.add(class_id)
    except Exception as e:
        print(f"Error reading {label_path}: {e}")
    return classes

# Create train/val directories
for split in ['train', 'val']:
    os.makedirs(os.path.join(IMAGE_DIR, split), exist_ok=True)
    os.makedirs(os.path.join(LABEL_DIR, split), exist_ok=True)

# Get all label files (excluding subdirectories)
label_files = [f for f in Path(LABEL_DIR).glob('*.txt') if f.is_file()]

# Build dataset with class information
dataset = []
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.JPG', '.JPEG', '.PNG']

for label_file in label_files:
    stem = label_file.stem
    
    # Find corresponding image file
    image_file = None
    for ext in image_extensions:
        potential_image = Path(IMAGE_DIR) / f"{stem}{ext}"
        if potential_image.exists():
            image_file = potential_image
            break
    
    if image_file:
        classes = get_classes_from_label(label_file)
        if classes:
            # Use minimum class ID as primary class for stratification
            primary_class = min(classes)
            dataset.append({
                'stem': stem,
                'image': image_file,
                'label': label_file,
                'class': primary_class
            })

if not dataset:
    print("No matching image-label pairs found!")
    exit(1)

# Extract data for stratification
stems = [item['stem'] for item in dataset]
classes = [item['class'] for item in dataset]

# Perform stratified split
train_stems, val_stems = train_test_split(
    stems,
    test_size=1-TRAIN_RATIO,
    random_state=RANDOM_STATE,
    stratify=classes
)

train_set = set(train_stems)
val_set = set(val_stems)

# Move files to appropriate directories
for item in dataset:
    stem = item['stem']
    split = 'train' if stem in train_set else 'val'
    
    # Move image
    dest_image = os.path.join(IMAGE_DIR, split, item['image'].name)
    shutil.move(str(item['image']), dest_image)
    
    # Move label
    dest_label = os.path.join(LABEL_DIR, split, item['label'].name)
    shutil.move(str(item['label']), dest_label)

# Print statistics
print(f"✓ Dataset split complete!")
print(f"  Training samples: {len(train_stems)}")
print(f"  Validation samples: {len(val_stems)}")
print(f"\nClass distribution:")

class_stats = defaultdict(lambda: {'train': 0, 'val': 0})
for item in dataset:
    split = 'train' if item['stem'] in train_set else 'val'
    class_stats[item['class']][split] += 1

for class_id in sorted(class_stats.keys()):
    stats = class_stats[class_id]
    print(f"  Class {class_id}: {stats['train']} train, {stats['val']} val")


✓ Dataset split complete!
  Training samples: 348
  Validation samples: 87

Class distribution:
  Class 0: 28 train, 7 val
  Class 1: 32 train, 8 val
  Class 2: 32 train, 8 val
  Class 3: 32 train, 8 val
  Class 4: 32 train, 8 val
  Class 5: 32 train, 8 val
  Class 6: 32 train, 8 val
  Class 7: 32 train, 8 val
  Class 8: 32 train, 8 val
  Class 9: 32 train, 8 val
  Class 10: 32 train, 8 val


In [2]:
from ultralytics import YOLO

# Load a COCO-pretrained YOLO11n model
model = YOLO("yolo11s.pt")

# Train with aggressive augmentations
results = model.train(
    data="./gemini-labels-v4/dataset.yaml",
    
    # Training params
    cache='ram',
    batch=4,
    epochs=300,
    imgsz=720,
    
    # Geometric augmentations (more aggressive)
    degrees=20.0,           # Rotation ±15° (default: 0.0)
    translate=0.2,          # Translation ±20% (default: 0.1)
    scale=0.8,              # Scale ±90% (default: 0.5) - increased from your 0.8
    shear=10.0,             # Shear ±10° (default: 0.0)
    perspective=0.001,      # Perspective transform (default: 0.0)
    flipud=0.5,             # Vertical flip 50% (default: 0.0)
    fliplr=0.5,             # Horizontal flip 50% (default: 0.5)
    
    # Mosaic & Mixing augmentations
    mosaic=1.0,             # Mosaic augmentation probability (default: 1.0)
    copy_paste=0.4,         # Copy-paste augmentation (increased from 0.4)
    cutmix=0.5,             # CutMix augmentation (kept)
    
    # Color augmentations
    hsv_h=0.02,             # HSV-Hue augmentation ±2% (default: 0.015)
    hsv_s=0.8,              # HSV-Saturation ±80% (default: 0.7)
    hsv_v=0.5,              # HSV-Value ±50% (default: 0.4)
    
    # Advanced augmentations
    auto_augment='randaugment',  # AutoAugment policy (default: None)    
    # Multi-scale
    
    # Optional: Close mosaic early for better convergence
    close_mosaic=10,        # Disable mosaic last 10 epochs (default: 10)
)


Ultralytics 8.3.204 🚀 Python-3.11.0 torch-2.7.0+cu126 CUDA:0 (NVIDIA GeForce RTX 3060 Laptop GPU, 5804MiB)
[34m[1mengine/trainer: [0magnostic_nms=False, amp=True, augment=False, auto_augment=randaugment, batch=4, bgr=0.0, box=7.5, cache=ram, cfg=None, classes=None, close_mosaic=10, cls=0.5, compile=False, conf=None, copy_paste=0.4, copy_paste_mode=flip, cos_lr=False, cutmix=0.5, data=./gemini-labels-v4/dataset.yaml, degrees=20.0, deterministic=True, device=None, dfl=1.5, dnn=False, dropout=0.0, dynamic=False, embed=None, epochs=300, erasing=0.4, exist_ok=False, fliplr=0.5, flipud=0.5, format=torchscript, fraction=1.0, freeze=None, half=False, hsv_h=0.02, hsv_s=0.8, hsv_v=0.5, imgsz=720, int8=False, iou=0.7, keras=False, kobj=1.0, line_width=None, lr0=0.01, lrf=0.01, mask_ratio=4, max_det=300, mixup=0.0, mode=train, model=yolo11s.pt, momentum=0.937, mosaic=1.0, multi_scale=False, name=train18, nbs=64, nms=False, opset=None, optimize=False, optimizer=auto, overlap_mask=True, patience=

In [5]:
!pip install -U ultralytics sahi

Collecting sahi
  Downloading sahi-0.11.36-py3-none-any.whl.metadata (19 kB)
Collecting fire (from sahi)
  Downloading fire-0.7.1-py3-none-any.whl.metadata (5.8 kB)
Collecting pybboxes==0.1.6 (from sahi)
  Downloading pybboxes-0.1.6-py3-none-any.whl.metadata (9.9 kB)
Collecting shapely>=2.0.0 (from sahi)
  Downloading shapely-2.1.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (6.8 kB)
Collecting terminaltables (from sahi)
  Downloading terminaltables-3.1.10-py2.py3-none-any.whl.metadata (3.5 kB)
Downloading sahi-0.11.36-py3-none-any.whl (111 kB)
Downloading pybboxes-0.1.6-py3-none-any.whl (24 kB)
Downloading shapely-2.1.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (3.1 MB)
[2K   [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m20.7 MB/s[0m eta [36m0:00:00[0m
Downloading fire-0.7.1-py3-none-any.whl (115 kB)
Downloading terminaltables-3.1.10-py2.py3-none-any.whl (15 kB)
Installing collected packages: termin