In [None]:
import cv2
import os
import shutil
from ultralytics import YOLO
from glob import glob
from tqdm import tqdm

In [None]:

# === Load YOLO model ===
model = YOLO("yolov8n.pt")  # Replace with your fine-tuned eye model

# === Define folders ===
input_base = "../data/filtered"
output_base = "../data/temp_split"
classes = ["healthy_eye", "infected_eye"]

# === Create output folders ===
for cls in classes:
    os.makedirs(os.path.join(output_base, cls), exist_ok=True)

In [None]:
# === Function to process and move/split images ===
def process_image(image_path, out_dir):
    img = cv2.imread(image_path)
    filename = os.path.splitext(os.path.basename(image_path))[0]

    results = model(img)[0]
    boxes = results.boxes.xyxy.cpu().numpy()
    num_boxes = len(boxes)

    if num_boxes == 1:
        x1, y1, x2, y2 = map(int, boxes[0])
        eye = img[y1:y2, x1:x2]
        out_path = os.path.join(out_dir, f"{filename}.jpg")
        cv2.imwrite(out_path, eye)

    elif num_boxes == 2:
        boxes = sorted(boxes, key=lambda b: b[0])  # sort left to right
        for i, box in enumerate(boxes):
            x1, y1, x2, y2 = map(int, box)
            eye = img[y1:y2, x1:x2]
            suffix = "left" if i == 0 else "right"
            out_path = os.path.join(out_dir, f"{filename}_{suffix}.jpg")
            cv2.imwrite(out_path, eye)

    else:
        print(f"⚠️ Skipped (0 or >2 eyes): {image_path}")
        return

    # Remove original after successful move
    os.remove(image_path)

# === Run the pipeline ===
for cls in classes:
    input_folder = os.path.join(input_base, cls)
    output_folder = os.path.join(output_base, cls)
    images = glob(os.path.join(input_folder, "*.jpg")) + glob(os.path.join(input_folder, "*.png"))

    print(f"\n📂 Processing {len(images)} images from: {input_folder}")

    for img_path in tqdm(images, desc=f"Splitting {cls}"):
        try:
            process_image(img_path, output_folder)
        except Exception as e:
            print(f"❌ Error with {img_path}: {e}")

In [None]:
for box in boxes:
    print(box)  # Bounding box
    print(results.names[int(results.boxes.cls[0])])  # Class name