In [7]:
import os
import random
import shutil
from pathlib import Path
from typing import List, Tuple, Union
from tqdm import tqdm
from PIL import Image, ImageDraw
import json
from google import genai
from google.genai import types
import time

client = genai.Client(api_key='AIzaSyAcA_IVOgrqqc4AlMmOo_BWENzThzZ5IR0')


def detect_objects(
    image: Union[str, Image.Image], 
    prompt: str, 
    client, 
    model: str = "gemini-2.5-flash"
) -> List[List[int]]:
    """Detect objects using Gemini API"""
    if isinstance(image, str):
        img = Image.open(image)
    else:
        img = image
    
    config = types.GenerateContentConfig(
        response_mime_type="application/json"
    )
    
    response = client.models.generate_content(
        model=model,
        contents=[img, prompt],
        config=config
    )
    
    width, height = img.size
    time.sleep(5)
    bounding_boxes = json.loads(response.text)
    
    converted_bounding_boxes = []
    for bounding_box in bounding_boxes:
        abs_y1 = int(bounding_box["box_2d"][0] / 1000 * height)
        abs_x1 = int(bounding_box["box_2d"][1] / 1000 * width)
        abs_y2 = int(bounding_box["box_2d"][2] / 1000 * height)
        abs_x2 = int(bounding_box["box_2d"][3] / 1000 * width)
        converted_bounding_boxes.append([abs_x1, abs_y1, abs_x2, abs_y2])
    
    return converted_bounding_boxes


def visualize_bboxes(
    image: Union[str, Image.Image], 
    bboxes: List[List[int]], 
    save_path: str = None, 
    thickness: int = 3, 
    color: str = "red"
) -> Image.Image:
    """Visualize bounding boxes on image"""
    if isinstance(image, str):
        img = Image.open(image)
    else:
        img = image.copy()
    
    draw = ImageDraw.Draw(img)
    
    for bbox in bboxes:
        x1, y1, x2, y2 = bbox
        draw.rectangle([x1, y1, x2, y2], outline=color, width=thickness)
    
    if save_path:
        img.save(save_path)
    
    return img


def bbox_to_yolo(
    bbox: List[int], 
    img_width: int, 
    img_height: int
) -> Tuple[float, float, float, float]:
    """
    Convert absolute bbox coordinates to YOLO format
    
    Args:
        bbox: [x1, y1, x2, y2] in absolute coordinates
        img_width: Image width
        img_height: Image height
    
    Returns:
        (x_center, y_center, width, height) normalized to 0-1
    """
    x1, y1, x2, y2 = bbox
    
    # Calculate center point
    x_center = ((x1 + x2) / 2) / img_width
    y_center = ((y1 + y2) / 2) / img_height
    
    # Calculate width and height
    width = (x2 - x1) / img_width
    height = (y2 - y1) / img_height
    
    return x_center, y_center, width, height


def create_yolo_annotation(
    bboxes: List[List[int]], 
    class_id: int, 
    img_width: int, 
    img_height: int, 
    save_path: str
):
    """
    Create YOLO format annotation file
    
    Args:
        bboxes: List of bounding boxes [[x1, y1, x2, y2], ...]
        class_id: Class ID (folder index)
        img_width: Image width
        img_height: Image height
        save_path: Path to save annotation .txt file
    """
    with open(save_path, 'w') as f:
        for bbox in bboxes:
            x_center, y_center, width, height = bbox_to_yolo(bbox, img_width, img_height)
            f.write(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n")


def process_dataset(
    root_folder: str,
    detection_prompt: str,
    client,
    output_root: str = "./output",
    images_per_folder: int = 40,
    model: str = "gemini-2.5-flash",
    seed: int = None
):
    """
    Process entire dataset: select images, detect objects, create YOLO annotations
    
    Args:
        root_folder: Root folder containing subfolders with images
        detection_prompt: Prompt for object detection
        client: Google GenAI client
        output_root: Root folder for output
        images_per_folder: Number of images to select from each subfolder
        model: Gemini model name
        seed: Random seed for reproducibility
    """
    if seed is not None:
        random.seed(seed)
    
    root_path = Path(root_folder)
    
    # Get all subfolders and sort alphabetically
    subfolders = sorted([f for f in root_path.iterdir() if f.is_dir()])
    
    print(f"Found {len(subfolders)} subfolders")
    if len(subfolders) != 11:
        print(f"⚠️  Warning: Expected 11 subfolders, found {len(subfolders)}")
    
    # Create output directories
    output_path = Path(output_root)
    selected_path = output_path / "Selected"
    detections_path = output_path / "detections"
    annotations_path = output_path / "annotations"
    
    for path in [selected_path, detections_path, annotations_path]:
        path.mkdir(parents=True, exist_ok=True)
    
    # Supported image extensions
    image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff', '.webp'}
    
    # Calculate total images for progress bar
    total_images_to_process = 0
    folder_image_counts = {}
    
    for subfolder in subfolders:
        images = [
            f for f in subfolder.iterdir() 
            if f.is_file() and f.suffix.lower() in image_extensions
        ]
        folder_image_counts[subfolder] = images
        total_images_to_process += min(len(images), images_per_folder)
    
    print(f"\nTotal images to process: {total_images_to_process}\n")
    
    # Create overall progress bar
    overall_pbar = tqdm(total=total_images_to_process, desc="Overall Progress", position=0)
    
    # Statistics
    stats = {
        'total_processed': 0,
        'total_detections': 0,
        'errors': []
    }
    
    # Process each subfolder
    for class_id, subfolder in enumerate(subfolders):
        folder_name = subfolder.name
        print(f"\n📁 Processing folder {class_id}: {folder_name}")
        
        # Create output subfolders
        selected_subfolder = selected_path / folder_name
        detections_subfolder = detections_path / folder_name
        
        selected_subfolder.mkdir(exist_ok=True)
        detections_subfolder.mkdir(exist_ok=True)
        
        # Get images
        images = folder_image_counts[subfolder]
        
        print(f"   Found {len(images)} images")
        
        # Randomly select images
        if len(images) < images_per_folder:
            print(f"   ⚠️  Only {len(images)} images available, selecting all")
            selected_images = images
        else:
            selected_images = random.sample(images, images_per_folder)
        
        # Process each selected image
        for img_path in selected_images:
            try:
                # Copy to Selected folder
                selected_img_path = selected_subfolder / img_path.name
                shutil.copy2(img_path, selected_img_path)
                
                # Load image
                img = Image.open(img_path)
                img_width, img_height = img.size
                
                # Detect objects
                bboxes = detect_objects(img, detection_prompt, client, model)
                
                # Create visualization
                detection_img_path = detections_subfolder / img_path.name
                visualize_bboxes(img, bboxes, save_path=str(detection_img_path))
                
                # Create YOLO annotation
                annotation_filename = img_path.stem + '.txt'
                annotation_path = annotations_path / annotation_filename
                create_yolo_annotation(bboxes, class_id, img_width, img_height, str(annotation_path))
                
                # Update statistics
                stats['total_processed'] += 1
                stats['total_detections'] += len(bboxes)
                
            except Exception as e:
                error_msg = f"Error processing {folder_name}/{img_path.name}: {str(e)}"
                stats['errors'].append(error_msg)
                tqdm.write(f"   ❌ {error_msg}")
            
            finally:
                overall_pbar.update(1)
    
    overall_pbar.close()
    
    # Print summary
    print(f"\n{'='*60}")
    print(f"✅ Processing Complete!")
    print(f"{'='*60}")
    print(f"📊 Statistics:")
    print(f"   - Images processed: {stats['total_processed']}")
    print(f"   - Total detections: {stats['total_detections']}")
    print(f"   - Errors: {len(stats['errors'])}")
    print(f"\n📂 Output folders:")
    print(f"   - Selected images: {selected_path}")
    print(f"   - Detections: {detections_path}")
    print(f"   - YOLO annotations: {annotations_path}")
    
    # Create dataset.yaml for YOLO training
    yaml_path = output_path / "dataset.yaml"
    class_names = [f.name for f in subfolders]
    
    with open(yaml_path, 'w') as f:
        f.write(f"# YOLO Dataset Configuration\n")
        f.write(f"path: {output_path.absolute()}\n")
        f.write(f"train: Selected\n")
        f.write(f"val: Selected\n\n")
        f.write(f"# Number of classes\n")
        f.write(f"nc: {len(subfolders)}\n\n")
        f.write(f"# Class names\n")
        f.write(f"names:\n")
        for i, name in enumerate(class_names):
            f.write(f"  {i}: {name}\n")
    
    print(f"   - Dataset config: {yaml_path}")
    
    if stats['errors']:
        print(f"\n⚠️  Errors encountered:")
        for error in stats['errors'][:10]:  # Show first 10 errors
            print(f"   - {error}")
        if len(stats['errors']) > 10:
            print(f"   ... and {len(stats['errors']) - 10} more")
    
    print(f"\n{'='*60}\n")
    
    return stats


# Example usage
if __name__ == "__main__":
    from google import genai
    client = genai.Client(api_key='AIzaSyAcA_IVOgrqqc4AlMmOo_BWENzThzZ5IR0')
    
    # Process dataset
    stats = process_dataset(
        root_folder="./LCT_1280/separate",
        detection_prompt="Detect the instrument on the image",
        client=client,
        output_root="./gemini-labels-v6",
        images_per_folder=80,
        model="gemini-2.5-flash-lite",
        seed=1337  # For reproducibility
    )


Found 11 subfolders

Total images to process: 880



Overall Progress:   0%|                                 | 0/880 [00:00<?, ?it/s]


📁 Processing folder 0: 1 Отвертка «-»
   Found 247 images


Overall Progress:   9%|██                    | 80/880 [11:40<1:37:54,  7.34s/it]


📁 Processing folder 1: 10 Ключ рожковыйнакидной  ¾
   Found 425 images


Overall Progress:  18%|███▊                 | 160/880 [23:42<1:44:00,  8.67s/it]


📁 Processing folder 2: 11 Бокорезы
   Found 271 images


Overall Progress:  21%|████▍                | 186/880 [26:47<1:15:13,  6.50s/it]

   ❌ Error processing 11 Бокорезы/DSCN0950.JPG: [Errno 104] Connection reset by peer


Overall Progress:  27%|█████▋               | 240/880 [33:16<1:16:21,  7.16s/it]


📁 Processing folder 3: 2 Отвертка «+»
   Found 302 images


Overall Progress:  36%|███████▋             | 320/880 [43:12<1:08:43,  7.36s/it]


📁 Processing folder 4: 3 Отвертка на смещенный крест
   Found 275 images


Overall Progress:  45%|██████████▍            | 400/880 [52:44<57:12,  7.15s/it]


📁 Processing folder 5: 4 Коловорот
   Found 848 images


Overall Progress:  55%|██████████▎        | 480/880 [1:09:16<1:32:38, 13.90s/it]


📁 Processing folder 6: 5 Пассатижи контровочные
   Found 819 images


Overall Progress:  64%|█████████████▎       | 560/880 [1:21:23<39:19,  7.37s/it]


📁 Processing folder 7: 6 Пассатижи
   Found 303 images


Overall Progress:  73%|███████████████▎     | 640/880 [1:30:53<27:14,  6.81s/it]


📁 Processing folder 8: 7 Шэрница
   Found 627 images


Overall Progress:  82%|█████████████████▏   | 720/880 [1:40:19<18:04,  6.78s/it]


📁 Processing folder 9: 8 Разводной ключ
   Found 645 images


Overall Progress:  91%|███████████████████  | 800/880 [1:51:53<08:51,  6.65s/it]


📁 Processing folder 10: 9 Открывашка для банок с маслом
   Found 248 images


Overall Progress:  95%|███████████████████▉ | 838/880 [1:58:19<24:09, 34.52s/it]

   ❌ Error processing 9 Открывашка для банок с маслом/DSCN4645.JPG: Expecting value: line 1 column 1 (char 0)


Overall Progress:  98%|████████████████████▋| 865/880 [2:01:21<01:38,  6.57s/it]

   ❌ Error processing 9 Открывашка для банок с маслом/DSCN2813.JPG: Expecting value: line 1 column 1 (char 0)


Overall Progress: 100%|█████████████████████| 880/880 [2:03:07<00:00,  8.39s/it]


✅ Processing Complete!
📊 Statistics:
   - Images processed: 877
   - Total detections: 877
   - Errors: 3

📂 Output folders:
   - Selected images: gemini-labels-v6/Selected
   - Detections: gemini-labels-v6/detections
   - YOLO annotations: gemini-labels-v6/annotations
   - Dataset config: gemini-labels-v6/dataset.yaml

⚠️  Errors encountered:
   - Error processing 11 Бокорезы/DSCN0950.JPG: [Errno 104] Connection reset by peer
   - Error processing 9 Открывашка для банок с маслом/DSCN4645.JPG: Expecting value: line 1 column 1 (char 0)
   - Error processing 9 Открывашка для банок с маслом/DSCN2813.JPG: Expecting value: line 1 column 1 (char 0)





