In [None]:
!pip install -q ultralytics pillow matplotlib scikit-learn

import os
import cv2
import numpy as np
from ultralytics import YOLOWorld
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.cluster import DBSCAN

In [None]:
# STEP 1: Upload an image
try:
    from google.colab import files
    uploaded = files.upload()
    image_path = list(uploaded.keys())[0]
except:
    image_path = "your_image.jpg"  # fallback for local runs
    assert os.path.exists(image_path), "Image not found!"

#  STEP 2: Load YOLO-World model
model = YOLOWorld("yolov8m-worldv2.pt")  # Change to smaller model if needed

#  STEP 3: Get prompt
prompt = input("Enter your search prompt (e.g., 'person eating banana', 'car driving on road'  :  ").strip()

# - STEP 4: Detect objects with YOLO-World
model.set_classes([prompt])                                                                  # yolo world generates feauture vectors based on the prompt
results = model.predict(source=image_path, imgsz=736, conf=0.25, iou=0.6, verbose=False)

# Extract bounding boxes and centers
boxes = results[0].boxes.xyxy.cpu().numpy() if len(results) else np.empty((0,4))
confs = results[0].boxes.conf.cpu().numpy() if len(results) else np.empty((0,))
centers = np.column_stack(((boxes[:,0] + boxes[:,2]) / 2, (boxes[:,1] + boxes[:,3]) / 2)) if len(boxes) else np.empty((0,2))

#  STEP 5: Run DBSCAN to group detections
labels = []
if len(centers) > 0:
    eps_value = 100  # distance threshold in pixels adjust for your image size
    clustering = DBSCAN(eps=eps_value, min_samples=1).fit(centers)
    labels = clustering.labels_
else:
    labels = np.array([])

# STEP 6: Draw results
img = cv2.imread(image_path)
for idx, (box, conf) in enumerate(zip(boxes, confs)):
    x1, y1, x2, y2 = map(int, box)
    group_id = labels[idx] if len(labels) > 0 else -1
    color = (0, 255, 0) if group_id != -1 else (0, 0, 255)
    label_text = f"{prompt} {conf:.2f} | Group {group_id}"
    cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
    cv2.putText(img, label_text, (x1, y1 - 10),
                cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)

#  STEP 7: Save and display
os.makedirs("search_results", exist_ok=True)
output_img_path = os.path.join("search_results", "grouped.jpg")
cv2.imwrite(output_img_path, img)

plt.figure(figsize=(12, 8))
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
plt.axis("off")
plt.title(f"Detections & Groups for: {prompt}")
plt.show()

# STEP 8: Print group summary
if len(labels) > 0:
    for g_id in sorted(set(labels)):
        members = np.sum(labels == g_id)
        print(f"Group {g_id}: {members} objects")
else:
    print("No objects detected.") # for false queries
