In [None]:
# ============================================
# CELL 1: Install Dependencies
# ============================================
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install -q opencv-python pillow matplotlib numpy
!pip install -q supervision transformers

# Clone and install SAM 2
!git clone https://github.com/facebookresearch/segment-anything-2.git
%cd segment-anything-2
!pip install -q -e .
%cd ..

# Clone and install GroundingDINO
!git clone https://github.com/IDEA-Research/GroundingDINO.git
%cd GroundingDINO
!pip install -q -e .
%cd ..

# Download SAM 2 checkpoints
!mkdir -p checkpoints
!wget -q https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt -P checkpoints/
!wget -q https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt -P checkpoints/

# Download GroundingDINO weights
!mkdir -p weights
!wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth -P weights/


In [None]:
# ============================================
# CELL 2: Import Libraries and Setup
# ============================================
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image
import os
import sys
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Add paths
sys.path.append('./segment-anything-2')
sys.path.append('./GroundingDINO')

from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam2.build_sam import build_sam2_video_predictor

# Import GroundingDINO
from groundingdino.util.inference import load_model, load_image, predict, annotate
import groundingdino.datasets.transforms as T
from torchvision.ops import box_convert

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
 ============================================
# CELL 3: Helper Functions
# ============================================
def download_sample_image():
    """Download a sample image for testing"""
    !wget -q https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/dog.jpg -O sample_image.jpg
    return "sample_image.jpg"

def download_sample_video():
    """Download a sample video for testing"""
    # Download a short sample video
    !wget -q https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/360/Big_Buck_Bunny_360_10s_1MB.mp4 -O sample_video.mp4
    return "sample_video.mp4"

def show_mask(mask, ax, color=[30/255, 144/255, 255/255, 0.6]):
    """Display mask on axis"""
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * np.array(color).reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_box(box, ax, label):
    """Display bounding box on axis"""
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', 
                               facecolor=(0,0,0,0), lw=2))
    ax.text(x0, y0-5, label, fontsize=12, color='white', 
            bbox=dict(boxstyle="round,pad=0.3", facecolor='green', alpha=0.7))

def process_grounding_dino_output(boxes, logits, phrases, image_shape):
    """Convert GroundingDINO output to SAM format"""
    h, w = image_shape[:2]
    boxes = boxes * torch.tensor([w, h, w, h])
    xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
    return xyxy, logits.numpy(), phrases


In [None]:
# ============================================
# CELL 4: Initialize Models
# ============================================
class TextPromptedSegmentor:
    def __init__(self, sam_checkpoint="checkpoints/sam2_hiera_large.pt",
                 grounding_dino_checkpoint="weights/groundingdino_swint_ogc.pth"):
        
        # Initialize SAM 2
        self.sam_model = build_sam2(
            config_file="segment-anything-2/sam2_configs/sam2_hiera_l.yaml",
            ckpt_path=sam_checkpoint,
            device=device
        )
        self.sam_predictor = SAM2ImagePredictor(self.sam_model)
        
        # Initialize GroundingDINO
        self.grounding_model = load_model(
            "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
            grounding_dino_checkpoint,
            device=device
        )
        
        print("✓ Models loaded successfully!")
    
    def detect_objects(self, image_path, text_prompt, box_threshold=0.35, text_threshold=0.25):
        """Detect objects using GroundingDINO based on text prompt"""
        image_source, image = load_image(image_path)
        
        # Predict with GroundingDINO
        boxes, logits, phrases = predict(
            model=self.grounding_model,
            image=image,
            caption=text_prompt,
            box_threshold=box_threshold,
            text_threshold=text_threshold,
            device=device
        )
        
        # Convert to pixel coordinates
        h, w, _ = image_source.shape
        boxes_xyxy, scores, labels = process_grounding_dino_output(
            boxes, logits, phrases, (h, w)
        )
        
        return image_source, boxes_xyxy, scores, labels
    
    def segment_image(self, image_path, text_prompt, box_threshold=0.35):
        """Complete pipeline: text prompt → detection → segmentation"""
        
        # Step 1: Detect objects with GroundingDINO
        image, boxes, scores, labels = self.detect_objects(
            image_path, text_prompt, box_threshold
        )
        
        if len(boxes) == 0:
            print(f"No objects detected for prompt: '{text_prompt}'")
            return image, None, None, None
        
        print(f"✓ Detected {len(boxes)} object(s): {labels}")
        
        # Step 2: Set image for SAM
        self.sam_predictor.set_image(image)
        
        # Step 3: Generate masks for each detected box
        masks = []
        for box in boxes:
            mask, score, _ = self.sam_predictor.predict(
                box=box,
                multimask_output=False
            )
            masks.append(mask[0])
        
        return image, boxes, labels, masks
    
    def visualize_results(self, image, boxes, labels, masks, text_prompt):
        """Visualize detection and segmentation results"""
        fig, axes = plt.subplots(1, 2, figsize=(15, 8))
        
        # Original image with boxes
        axes[0].imshow(image)
        axes[0].set_title(f"Detection: '{text_prompt}'")
        axes[0].axis('off')
        
        if boxes is not None:
            for box, label in zip(boxes, labels):
                show_box(box, axes[0], label)
        
        # Segmentation masks
        axes[1].imshow(image)
        axes[1].set_title(f"Segmentation: '{text_prompt}'")
        axes[1].axis('off')
        
        if masks is not None:
            for mask, label in zip(masks, labels):
                show_mask(mask, axes[1])
        
        plt.tight_layout()
        plt.show()

# Initialize the segmentor
segmentor = TextPromptedSegmentor()

In [None]:
 ============================================
# CELL 5: Single Image Segmentation Demo
# ============================================
# Download sample image
image_path = download_sample_image()

# Test different text prompts
test_prompts = [
    "dog",
    "dog face",
    "animal",
]

for prompt in test_prompts:
    print(f"\n📝 Processing prompt: '{prompt}'")
    image, boxes, labels, masks = segmentor.segment_image(
        image_path, 
        prompt,
        box_threshold=0.3
    )
    segmentor.visualize_results(image, boxes, labels, masks, prompt)


In [None]:
# ============================================
# CELL 6: Upload Custom Image (Optional)
# ============================================
from google.colab import files
import io

def upload_and_segment():
    """Upload your own image and segment with text prompt"""
    print("Upload an image:")
    uploaded = files.upload()
    
    for filename in uploaded.keys():
        # Save uploaded file
        image_data = uploaded[filename]
        image = Image.open(io.BytesIO(image_data))
        image.save('uploaded_image.jpg')
        
        # Get text prompt
        text_prompt = input("Enter text prompt for segmentation: ")
        
        # Process
        image, boxes, labels, masks = segmentor.segment_image(
            'uploaded_image.jpg',
            text_prompt,
            box_threshold=0.25
        )
        
        # Visualize
        segmentor.visualize_results(image, boxes, labels, masks, text_prompt)
        
        # Save results
        if masks is not None:
            result = image.copy()
            for mask in masks:
                colored_mask = np.zeros_like(image)
                colored_mask[:, :, 0] = mask * 255
                result = cv2.addWeighted(result, 0.7, colored_mask, 0.3, 0)
            
            cv2.imwrite('segmentation_result.jpg', cv2.cvtColor(result, cv2.COLOR_RGB2BGR))
            print("✓ Result saved as 'segmentation_result.jpg'")

# Uncomment to use:
# upload_and_segment()


In [None]:
# ============================================
# CELL 7: Video Object Segmentation
# ============================================
class VideoSegmentor:
    def __init__(self, sam_checkpoint="checkpoints/sam2_hiera_large.pt"):
        # Initialize video predictor
        self.video_predictor = build_sam2_video_predictor(
            config_file="segment-anything-2/sam2_configs/sam2_hiera_l.yaml",
            ckpt_path=sam_checkpoint,
            device=device
        )
        
        # Reuse GroundingDINO from image segmentor
        self.grounding_model = segmentor.grounding_model
        print("✓ Video segmentor ready!")
    
    def extract_frames(self, video_path, max_frames=30):
        """Extract frames from video"""
        cap = cv2.VideoCapture(video_path)
        frames = []
        frame_count = 0
        
        fps = cap.get(cv2.CAP_PROP_FPS)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        
        # Sample frames evenly
        sample_rate = max(1, total_frames // max_frames)
        
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            
            if frame_count % sample_rate == 0:
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frames.append(frame_rgb)
            
            frame_count += 1
            
            if len(frames) >= max_frames:
                break
        
        cap.release()
        print(f"✓ Extracted {len(frames)} frames")
        return frames, fps
    
    def detect_in_first_frame(self, frame, text_prompt, box_threshold=0.3):
        """Detect object in first frame using GroundingDINO"""
        # Save frame temporarily
        cv2.imwrite('temp_frame.jpg', cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
        
        # Detect with GroundingDINO
        _, boxes, scores, labels = segmentor.detect_objects(
            'temp_frame.jpg',
            text_prompt,
            box_threshold
        )
        
        return boxes, labels
    
    def segment_video(self, video_path, text_prompt, max_frames=30):
        """Complete video segmentation pipeline"""
        
        # Extract frames
        frames, fps = self.extract_frames(video_path, max_frames)
        
        if len(frames) == 0:
            print("Failed to extract frames")
            return None
        
        # Detect in first frame
        boxes, labels = self.detect_in_first_frame(frames[0], text_prompt)
        
        if len(boxes) == 0:
            print(f"No objects detected for: '{text_prompt}'")
            return None
        
        print(f"✓ Found {len(boxes)} object(s): {labels}")
        
        # Initialize video predictor with frames
        inference_state = self.video_predictor.init_state(video_path=video_path)
        
        # Add detected objects as prompts
        for idx, box in enumerate(boxes):
            _, out_obj_ids, out_mask_logits = self.video_predictor.add_new_points_or_box(
                inference_state=inference_state,
                frame_idx=0,
                obj_id=idx,
                box=box
            )
        
        # Propagate through video
        video_segments = {}
        for out_frame_idx, out_obj_ids, out_mask_logits in self.video_predictor.propagate_in_video(inference_state):
            video_segments[out_frame_idx] = {
                out_obj_id: out_mask_logits[i].cpu().numpy()
                for i, out_obj_id in enumerate(out_obj_ids)
            }
        
        return frames, video_segments, labels
    
    def visualize_video_results(self, frames, segments, labels, text_prompt, sample_frames=5):
        """Visualize video segmentation results"""
        
        # Sample frames to display
        total_frames = len(frames)
        indices = np.linspace(0, total_frames-1, sample_frames, dtype=int)
        
        fig, axes = plt.subplots(2, sample_frames, figsize=(20, 8))
        
        for col, idx in enumerate(indices):
            # Original frame
            axes[0, col].imshow(frames[idx])
            axes[0, col].set_title(f"Frame {idx}")
            axes[0, col].axis('off')
            
            # Segmented frame
            axes[1, col].imshow(frames[idx])
            
            # Add masks if available
            if idx in segments:
                for obj_id, mask in segments[idx].items():
                    mask_binary = (mask > 0.0).squeeze()
                    show_mask(mask_binary, axes[1, col])
            
            axes[1, col].set_title(f"Segmented")
            axes[1, col].axis('off')
        
        fig.suptitle(f"Video Segmentation: '{text_prompt}'", fontsize=16)
        plt.tight_layout()
        plt.show()
    
    def save_video_with_masks(self, frames, segments, output_path='output_video.mp4', fps=30):
        """Save video with mask overlays"""
        if not frames:
            return
        
        h, w = frames[0].shape[:2]
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
        
        for idx, frame in enumerate(frames):
            result = frame.copy()
            
            # Add masks if available
            if idx in segments:
                for obj_id, mask in segments[idx].items():
                    mask_binary = (mask > 0.0).squeeze().astype(np.uint8) * 255
                    mask_colored = np.zeros_like(frame)
                    mask_colored[:, :, 0] = mask_binary  # Red channel
                    result = cv2.addWeighted(result, 0.7, mask_colored, 0.3, 0)
            
            out.write(cv2.cvtColor(result, cv2.COLOR_RGB2BGR))
        
        out.release()
        print(f"✓ Video saved as '{output_path}'")

# Initialize video segmentor
video_seg = VideoSegmentor()

In [None]:
# ============================================
# CELL 8: Video Segmentation Demo
# ============================================
# Download sample video
video_path = download_sample_video()

# Segment video with text prompt
text_prompt = "rabbit"
print(f"\n🎥 Processing video with prompt: '{text_prompt}'")

frames, segments, labels = video_seg.segment_video(
    video_path,
    text_prompt,
    max_frames=20  # Limit frames for Colab memory
)

if frames is not None:
    # Visualize results
    video_seg.visualize_video_results(
        frames, segments, labels, text_prompt,
        sample_frames=5
    )
    
    # Save output video
    video_seg.save_video_with_masks(
        frames, segments,
        output_path='segmented_video.mp4',
        fps=15
    )

In [None]:
 ============================================
# CELL 9: Interactive Demo Functions
# ============================================
def interactive_image_segmentation():
    """Interactive function for custom prompts"""
    print("=== Interactive Image Segmentation ===")
    
    # Option to upload or use sample
    use_sample = input("Use sample image? (y/n): ").lower() == 'y'
    
    if use_sample:
        image_path = download_sample_image()
    else:
        print("Upload your image:")
        uploaded = files.upload()
        image_path = list(uploaded.keys())[0]
    
    while True:
        text_prompt = input("\nEnter text prompt (or 'quit' to exit): ")
        if text_prompt.lower() == 'quit':
            break
        
        image, boxes, labels, masks = segmentor.segment_image(
            image_path,
            text_prompt,
            box_threshold=0.25
        )
        
        segmentor.visualize_results(image, boxes, labels, masks, text_prompt)

def interactive_video_segmentation():
    """Interactive function for video segmentation"""
    print("=== Interactive Video Segmentation ===")
    
    # Option to upload or use sample
    use_sample = input("Use sample video? (y/n): ").lower() == 'y'
    
    if use_sample:
        video_path = download_sample_video()
    else:
        print("Upload your video:")
        uploaded = files.upload()
        video_path = list(uploaded.keys())[0]
    
    text_prompt = input("Enter text prompt for video segmentation: ")
    max_frames = int(input("Max frames to process (default 20): ") or "20")
    
    print(f"\nProcessing video...")
    frames, segments, labels = video_seg.segment_video(
        video_path,
        text_prompt,
        max_frames=max_frames
    )
    
    if frames is not None:
        video_seg.visualize_video_results(
            frames, segments, labels, text_prompt,
            sample_frames=min(5, len(frames))
        )
        
        save_output = input("Save output video? (y/n): ").lower() == 'y'
        if save_output:
            video_seg.save_video_with_masks(
                frames, segments,
                output_path='interactive_output.mp4'
            )

# Uncomment to run interactive demos:
# interactive_image_segmentation()
# interactive_video_segmentation()

In [None]:
# ============================================
# CELL 10: Quick Test Suite
# ============================================
print("\n" + "="*50)
print("🎉 Setup Complete! SAM 2 Text-Prompted Segmentation Ready")
print("="*50)
print("\nQuick Start Options:")
print("1. Run Cell 5 for automatic image segmentation demo")
print("2. Run Cell 8 for automatic video segmentation demo")
print("3. Uncomment functions in Cell 9 for interactive mode")
print("\nExample usage:")
print("  image, boxes, labels, masks = segmentor.segment_image('image.jpg', 'cat')")
print("  frames, segments, labels = video_seg.segment_video('video.mp4', 'person')")