# Object Detection V3: 3-Model Ensemble (Colab Ready)

This notebook implements the **3-Model Ensemble** evaluated on the PRISM-H dataset:
1. **Grounding DINO Swin-T**
2. **Grounding DINO Swin-B**
3. **OWLv2**

It fuses predictions using **Non-Maximum Suppression (NMS)**.

### Instructions
1. Run the **Setup** cell to install dependencies and download weights.
2. **IMPORTANT:** If you see an import error, try **Runtime > Restart Runtime** and run the cells again (skip the installation part if already done).
3. Run the **Model Loading** cell to initialize the ensemble.
4. Run the **Inference** cell to upload an image and see the result.

In [None]:
# --- SETUP & INSTALLATION ---
# This cell is designed to be robust for Google Colab.

import os
import sys

# 1. Install System Dependencies
!apt-get install -y libgl1-mesa-glx

# 2. Install Python Packages
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install -q transformers accelerate scipy safetensors
!pip install -q opencv-python matplotlib pycocotools

# 3. Install Grounding DINO (Robust Method)
if not os.path.exists("GroundingDINO"):
    !git clone https://github.com/IDEA-Research/GroundingDINO.git
    %cd GroundingDINO
    !pip install -q -e .
    %cd ..
else:
    print("GroundingDINO repo already exists. Ensuring it is installed...")
    %cd GroundingDINO
    !pip install -q -e .
    %cd ..

# 4. Download Weights
def download_weight(url, filename):
    if not os.path.exists(filename):
        print(f"Downloading {filename}...")
        !wget -q {url} -O {filename}

download_weight("https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth", "groundingdino_swint_ogc.pth")
download_weight("https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha2/groundingdino_swinb_cogcoor.pth", "groundingdino_swinb_cogcoor.pth")

print("Setup Complete!")

In [None]:
# --- IMPORTS & CONFIGURATION ---

import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from torchvision.ops import nms
from transformers import Owlv2Processor, Owlv2ForObjectDetection

# Grounding DINO Imports
# We remove the try-except block to see the actual error if it fails.
import sys
if "GroundingDINO" not in sys.path:
    sys.path.append("GroundingDINO")

from groundingdino.util.inference import load_model, predict as groundingdino_predict
import groundingdino.datasets.transforms as T

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Target Classes (PRISM-H)
TARGET_CLASSES = [
    "Sump", "Cement Tank", "Plastic Barrel", "Metal Drum", "Mud Pot", 
    "Plastic Bucket", "Stone Cistern", "Grinding-stone", "Cement Tanks", 
    "Water Puddle", "Plant-holder", "Tyre", "Solid Waste", "Other Container"
]
PROMPT_LIST = TARGET_CLASSES
PROMPT_STRING = " . ".join(TARGET_CLASSES) + " ."

BOX_THRESHOLD = 0.25
TEXT_THRESHOLD = 0.25

In [None]:
# --- MODEL LOADING ---

def init_detector(backend, device):
    print(f"Initializing {backend}...")
    if backend == "gdino_swint":
        return load_model("GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py", "groundingdino_swint_ogc.pth", device=device)
    elif backend == "gdino_swinb":
        return load_model("GroundingDINO/groundingdino/config/GroundingDINO_SwinB_cfg.py", "groundingdino_swinb_cogcoor.pth", device=device)
    elif backend == "owlv2":
        processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
        model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to(device)
        return (processor, model)
    elif backend == "ensemble_3":
        models = {
            "gdino_swint": init_detector("gdino_swint", device),
            "gdino_swinb": init_detector("gdino_swinb", device),
            "owlv2":       init_detector("owlv2", device),
        }
        return models
    else:
        raise ValueError(f"Unknown backend: {backend}")

print("Loading Ensemble Models (this may take a minute)...")
ensemble_models = init_detector("ensemble_3", device)
print("Models Loaded!")

In [None]:
# --- INFERENCE LOGIC ---

def run_single_detector(backend, model, img_rgb, prompt_text, prompt_list, threshold, device):
    h, w = img_rgb.shape[:2]

    if backend.startswith("gdino"):
        transform = T.Compose([
            T.RandomResize([800], max_size=1333),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])
        img_pil = Image.fromarray(img_rgb)
        img_tensor, _ = transform(img_pil, None)
        
        boxes, logits, phrases = groundingdino_predict(
            model=model,
            image=img_tensor,
            caption=prompt_text,
            box_threshold=threshold,
            text_threshold=TEXT_THRESHOLD,
            device=device
        )
        if len(boxes) > 0:
            # Convert cxcywh -> xyxy for NMS later
            boxes_xyxy = boxes * torch.Tensor([w, h, w, h])
            boxes_xyxy[:, :2] -= boxes_xyxy[:, 2:] / 2  # cx,cy -> x1,y1
            boxes_xyxy[:, 2:] += boxes_xyxy[:, :2]      # w,h -> x2,y2
            return boxes_xyxy.cpu().numpy(), logits.cpu().numpy(), phrases
        else:
            return np.array([]), np.array([]), []

    elif backend == "owlv2":
        processor, owl_model = model
        texts = [[f"a photo of a {t}" for t in prompt_list]]
        inputs = processor(text=texts, images=Image.fromarray(img_rgb), return_tensors="pt").to(device)
        
        with torch.no_grad():
            outputs = owl_model(**inputs)
            
        target_sizes = torch.tensor([[h, w]]).to(device)
        results = processor.post_process_object_detection(outputs, threshold=threshold, target_sizes=target_sizes)[0]
        
        boxes = results["boxes"]
        scores = results["scores"]
        labels = [prompt_list[i] for i in results["labels"].cpu().numpy()]
        return boxes.cpu().numpy(), scores.cpu().numpy(), labels

    return np.array([]), np.array([]), []

def run_ensemble(models, img_rgb, prompt_text, prompt_list, threshold, device):
    # Run all three
    b_T, s_T, l_T = run_single_detector("gdino_swint", models["gdino_swint"], img_rgb, prompt_text, prompt_list, threshold, device)
    b_B, s_B, l_B = run_single_detector("gdino_swinb", models["gdino_swinb"], img_rgb, prompt_text, prompt_list, threshold, device)
    b_O, s_O, l_O = run_single_detector("owlv2",       models["owlv2"],       img_rgb, prompt_text, prompt_list, threshold, device)

    # Collect
    all_boxes, all_scores, all_labels = [], [], []
    if len(b_T) > 0: all_boxes.append(b_T); all_scores.append(s_T); all_labels.extend(l_T)
    if len(b_B) > 0: all_boxes.append(b_B); all_scores.append(s_B); all_labels.extend(l_B)
    if len(b_O) > 0: all_boxes.append(b_O); all_scores.append(s_O); all_labels.extend(l_O)

    if not all_boxes:
        return [], [], []

    boxes_np = np.concatenate(all_boxes, axis=0)
    scores_np = np.concatenate(all_scores, axis=0)
    
    # NMS
    boxes_tensor = torch.tensor(boxes_np, dtype=torch.float32)
    scores_tensor = torch.tensor(scores_np, dtype=torch.float32)
    
    keep = nms(boxes_tensor, scores_tensor, iou_threshold=0.5)
    
    return boxes_tensor[keep].numpy(), scores_tensor[keep].numpy(), [all_labels[i] for i in keep]

def visualize(img_rgb, boxes, scores, labels):
    plt.figure(figsize=(12, 12))
    plt.imshow(img_rgb)
    ax = plt.gca()
    ax.set_axis_off()

    for box, score, label in zip(boxes, scores, labels):
        x1, y1, x2, y2 = box
        w, h = x2 - x1, y2 - y1
        rect = plt.Rectangle((x1, y1), w, h, fill=False, edgecolor='red', linewidth=2)
        ax.add_patch(rect)
        ax.text(x1, y1-5, f"{label}: {score:.2f}", color='white', fontsize=10, bbox=dict(facecolor='red', alpha=0.5))
    plt.show()

In [None]:
# --- RUN ON IMAGE ---

try:
    from google.colab import files
    uploaded = files.upload()
    filename = list(uploaded.keys())[0]
    
    img = cv2.imread(filename)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    print("Running Ensemble Detection...")
    boxes, scores, labels = run_ensemble(ensemble_models, img_rgb, PROMPT_STRING, PROMPT_LIST, BOX_THRESHOLD, device)
    
    print(f"Found {len(boxes)} objects.")
    visualize(img_rgb, boxes, scores, labels)
    
except ImportError:
    print("Not running in Colab or no file uploaded.")
except Exception as e:
    print(f"Error: {e}")