# SAM2 Baseline Pipeline Demonstration

This notebook demonstrates the SAM2 (Segment Anything Model 2) baseline pipeline for automatic image segmentation of people and vehicles using text prompts.

## Project Overview
- **Goal**: Segment people and vehicles in CamVid dataset images
- **Baseline Method**: Text prompt-based segmentation with SAM2
- **Evaluation Metric**: Dice coefficient

---

## 1. Install and Import Dependencies

First, let's import all the necessary libraries and set up our environment.

In [None]:
import sys
import os
from pathlib import Path

# Add project root to path
project_root = Path("/anvil/projects/x-soc250046/x-sishraq/SegmentAnythingModel")
sys.path.append(str(project_root))
sys.path.append('/anvil/projects/x-soc250046/x-sishraq/SegmentAnythingModel/sam_env/lib/python3.12/site-packages')

# Core libraries
import numpy as np
import torch
import cv2
import matplotlib.pyplot as plt
from PIL import Image
import json
from tqdm import tqdm

# SAM2 imports
try:
    from sam2.build_sam import build_sam2
    from sam2.sam2_image_predictor import SAM2ImagePredictor
    print("✓ SAM2 imported successfully!")
except ImportError as e:
    print(f"✗ SAM2 import failed: {e}")

# Project modules
from src.baseline import SAM2BaselineSegmenter
from src.utils import dice_coefficient, load_image, load_mask, visualize_results
from src.evaluation import SegmentationEvaluator

print("All imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")

## 2. Download SAM2 Model Checkpoints

The SAM2 model checkpoints should already be downloaded during setup. Let's verify they exist.

In [None]:
# Check if checkpoint exists
checkpoint_path = project_root / "checkpoints" / "sam2_hiera_large.pt"

if checkpoint_path.exists():
    print(f"✓ SAM2 checkpoint found at: {checkpoint_path}")
    print(f"File size: {checkpoint_path.stat().st_size / (1024**3):.2f} GB")
else:
    print(f"✗ Checkpoint not found at: {checkpoint_path}")
    print("Please download it from: https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt")

# List available config files
config_dir = Path("/anvil/projects/x-soc250046/x-sishraq/SegmentAnythingModel/sam_env/lib/python3.12/site-packages/sam2")
config_files = list(config_dir.glob("*.yaml"))
print(f"\nAvailable SAM2 configs:")
for config in config_files:
    print(f"  - {config.name}")

## 3. Initialize SAM2 Model

Let's initialize our baseline SAM2 segmenter with the correct configuration.

In [None]:
# Initialize the baseline segmenter
print("Initializing SAM2 baseline segmenter...")
segmenter = SAM2BaselineSegmenter(
    model_cfg="sam2_hiera_l.yaml",
    device="cuda" if torch.cuda.is_available() else "cpu"
)

print("✓ SAM2 baseline segmenter initialized successfully!")
print(f"Device: {segmenter.device}")
print(f"Model config: {segmenter.model_cfg}")
print(f"Text prompts configured for:")
for class_name, prompts in segmenter.text_prompts.items():
    print(f"  {class_name}: {prompts}")

## 4. Load and Prepare Input Image

Let's create a demo image or load one from the dataset to test our segmentation pipeline.

In [None]:
# Check if dataset is available
dataset_path = project_root / "data" / "camvid"
demo_image_path = None

if dataset_path.exists():
    # Look for sample images in the dataset
    val_images_dir = dataset_path / "val" / "images"
    if val_images_dir.exists():
        image_files = list(val_images_dir.glob("*.png")) + list(val_images_dir.glob("*.jpg"))
        if image_files:
            demo_image_path = image_files[0]
            print(f"Using dataset image: {demo_image_path.name}")
        else:
            print("No images found in val/images directory")
    else:
        print("val/images directory not found")
else:
    print("Dataset not found - will create a synthetic demo image")

# Load or create demo image
if demo_image_path and demo_image_path.exists():
    # Load real image from dataset
    demo_image = load_image(demo_image_path)
    print(f"Loaded image shape: {demo_image.shape}")
else:
    # Create a synthetic demo image
    print("Creating synthetic demo image...")
    demo_image = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
    # Add some simple shapes to simulate objects
    cv2.rectangle(demo_image, (100, 200), (200, 400), (0, 255, 0), -1)  # Green rectangle (person)
    cv2.rectangle(demo_image, (300, 300), (500, 450), (255, 0, 0), -1)  # Blue rectangle (vehicle)
    print(f"Created synthetic image shape: {demo_image.shape}")

# Display the demo image
plt.figure(figsize=(10, 6))
plt.imshow(demo_image)
plt.title("Demo Image for Segmentation")
plt.axis('off')
plt.show()

## 5. Create Image Predictor

Now let's set up the SAM2 predictor with our demo image.

In [None]:
# Set the image in the predictor
if segmenter.predictor is not None:
    print("Setting image in SAM2 predictor...")
    segmenter.predictor.set_image(demo_image)
    print("✓ Image set successfully in predictor")
    
    # Get image embedding info
    print(f"Image shape processed: {demo_image.shape}")
    print(f"Predictor ready for inference")
else:
    print("✗ SAM2 predictor not available")

## 6. Generate Segmentation Masks

Let's test segmentation for both people and vehicles using our baseline text prompt method.

In [None]:
# Test segmentation for different classes
target_classes = ['person', 'vehicle']

print("Running baseline segmentation...")
segmentation_results = segmenter.segment_image(demo_image, target_classes)

print("\nSegmentation Results:")
for class_name, mask in segmentation_results.items():
    mask_area = np.sum(mask) / (mask.shape[0] * mask.shape[1]) * 100
    print(f"{class_name}:")
    print(f"  Mask shape: {mask.shape}")
    print(f"  Mask area: {mask_area:.2f}% of image")
    print(f"  Unique values: {np.unique(mask)}")

## 7. Visualize Results

Let's create comprehensive visualizations of our segmentation results.

In [None]:
# Create detailed visualization
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# Original image
axes[0, 0].imshow(demo_image)
axes[0, 0].set_title("Original Image", fontsize=14)
axes[0, 0].axis('off')

# Person mask
person_mask = segmentation_results.get('person', np.zeros_like(demo_image[:,:,0]))
axes[0, 1].imshow(person_mask, cmap='gray')
axes[0, 1].set_title("Person Mask (Baseline)", fontsize=14)
axes[0, 1].axis('off')

# Vehicle mask
vehicle_mask = segmentation_results.get('vehicle', np.zeros_like(demo_image[:,:,0]))
axes[0, 2].imshow(vehicle_mask, cmap='gray')
axes[0, 2].set_title("Vehicle Mask (Baseline)", fontsize=14)
axes[0, 2].axis('off')

# Combined mask
combined_mask = person_mask.astype(np.float32) + 2 * vehicle_mask.astype(np.float32)
axes[1, 0].imshow(combined_mask, cmap='viridis')
axes[1, 0].set_title("Combined Masks\n(Person=1, Vehicle=2)", fontsize=14)
axes[1, 0].axis('off')

# Overlay on original
overlay = demo_image.copy()
overlay[person_mask > 0] = [255, 0, 0]  # Red for person
overlay[vehicle_mask > 0] = [0, 0, 255]  # Blue for vehicle
axes[1, 1].imshow(overlay)
axes[1, 1].set_title("Overlay\n(Person=Red, Vehicle=Blue)", fontsize=14)
axes[1, 1].axis('off')

# Mask statistics
stats_text = f"""Segmentation Statistics:

Person Mask:
• Area: {np.sum(person_mask)/(person_mask.shape[0]*person_mask.shape[1])*100:.2f}%
• Non-zero pixels: {np.sum(person_mask > 0)}

Vehicle Mask:
• Area: {np.sum(vehicle_mask)/(vehicle_mask.shape[0]*vehicle_mask.shape[1])*100:.2f}%
• Non-zero pixels: {np.sum(vehicle_mask > 0)}

Total Segmented:
• {(np.sum(person_mask > 0) + np.sum(vehicle_mask > 0))/(person_mask.shape[0]*person_mask.shape[1])*100:.2f}% of image"""

axes[1, 2].text(0.05, 0.95, stats_text, transform=axes[1, 2].transAxes, 
                fontsize=10, verticalalignment='top', fontfamily='monospace',
                bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8))
axes[1, 2].set_xlim(0, 1)
axes[1, 2].set_ylim(0, 1)
axes[1, 2].axis('off')

plt.tight_layout()
plt.suptitle("SAM2 Baseline Segmentation Results", fontsize=16, y=1.02)
plt.show()

## 8. Batch Processing Multiple Images

Let's demonstrate how to process multiple images efficiently and evaluate performance.

In [None]:
# Setup evaluation if dataset is available
if dataset_path.exists():
    print("Setting up evaluation on CamVid dataset...")
    
    # Initialize evaluator
    evaluator = SegmentationEvaluator(str(dataset_path))
    
    # Define segmentation function for evaluator
    def baseline_segmentation_function(image):
        """Wrapper function for evaluation"""
        return segmenter.segment_image(image, target_classes=['person', 'vehicle'])
    
    # Run evaluation on a small subset for demo
    print("Running baseline evaluation on 5 images...")
    results = evaluator.evaluate_method(
        baseline_segmentation_function,
        method_name="SAM2_Baseline_TextPrompt",
        target_classes=['person', 'vehicle'],
        max_images=5,  # Limit for demo
        save_visualizations=True,
        viz_dir=str(project_root / "results" / "baseline" / "demo_viz")
    )
    
    # Display results
    print("\n" + "="*60)
    print("BASELINE EVALUATION RESULTS (5 images)")
    print("="*60)
    
    for class_name in ['person', 'vehicle', 'overall']:
        if class_name in results:
            metrics = results[class_name]
            print(f"\n{class_name.upper()}:")
            print(f"  Mean Dice Score: {metrics['mean_dice']:.4f} ± {metrics['std_dice']:.4f}")
            print(f"  Median Dice Score: {metrics['median_dice']:.4f}")
            print(f"  Range: [{metrics['min_dice']:.4f}, {metrics['max_dice']:.4f}]")
            print(f"  Images processed: {metrics['num_images']}")
    
    # Save results
    results_file = project_root / "results" / "baseline" / "demo_results.json"
    evaluator.save_results(results, str(results_file))
    
else:
    print("Dataset not available - creating synthetic batch processing demo...")
    
    # Create multiple synthetic images
    demo_images = []
    for i in range(3):
        # Create varied synthetic images
        img = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
        # Add different shaped objects
        cv2.rectangle(img, (50+i*50, 100+i*30), (150+i*50, 300+i*30), (0, 255, 0), -1)
        cv2.rectangle(img, (200+i*80, 250), (400+i*80, 400), (255, 0, 0), -1)
        demo_images.append(img)
    
    print(f"Processing {len(demo_images)} synthetic images...")
    
    # Process each image
    batch_results = []
    for i, img in enumerate(demo_images):
        print(f"Processing image {i+1}/{len(demo_images)}...")
        result = segmenter.segment_image(img, target_classes=['person', 'vehicle'])
        batch_results.append(result)
    
    # Display batch processing summary
    print("\nBatch Processing Summary:")
    for i, result in enumerate(batch_results):
        person_area = np.sum(result['person']) / (result['person'].shape[0] * result['person'].shape[1]) * 100
        vehicle_area = np.sum(result['vehicle']) / (result['vehicle'].shape[0] * result['vehicle'].shape[1]) * 100
        print(f"Image {i+1}: Person={person_area:.2f}%, Vehicle={vehicle_area:.2f}%")

## Summary and Next Steps

This notebook demonstrated the SAM2 baseline pipeline for automatic segmentation of people and vehicles. 

### Key Features Demonstrated:
1. ✅ SAM2 model initialization and configuration
2. ✅ Text prompt-based segmentation (baseline method)
3. ✅ Comprehensive visualization of results
4. ✅ Batch processing capabilities
5. ✅ Evaluation framework with Dice coefficient

### Baseline Method Characteristics:
- **Approach**: Uses predefined text prompts to generate point coordinates
- **Limitations**: Simple grid-based point generation, not using actual vision-language models
- **Evaluation**: Measured using Dice coefficient against ground truth

### Next Steps:
1. **Download CamVid Dataset**: Run `python src/data_preparation.py`
2. **Full Baseline Evaluation**: Run `python src/baseline.py`
3. **Develop Novel Method**: Create an improved automatic prompting strategy
4. **Compare Methods**: Evaluate novel method against baseline

### Novel Method Ideas:
- Object detection + SAM2 (YOLO → bounding boxes → SAM2)
- Semantic segmentation + SAM2 refinement
- Attention-based smart point generation
- Multi-scale prompting strategies

**Goal**: Beat the baseline Dice score with a fully automatic method! 🎯