<a href="https://colab.research.google.com/github/manva-soni-3rd/AIMS---project/blob/main/new_pipeline_khud_se_SAM%2BNMS%2BCLIP_.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ===================================================================
# Part 1: Setup and Installation (Run Once Per Session)
# ===================================================================
print(" PART 1: INSTALLING LIBRARIES & DOWNLOADING MODELS ")

# Install the necessary libraries
!pip install -q git+https://github.com/facebookresearch/segment-anything.git
!pip install -q transformers

# Download the SMALLER 'vit_b' SAM model checkpoint file for memory efficiency
!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

print("\n All libraries and the SAM model downloaded.")
print("You can now proceed to the next cell to load the models.")


We are using a pre-trained - SAM and CLIP model + Non-max supression.

In [None]:
# ===================================================================
# Part 2: Load Models into Memory (Run Once Per Session)
# ===================================================================
print("\n PART 2: LOADING MODELS ONTO GPU ")

import torch
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import requests
from io import BytesIO
from IPython.display import display
import numpy as np
from torchvision.ops import nms

# Setup device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load Segment Anything Model (SAM) ---
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
# Use memory-efficient settings for the mask generator
mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=16,
    crop_n_layers=1,
    crop_n_points_downscale_factor=2
)
print(" SAM model loaded.")

# loading clip
clip_model_name = "openai/clip-vit-base-patch32"
language_model = CLIPModel.from_pretrained(clip_model_name).to(device)
language_processor = CLIPProcessor.from_pretrained(clip_model_name)
print(" CLIP model loaded.")
print("\nModels are ready. You can now run the inference cell.")


In part 3 - aka implementation we are working in a 5 step process -
1. loading models and libraries - and text queries and image.
2. changing the size of the image to deal with the Cuda memory error - note if the resultion is high - the better the results.
3. hierarchy model - in simple terms - this takes in account all the segmented parts, and checks which mask is part of a bigger parent mask according to the query given.
4. The clip then assigns the score to different masks the sam proposes - and gives the best match to the query.
5. finally somewhat of a gate - that allows only a answer to pass - that matches or exceeds the threshold proposed by it. this ensures that all the negative cases do no return anything.

In [None]:
# ===================================================================
# Part 3: Run the Full Pipeline with Hierarchical Analysis
# ===================================================================
print("\nPART 3: RUNNING THE PIPELINE ")

# --- Install supervision library for drawing annotations ---
!pip install -q supervision

import supervision as sv

#  1. Define Image and Text Query ---
image_url = 'https://i.ytimg.com/vi/qMnw30lgRNw/maxresdefault.jpg' # Person on a bike
text_query = "mickey mouse" # A query that might focus on the head
print(f"Processing image: {image_url}")
print(f"Searching for: '{text_query}'")

try:
    response = requests.get(image_url)
    image_pil = Image.open(BytesIO(response.content)).convert("RGB")
    print(f"\n Image loaded successfully. Original size: {image_pil.size}")

    # resolution -
    max_dimension = 2000
    if max(image_pil.size) > max_dimension:
        image_pil.thumbnail((max_dimension, max_dimension))
        print(f"Resized image to: {image_pil.size} for processing.")

    # Check if the loaded image has valid dimensions before proceeding
    if image_pil.width > 0 and image_pil.height > 0:
        image_np = np.array(image_pil)

        #  2.SAM
        print("🔎 Finding all potential objects with SAM...")
        sam_results = mask_generator.generate(image_np)
        print(f"Found {len(sam_results)} potential object masks.")

        # 3.Hierarchy
        print(" Analyzing part-whole relationships...")
        proposals = []
        total_image_area = image_np.shape[0] * image_np.shape[1]

        for i, mask_data in enumerate(sam_results):
            bbox_xywh = mask_data['bbox']
            box_area = bbox_xywh[2] * bbox_xywh[3]
            if box_area / total_image_area > 0.95:
                continue

            box_xyxy = [bbox_xywh[0], bbox_xywh[1], bbox_xywh[0] + bbox_xywh[2], bbox_xywh[1] + bbox_xywh[3]]
            proposals.append({
                "box": np.array(box_xyxy),
                "mask": mask_data['segmentation'],
                "score": mask_data['predicted_iou'],
                "parent": None
            })
        print(f"Kept {len(proposals)} proposals after filtering.")

        for i, proposal_i in enumerate(proposals):
            for j, proposal_j in enumerate(proposals):
                if i == j: continue
                if (proposal_i["box"][0] >= proposal_j["box"][0] - 5 and
                    proposal_i["box"][1] >= proposal_j["box"][1] - 5 and
                    proposal_i["box"][2] <= proposal_j["box"][2] + 5 and
                    proposal_i["box"][3] <= proposal_j["box"][3] + 5):
                    if proposal_i["parent"] is None or np.prod(proposals[proposal_i["parent"]]["box"][2:] - proposals[proposal_i["parent"]]["box"][:2]) > np.prod(proposal_j["box"][2:] - proposal_j["box"][:2]):
                         proposal_i["parent"] = j

        # 4. CLIP
        print(" Scoring all proposals with CLIP...")
        if not proposals:
             print("\nNo initial proposals found by SAM.")
        else:
            cropped_images = [image_pil.crop(p['box']) for p in proposals]

            with torch.no_grad():
                text_inputs = language_processor(text=[text_query], return_tensors="pt", padding=True).to(device)
                text_features = language_model.get_text_features(**text_inputs)
                text_features /= text_features.norm(dim=-1, keepdim=True)

                inputs_images = language_processor(images=cropped_images, return_tensors="pt", padding=True).to(device)

                features_images = language_model.get_image_features(**inputs_images)
                features_images /= features_images.norm(dim=-1, keepdim=True)

                similarity_scores = (text_features @ features_images.T).squeeze(0)

            # 5. Confidence Gate
            best_detection_idx = torch.argmax(similarity_scores).item()
            best_score = similarity_scores[best_detection_idx].item()

            final_confidence_threshold = 0.25 # You can tune this value

            if best_score >= final_confidence_threshold:
                final_proposal_idx = best_detection_idx

                if proposals[best_detection_idx]["parent"] is not None:
                    print("Best match was a part of a larger object. Selecting the parent object.")
                    final_proposal_idx = proposals[best_detection_idx]["parent"]

                final_proposal = proposals[final_proposal_idx]
                final_cropped_image = image_pil.crop(final_proposal["box"])

                print(f"\n✅ Detection complete. Displaying the cropped region for '{text_query}':")
                print(f"(Confidence Score: {best_score:.2f})")
                display(final_cropped_image)
            else:
                print(f"\n No confident match found for the query.")
                print(f"(Highest confidence score was {best_score:.2f}, which is below the threshold of {final_confidence_threshold})")

    else:
        print(f" ERROR: Image from {image_url} is invalid or has zero dimensions.")

except Exception as e:
    print(f" An error occurred: {e}")
