## Section 1: Environment Setup

In [None]:
# Detect environment and configure paths
import sys
import os
from pathlib import Path

# Detect Google Colab
try:
    import google.colab
    IN_COLAB = True
    print("✓ Running on Google Colab")
except:
    IN_COLAB = False
    print("✓ Running locally")

# Set up paths
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    PROJECT_ROOT = '/content/AMLProject'
    DATA_ROOT = '/content/drive/MyDrive/AMLProject/data'
else:
    PROJECT_ROOT = os.getcwd()
    DATA_ROOT = os.path.join(PROJECT_ROOT, 'data')

# Create necessary directories
CHECKPOINT_DIR = os.path.join(PROJECT_ROOT, 'checkpoints', 'sam')
OUTPUT_DIR = os.path.join(PROJECT_ROOT, 'outputs', 'sam')
MODEL_DIR = os.path.join(PROJECT_ROOT, 'models')

for directory in [CHECKPOINT_DIR, OUTPUT_DIR, MODEL_DIR, DATA_ROOT]:
    os.makedirs(directory, exist_ok=True)

print(f"\nProject root: {PROJECT_ROOT}")
print(f"Data root: {DATA_ROOT}")
print(f"Checkpoint directory: {CHECKPOINT_DIR}")
print(f"Output directory: {OUTPUT_DIR}")

In [None]:
# Install dependencies (robust, platform-aware with clear errors)
import subprocess, sys, platform
print("Installing required packages...")
# Install Segment Anything first (from git)
try:
    print("Installing Segment Anything (segment-anything)...")
    subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'git+https://github.com/facebookresearch/segment-anything.git'])
    print("✓ segment-anything installed")
except Exception as e:
    print("You can install it manually with:")
    print("  pip install git+https://github.com/facebookresearch/segment-anything.git")

# Common Python packages (install separately to isolate failures)
common_packages = [
    'numpy',
    'matplotlib',
    'opencv-python',
    'pillow',
    'scipy',
    'tqdm',
    'pandas',
    'scikit-learn'
]
try:
    print("Installing common Python packages...")
    subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--upgrade'] + common_packages)
    print("✓ Common packages installed")
except Exception as e:
    print("You can install them manually, e.g.:")
    print("  pip install ")
 

print("Installing PyTorch (torch, torchvision, torchaudio)...")
try:
    if platform.system() == 'Darwin':
        # macOS: pip usually installs the correct (CPU/MPS) wheel or user can use conda
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--upgrade', 'torch', 'torchvision', 'torchaudio'])
    else:
        # Linux/Windows: prefer official CUDA wheel index (adjust if you need a different CUDA version)
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--index-url', 'https://download.pytorch.org/whl/cu121', 'torch', 'torchvision', 'torchaudio'])
    print("✓ PyTorch packages installed")
except Exception as e:
    print("Attempting CPU-only PyTorch installation as a fallback...")
    try:
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--index-url', 'https://download.pytorch.org/whl/cpu', 'torch', 'torchvision', 'torchaudio'])
        print("✓ CPU-only PyTorch installed")
    except Exception as e2:
        print("Please follow the official instructions at https://pytorch.org/get-started/locally/ to install a compatible wheel for your system.")

import sys, subprocess
try:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "scikit-learn"])
    print("✓ scikit-learn installed (or already up-to-date)")
except Exception as e:
    print("✗ pip install failed:", e)

print("\nInstallation step finished. If any package failed, rerun the cell without '--quiet' or install manually.")

In [None]:
# Import libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
from tqdm import tqdm
import json
import pandas as pd
from pathlib import Path
from sklearn.neighbors import NearestNeighbors

# Import SAM
try:
    from segment_anything import sam_model_registry, SamPredictor
    from segment_anything.utils.transforms import ResizeLongestSide
    print("✓ SAM imported successfully")
except ImportError as e:
    print(f"✗ Error importing SAM: {e}")
    print("  Please ensure segment-anything is installed")
    raise

# Configure matplotlib
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['figure.dpi'] = 100

# Detect device
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✓ Using CUDA GPU: {torch.cuda.get_device_name(0)}")
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device('mps')
    print("✓ Using Apple Silicon GPU (MPS)")
else:
    device = torch.device('cpu')
    print("✓ Using CPU")

print(f"PyTorch version: {torch.__version__}")

## Section 2: Download and Load SAM Model

In [None]:
# Download SAM checkpoint
import urllib.request

SAM_CHECKPOINT_URL = 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth'
SAM_CHECKPOINT_NAME = 'sam_vit_b_01ec64.pth'
sam_checkpoint_path = os.path.join(CHECKPOINT_DIR, SAM_CHECKPOINT_NAME)

print("Checking SAM checkpoint...")
if os.path.exists(sam_checkpoint_path):
    print(f"✓ Checkpoint already exists: {sam_checkpoint_path}")
else:
    print(f"Downloading SAM ViT-B checkpoint...")
    print(f"URL: {SAM_CHECKPOINT_URL}")
    print(f"This may take a few minutes (~375 MB)...")
    
    try:
        urllib.request.urlretrieve(SAM_CHECKPOINT_URL, sam_checkpoint_path)
        print(f"✓ Downloaded successfully to: {sam_checkpoint_path}")
    except Exception as e:
        print(f"✗ Download failed: {e}")
        print("\nPlease download manually:")
        print(f"  URL: {SAM_CHECKPOINT_URL}")
        print(f"  Save to: {sam_checkpoint_path}")
        raise

In [None]:
# Load SAM model
print("Loading SAM ViT-B model...")

try:
    sam_model = sam_model_registry['vit_b'](checkpoint=sam_checkpoint_path)
    sam_model = sam_model.to(device)
    sam_model.eval()
    
    # Create predictor (optional, for segmentation tasks)
    sam_predictor = SamPredictor(sam_model)
    
    print("✓ SAM model loaded successfully!")
    print(f"  - Architecture: ViT-B (Base)")
    print(f"  - Image encoder output: 64×64 feature map")
    print(f"  - Feature dimension: 256")
    print(f"  - Input size: 1024×1024 (longest side)")
    print(f"  - Device: {device}")
except Exception as e:
    print(f"✗ Error loading SAM: {e}")
    raise

## Section 3: Dense Feature Extraction

### SAM Feature Extraction Strategy

SAM's image encoder produces high-quality dense features:

1. **Preprocessing**:
   - Resize longest side to 1024 pixels
   - Pad to square (1024×1024)
   - Normalize with ImageNet statistics

2. **Feature Extraction**:
   - Extract from image encoder (ViT backbone)
   - Output: 64×64×256 feature map
   - Features encode both semantic and spatial information

3. **Key Differences from DINO**:
   - Larger input resolution (1024 vs 224)
   - Lower feature dimensionality (256 vs 768)
   - Designed for segmentation (may be better for spatial tasks)

In [None]:
class SAMFeatureExtractor:
    """
    Extract dense spatial features from SAM's image encoder.
    
    SAM is optimized for dense prediction, making it potentially
    excellent for correspondence tasks.
    """
    
    def __init__(self, model, device='cuda'):
        self.model = model
        self.device = device
        self.image_encoder = model.image_encoder
        self.img_size = self.image_encoder.img_size  # 1024
        self.feat_dim = 256  # SAM feature dimension
        
        # SAM's preprocessing transform
        self.transform = ResizeLongestSide(self.img_size)
    
    def preprocess_image(self, image):
        """
        Preprocess image following SAM's requirements: resize + pad to square.
        Returns:
            input_tensor: torch.Tensor shaped [1, 3, 1024, 1024] on device with correct dtype
            original_size: (H, W) of original image
        """
        # Convert to numpy
        if isinstance(image, Image.Image):
            image_np = np.array(image)
            original_size = (image.height, image.width)
        else:
            image_np = image
            original_size = (image_np.shape[0], image_np.shape[1])

        # Apply SAM's transform (resize longest side to 1024)
        input_image = self.transform.apply_image(image_np)  # HxWxC, uint8
        
        # Convert to tensor [C, H, W]
        input_tensor = torch.as_tensor(input_image, dtype=torch.float32).permute(2, 0, 1).contiguous()
        
        # Pad to square (1024x1024) - CRITICAL for SAM
        h, w = input_tensor.shape[-2:]
        padh = self.img_size - h
        padw = self.img_size - w
        input_tensor = F.pad(input_tensor, (0, padw, 0, padh))  # pad right and bottom
        
        # Add batch dimension [1, 3, 1024, 1024]
        input_tensor = input_tensor.unsqueeze(0)
        
        # Move to device and match model dtype
        try:
            param_dtype = next(self.image_encoder.parameters()).dtype
        except StopIteration:
            param_dtype = torch.float32
        
        input_tensor = input_tensor.to(self.device).to(param_dtype)
        
        return input_tensor, original_size
    
    def extract_features(self, image, normalize=True):
        """
        Extract dense feature map from image.
        
        Args:
            image: PIL Image or numpy array
            normalize: Apply L2 normalization
            
        Returns:
            features: [H, W, D] numpy array (64×64×256)
            info: Metadata dictionary
        """
        # Preprocess
        img_tensor, original_size = self.preprocess_image(image)
        
        # Extract features from image encoder
        with torch.no_grad():
            image_embedding = self.image_encoder(img_tensor)  # [1, 256, 64, 64]
        
        # Rearrange to [H, W, D]
        features = image_embedding[0].permute(1, 2, 0)  # [64, 64, 256]
        
        # L2 normalize
        if normalize:
            features = F.normalize(features, p=2, dim=-1)
        
        features = features.cpu().numpy()
        
        h, w = features.shape[0], features.shape[1]
        
        # Get original image size
        if isinstance(image, Image.Image):
            orig_w, orig_h = image.size
        else:
            orig_h, orig_w = image.shape[:2]
        
        info = {
            'original_size': (orig_w, orig_h),
            'feature_size': (w, h),
            'processed_size': original_size,
            'scale_x': w / orig_w,
            'scale_y': h / orig_h
        }
        
        return features, info
    
    def map_coords_to_features(self, coords, info):
        """Map image coordinates to feature space."""
        coords = np.array(coords).astype(float)
        feat_coords = coords.copy()
        feat_coords[:, 0] *= info['scale_x']
        feat_coords[:, 1] *= info['scale_y']
        return feat_coords
    
    def extract_keypoint_features(self, image, keypoints):
        """
        Extract features at specific keypoint locations.
        
        Args:
            image: PIL Image
            keypoints: [N, 2] array of (x, y) coordinates
            
        Returns:
            kp_features: [N, D] feature vectors
        """
        features, info = self.extract_features(image, normalize=True)
        h, w, d = features.shape
        
        # Map to feature space
        feat_kps = self.map_coords_to_features(keypoints, info)
        
        # Clip and round
        feat_kps[:, 0] = np.clip(feat_kps[:, 0], 0, w - 1)
        feat_kps[:, 1] = np.clip(feat_kps[:, 1], 0, h - 1)
        feat_kps = np.round(feat_kps).astype(int)
        
        # Extract features
        kp_features = features[feat_kps[:, 1], feat_kps[:, 0], :]
        
        return kp_features

# Initialize feature extractor
feature_extractor = SAMFeatureExtractor(sam_model, device=device)
print("✓ SAM feature extractor initialized")

In [None]:
# Test feature extraction
print("Testing SAM feature extraction...")
test_image = Image.new('RGB', (480, 640), color=(128, 128, 128))

features, info = feature_extractor.extract_features(test_image)
print(f"\n✓ Feature extraction successful!")
print(f"  Input image size: {info['original_size']}")
print(f"  Feature map size: {info['feature_size']} = {features.shape[0]}×{features.shape[1]}")
print(f"  Feature dimension: {features.shape[2]}")
print(f"  Features normalized: {np.allclose(np.linalg.norm(features[0, 0, :]), 1.0)}")
print(f"\n  Note: SAM uses 64×64 feature grid (higher resolution than DINO's 16×16)")

# Test keypoint features
test_kps = np.array([[100, 150], [200, 300], [400, 500]])
kp_features = feature_extractor.extract_keypoint_features(test_image, test_kps)
print(f"\n✓ Keypoint feature extraction successful!")
print(f"  Number of keypoints: {len(test_kps)}")
print(f"  Feature shape: {kp_features.shape}")

## Section 3: Correspondence Matching

In [None]:
class CorrespondenceMatcher:
    """
    Match keypoints between images using dense feature similarity.
    
    SAM's higher spatial resolution (64×64 vs 16×16) may provide
    more accurate localization.
    """
    
    def __init__(self, mutual_nn=False, ratio_threshold=None):
        self.mutual_nn = mutual_nn
        self.ratio_threshold = ratio_threshold
    
    def match(self, src_features, tgt_features_map, return_scores=True):
        """Match source features to target feature map."""
        h, w, d = tgt_features_map.shape
        tgt_flat = tgt_features_map.reshape(-1, d)
        
        # Compute cosine similarity
        similarity = src_features @ tgt_flat.T
        
        # Find best matches
        best_indices = np.argmax(similarity, axis=1)
        best_scores = np.max(similarity, axis=1)
        
        # Ratio test
        if self.ratio_threshold is not None:
            sorted_sim = np.sort(similarity, axis=1)[:, ::-1]
            ratios = sorted_sim[:, 0] / (sorted_sim[:, 1] + 1e-8)
            valid_mask = ratios > self.ratio_threshold
            best_indices[~valid_mask] = -1
        
        # Mutual nearest neighbor
        if self.mutual_nn:
            reverse_sim = tgt_flat @ src_features.T
            reverse_best = np.argmax(reverse_sim, axis=1)
            
            for i, tgt_idx in enumerate(best_indices):
                if tgt_idx >= 0 and reverse_best[tgt_idx] != i:
                    best_indices[i] = -1
        
        # Convert to coordinates
        matched_y = best_indices // w
        matched_x = best_indices % w
        matched_coords = np.stack([matched_x, matched_y], axis=1).astype(float)
        
        invalid = best_indices < 0
        matched_coords[invalid] = np.nan
        
        if return_scores:
            return matched_coords, best_scores
        return matched_coords
    
    def match_keypoints(self, src_image, tgt_image, src_keypoints, feature_extractor):
        """End-to-end keypoint matching."""
        # Extract features
        src_features = feature_extractor.extract_keypoint_features(src_image, src_keypoints)
        tgt_features_map, tgt_info = feature_extractor.extract_features(tgt_image, normalize=True)
        
        # Match
        matched_coords_feat, confidence = self.match(src_features, tgt_features_map, return_scores=True)
        
        # Map back to original coordinates
        tgt_w, tgt_h = tgt_info['original_size']
        feat_w, feat_h = tgt_info['feature_size']
        
        tgt_keypoints = matched_coords_feat.copy()
        tgt_keypoints[:, 0] = matched_coords_feat[:, 0] * (tgt_w / feat_w)
        tgt_keypoints[:, 1] = matched_coords_feat[:, 1] * (tgt_h / feat_h)
        
        return tgt_keypoints, confidence

matcher = CorrespondenceMatcher(mutual_nn=False, ratio_threshold=None)
print("✓ Correspondence matcher initialized")
print(f"  - Method: Nearest Neighbor")
print(f"  - Mutual NN: {matcher.mutual_nn}")
print(f"  - Ratio test: {matcher.ratio_threshold}")

## Section 3: Evaluation Metrics

In [None]:
class PCKEvaluator:
    """PCK (Percentage of Correct Keypoints) evaluator."""
    
    def __init__(self, alpha_values=[0.05, 0.10, 0.15], use_bbox=True):
        self.alpha_values = alpha_values
        self.use_bbox = use_bbox
    
    def compute_pck(self, predicted_kps, gt_kps, image_size=None, bbox=None):
        """Compute PCK for single image pair."""
        valid_mask = ~np.isnan(predicted_kps).any(axis=1) & ~np.isnan(gt_kps).any(axis=1)
        if valid_mask.sum() == 0:
            return {f'PCK@{alpha:.2f}': 0.0 for alpha in self.alpha_values}
        
        pred = predicted_kps[valid_mask]
        gt = gt_kps[valid_mask]
        
        distances = np.linalg.norm(pred - gt, axis=1)
        
        if self.use_bbox and bbox is not None and len(bbox) >= 4:
            norm_factor = np.sqrt(bbox[2]**2 + bbox[3]**2)
        elif image_size is not None:
            norm_factor = np.sqrt(image_size[0]**2 + image_size[1]**2)
        else:
            norm_factor = 1.0
        
        pck_dict = {}
        for alpha in self.alpha_values:
            threshold = alpha * norm_factor
            correct = (distances <= threshold).sum()
            pck = correct / len(distances) if len(distances) > 0 else 0.0
            pck_dict[f'PCK@{alpha:.2f}'] = pck
        
        return pck_dict
    
    def evaluate_batch(self, predictions, ground_truths, image_sizes=None, bboxes=None):
        """Evaluate multiple image pairs."""
        all_pck = {f'PCK@{alpha:.2f}': [] for alpha in self.alpha_values}
        per_sample = []
        
        for i in range(len(predictions)):
            img_size = image_sizes[i] if image_sizes else None
            bbox = bboxes[i] if bboxes else None
            
            pck = self.compute_pck(predictions[i], ground_truths[i], img_size, bbox)
            per_sample.append(pck)
            
            for key, value in pck.items():
                all_pck[key].append(value)
        
        mean_pck = {key: np.mean(values) for key, values in all_pck.items()}
        
        return {
            'mean': mean_pck,
            'per_sample': per_sample,
            'num_samples': len(predictions)
        }

evaluator = PCKEvaluator(alpha_values=[0.05, 0.10, 0.15], use_bbox=True)
print("✓ PCK evaluator initialized")

## Dataset Loaders

In [None]:
# Dataset setup
def setup_datasets(data_root):
    """Setup benchmark datasets."""
    print("="*60)
    print("DATASET SETUP")
    print("="*60)
    
    os.makedirs(data_root, exist_ok=True)
    
    print("\n⚠️  Please download datasets manually:")
    print("\n1. PF-Pascal: https://www.di.ens.fr/willow/research/proposalflow/")
    print(f"   → Extract to: {data_root}/pf-pascal/")
    print("\n2. SPair-71k: http://cvlab.postech.ac.kr/research/SPair-71k/")
    print(f"   → Extract to: {data_root}/spair-71k/")
    print("\n" + "="*60)

setup_datasets(DATA_ROOT)

In [None]:
# SPair-71k dataset loader
from torch.utils.data import Dataset

class SPairDataset(Dataset):
    """SPair-71k dataset loader."""
    
    def __init__(self, root_dir, split='test', category=None):
        self.root_dir = Path(root_dir)
        self.split = split
        self.category = category
        self.pairs = []
        self._load_annotations()
    
    def _load_annotations(self):
        anno_dir = self.root_dir / 'PairAnnotation' / self.split
        
        if not anno_dir.exists():
            print(f"⚠️  Annotations not found: {anno_dir}")
            return
        
        for anno_file in sorted(anno_dir.glob('*.json')):
            with open(anno_file, 'r') as f:
                data = json.load(f)
            
            if self.category and data.get('category') != self.category:
                continue
            
            # Image paths are: JPEGImages/<category>/<image_name>
            cat = data.get('category', 'unknown')
            pair = {
                'src_img': str(self.root_dir / 'JPEGImages' / cat / data['src_imname']),
                'tgt_img': str(self.root_dir / 'JPEGImages' / cat / data['trg_imname']),
                'src_kps': np.array(data['src_kps']).T,
                'tgt_kps': np.array(data['trg_kps']).T,
                'src_bbox': np.array(data.get('src_bndbox', [])),
                'tgt_bbox': np.array(data.get('trg_bndbox', [])),
                'category': cat
            }
            self.pairs.append(pair)
        
        print(f"✓ Loaded {len(self.pairs)} pairs from SPair-71k {self.split} split")
    
    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        pair = self.pairs[idx]
        
        src_img = Image.open(pair['src_img']).convert('RGB')
        tgt_img = Image.open(pair['tgt_img']).convert('RGB')
        
        return {
            'src_image': src_img,
            'tgt_image': tgt_img,
            'src_keypoints': pair['src_kps'],
            'tgt_keypoints': pair['tgt_kps'],
            'src_bbox': pair['src_bbox'],
            'tgt_bbox': pair['tgt_bbox'],
            'category': pair['category']
        }

print("✓ Dataset loaders defined")

## Visualization and Evaluation Pipeline

In [None]:
def visualize_correspondences(src_img, tgt_img, src_kps, pred_kps, gt_kps=None, 
                              max_points=15, save_path=None):
    """Visualize correspondence matches."""
    if isinstance(src_img, Image.Image):
        src_img = np.array(src_img)
    if isinstance(tgt_img, Image.Image):
        tgt_img = np.array(tgt_img)
    
    # Ensure all arrays are properly shaped 2D arrays (N, 2)
    src_kps = np.atleast_2d(src_kps)
    if src_kps.shape[0] == 2 and src_kps.shape[1] != 2:  # If shape is (2, M) where M > 2
        src_kps = src_kps.T
    
    pred_kps = np.atleast_2d(pred_kps)
    if pred_kps.shape[0] == 2 and pred_kps.shape[1] != 2:
        pred_kps = pred_kps.T
    
    if gt_kps is not None:
        gt_kps = np.atleast_2d(gt_kps)
        if gt_kps.shape[0] == 2 and gt_kps.shape[1] != 2:
            gt_kps = gt_kps.T
    
    # Ensure all have same number of points by truncating to minimum
    min_points = min(len(src_kps), len(pred_kps))
    if gt_kps is not None:
        min_points = min(min_points, len(gt_kps))
    
    src_kps = src_kps[:min_points]
    pred_kps = pred_kps[:min_points]
    if gt_kps is not None:
        gt_kps = gt_kps[:min_points]
    
    # Subsample if needed
    if len(src_kps) > max_points:
        indices = np.random.choice(len(src_kps), max_points, replace=False)
        src_kps = src_kps[indices]
        pred_kps = pred_kps[indices]
        if gt_kps is not None:
            gt_kps = gt_kps[indices]
    
    ncols = 3 if gt_kps is not None else 2
    fig, axes = plt.subplots(1, ncols, figsize=(6*ncols, 6))
    if ncols == 2:
        axes = [axes[0], axes[1]]
    
    # Left: source image with source keypoints
    axes[0].imshow(src_img)
    if len(src_kps) > 0:
        axes[0].scatter(src_kps[:, 0], src_kps[:, 1], c='red', s=100, 
                        edgecolors='white', linewidths=2, marker='o')
    axes[0].set_title('Source Image', fontsize=12, fontweight='bold')
    axes[0].axis('off')
    
    # Middle: target image with predicted keypoints
    axes[1].imshow(tgt_img)
    valid_pred = ~np.isnan(pred_kps).any(axis=1)
    if valid_pred.sum() > 0:
        axes[1].scatter(pred_kps[valid_pred, 0], pred_kps[valid_pred, 1], c='blue', s=100, 
                        marker='x', linewidths=3)
    axes[1].set_title('Target (Predictions)', fontsize=12, fontweight='bold')
    axes[1].axis('off')
    
    # Right: target image with GT vs Pred comparison
    if gt_kps is not None and ncols == 3:
        axes[2].imshow(tgt_img)
        valid_gt = ~np.isnan(gt_kps).any(axis=1)
        
        # Plot GT keypoints
        if valid_gt.sum() > 0:
            axes[2].scatter(gt_kps[valid_gt, 0], gt_kps[valid_gt, 1], c='green', s=100, 
                           edgecolors='white', linewidths=2, marker='o', label='GT')
        
        # Plot predicted keypoints
        if valid_pred.sum() > 0:
            axes[2].scatter(pred_kps[valid_pred, 0], pred_kps[valid_pred, 1], c='blue', s=50, 
                           marker='x', linewidths=2, alpha=0.7, label='Pred')
        
        # Draw lines and compute error only for indices valid in BOTH
        valid_both = valid_pred & valid_gt
        if valid_both.sum() > 0:
            for i in np.where(valid_both)[0]:
                axes[2].plot([gt_kps[i, 0], pred_kps[i, 0]], 
                           [gt_kps[i, 1], pred_kps[i, 1]], 
                           'r--', alpha=0.3, linewidth=1)
            
            try:
                errors = np.linalg.norm(pred_kps[valid_both] - gt_kps[valid_both], axis=1)
                mean_error = errors.mean()
            except ValueError:
                mean_error = 0
        else:
            mean_error = 0
        
        axes[2].set_title(f'GT vs Pred (Mean Error: {mean_error:.1f}px)', 
                         fontsize=12, fontweight='bold')
        axes[2].legend(loc='upper right')
        axes[2].axis('off')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    
    return fig

print("✓ Visualization utilities ready")

In [None]:
def evaluate_on_dataset(dataset, feature_extractor, matcher, evaluator, 
                       max_samples=None, save_visualizations=True):
    """Complete evaluation pipeline."""
    print("="*60)
    print(f"EVALUATING SAM ON {dataset.__class__.__name__}")
    print("="*60)
    
    num_samples = min(max_samples, len(dataset)) if max_samples else len(dataset)
    print(f"Total samples: {len(dataset)}")
    print(f"Evaluating: {num_samples} samples\n")
    
    predictions = []
    ground_truths = []
    image_sizes = []
    bboxes = []
    confidences = []
    
    for i in tqdm(range(num_samples), desc="Processing"):
        sample = dataset[i]
        
        src_img = sample['src_image']
        tgt_img = sample['tgt_image']
        src_kps = sample['src_keypoints']
        tgt_kps = sample['tgt_keypoints']
        
        if len(src_kps) == 0 or len(tgt_kps) == 0:
            continue
        
        pred_kps, conf = matcher.match_keypoints(
            src_img, tgt_img, src_kps, feature_extractor
        )
        
        predictions.append(pred_kps)
        ground_truths.append(tgt_kps)
        confidences.append(conf)
        image_sizes.append(tgt_img.size)
        
        if 'tgt_bbox' in sample and len(sample['tgt_bbox']) > 0:
            bboxes.append(sample['tgt_bbox'])
        else:
            bboxes.append(None)
        
        if save_visualizations and i < 5:
            vis_path = os.path.join(OUTPUT_DIR, f'sample_{i}.png')
            visualize_correspondences(src_img, tgt_img, src_kps, pred_kps, 
                                    tgt_kps, save_path=vis_path)
            plt.close()
    
    results = evaluator.evaluate_batch(predictions, ground_truths, image_sizes, bboxes)
    
    print("\n" + "="*60)
    print("RESULTS")
    print("="*60)
    print(f"Samples evaluated: {results['num_samples']}")
    print("\nPCK Scores:")
    for metric, value in sorted(results['mean'].items()):
        print(f"  {metric}: {value*100:.2f}%")
    print("="*60)
    
    results_file = os.path.join(OUTPUT_DIR, 'evaluation_results.json')
    with open(results_file, 'w') as f:
        json.dump({
            'backbone': 'SAM ViT-B',
            'dataset': dataset.__class__.__name__,
            'num_samples': results['num_samples'],
            'mean_pck': results['mean'],
            'per_sample_pck': results['per_sample']
        }, f, indent=2)
    print(f"\n✓ Results saved to {results_file}")
    
    return results

print("✓ Evaluation pipeline ready")

## Run Evaluation

Uncomment to run evaluation.

In [None]:
# Load and evaluate
spair_test = SPairDataset(
    root_dir=os.path.join(DATA_ROOT, 'SPair-71k'),  
    split='test'
)

results = evaluate_on_dataset(
    dataset=spair_test,
    feature_extractor=feature_extractor,
    matcher=matcher,
    evaluator=evaluator,
    max_samples=4,
    save_visualizations=True
)

## Summary

### SAM Implementation Complete ✓

**Implementation:**
1. ✓ Cross-platform environment setup
2. ✓ SAM ViT-B model and checkpoint download
3. ✓ Dense feature extraction (64×64×256)
4. ✓ Correspondence matching
5. ✓ PCK evaluation
6. ✓ Dataset loaders
7. ✓ Visualization tools
8. ✓ Complete pipeline

**SAM Advantages:**
- **Higher spatial resolution**: 64×64 features vs 16×16 (DINO)
- **Task-specific training**: Trained for dense prediction tasks
- **Strong boundaries**: Excellent at detecting object boundaries
- **Large-scale data**: 11M images, 1.1B masks

**Trade-offs:**
- Lower feature dimension (256 vs 768)
- Larger input size (1024 vs 224) → slower
- More memory intensive

**Expected Performance:**
- Potentially better localization accuracy (higher resolution)
- Strong for objects with clear boundaries
- May excel on geometric transformations