## Section 1: Environment Setup & Dataset Loading

In [None]:
# Detect environment and configure paths
import sys
import os
from pathlib import Path

# Detect Google Colab
try:
    import google.colab
    IN_COLAB = True
    print("✓ Running on Google Colab")
except:
    IN_COLAB = False
    print("✓ Running locally")

# Set up paths
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    PROJECT_ROOT = '/content/AMLProject'
    DATA_ROOT = '/content/drive/MyDrive/AMLProject/data'
else:
    PROJECT_ROOT = os.getcwd()
    DATA_ROOT = os.path.join(PROJECT_ROOT, 'data')

# Create necessary directories
CHECKPOINT_DIR = os.path.join(PROJECT_ROOT, 'checkpoints')
OUTPUT_DIR = os.path.join(PROJECT_ROOT, 'outputs', 'dinov2')
MODEL_DIR = os.path.join(PROJECT_ROOT, 'models')

for directory in [CHECKPOINT_DIR, OUTPUT_DIR, MODEL_DIR, DATA_ROOT]:
    os.makedirs(directory, exist_ok=True)

print(f"\nProject root: {PROJECT_ROOT}")
print(f"Data root: {DATA_ROOT}")
print(f"Output directory: {OUTPUT_DIR}")

In [None]:
# Install dependencies
import subprocess

print("Installing required packages...")
packages = [
    'torch',
    'torchvision', 
    'numpy',
    'matplotlib',
    'opencv-python',
    'pillow',
    'scipy',
    'tqdm',
    'pandas',
    'scikit-learn'
]

subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--quiet', '--upgrade'] + packages)
print("✓ All packages installed successfully!")

In [None]:
# Import libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
from tqdm import tqdm
import json
import pandas as pd
from pathlib import Path
from sklearn.neighbors import NearestNeighbors
from sklearn.decomposition import PCA
import importlib

# Configure matplotlib
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['figure.dpi'] = 100

# Detect device
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✓ Using CUDA GPU: {torch.cuda.get_device_name(0)}")
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device('mps')
    print("✓ Using Apple Silicon GPU (MPS)")
else:
    device = torch.device('cpu')
    print("✓ Using CPU")

print(f"PyTorch version: {torch.__version__}")

In [None]:
# Setup dataset path and import SPair dataset class
repo_path = os.path.join(PROJECT_ROOT, 'SD4Match')

if not os.path.exists(repo_path):
    print("WARNING: SD4Match repository not found!")
    print("Please ensure the SD4Match repository is cloned in the project directory.")
else:
    if repo_path not in sys.path:
        sys.path.append(repo_path)
    print(f"✓ SD4Match path added: {repo_path}")

# Import dataset class
try:
    module = importlib.import_module("dataset.spair")
    SPairDataset = getattr(module, "SPairDataset")
    print("✓ SPairDataset class imported successfully")
except Exception as e:
    print(f"Import Error: {e}")
    print("Make sure SD4Match repository is properly cloned.")

In [None]:
# Download SPair-71k dataset (if not already present)
import requests
import tarfile
from tqdm import tqdm as tqdm_requests

data_path = os.path.join(DATA_ROOT, 'SPair-71k')

if not os.path.exists(data_path):
    print("Downloading SPair-71k dataset...")
    url = "http://cvlab.postech.ac.kr/research/SPair-71k/data/SPair-71k.tar.gz"
    tar_path = os.path.join(DATA_ROOT, 'SPair-71k.tar.gz')
    
    # Download with progress bar
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    
    with open(tar_path, 'wb') as f, tqdm_requests(
        desc='Downloading',
        total=total_size,
        unit='B',
        unit_scale=True,
        unit_divisor=1024,
    ) as pbar:
        for data in response.iter_content(chunk_size=1024):
            size = f.write(data)
            pbar.update(size)
    
    print("\nExtracting...")
    with tarfile.open(tar_path, 'r:gz') as tar:
        tar.extractall(DATA_ROOT)
    
    # Cleanup
    os.remove(tar_path)
    print("✓ Extraction complete")
else:
    print(f"✓ SPair-71k dataset already exists at {data_path}")

In [None]:
# Create configuration for dataset
class Config:
    def __init__(self):
        class DatasetConfig:
            def __init__(self):
                self.ROOT = DATA_ROOT
                self.NAME = 'spair'
                self.CATEGORY = 'cat'
                self.SIZE = 224
                self.IMG_SIZE = 224
                self.MEAN = [0.485, 0.456, 0.406]
                self.STD = [0.229, 0.224, 0.225]
        self.DATASET = DatasetConfig()

cfg = Config()
print(f"✓ Configuration created")
print(f"Dataset root: {cfg.DATASET.ROOT}")

## Section 2: DINOv2 Model Loading

In [None]:
# Load DINOv2 ViT-B/14 model
print("Loading DINOv2 ViT-B/14 model...")

try:
    # Try loading from torch hub
    dinov2_model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
    dinov2_model = dinov2_model.to(device)
    dinov2_model.eval()
    
    print(f"✓ Model loaded successfully!")
    print(f"Model type: DINOv2 ViT-B/14")
    print(f"Feature dimension: {dinov2_model.embed_dim}")
    print(f"Patch size: 14x14")
    print(f"Image size: 224x224 → {224//14}x{224//14} = 256 patches")
    
except Exception as e:
    print(f"Error loading DINOv2: {e}")
    print("Make sure you have internet connection for the first load.")

## Section 3: Feature Extractor Class

DINOv2 extracts dense patch features that capture semantic information. We'll create a feature extractor class that:
- Extracts 16×16 grid of features (patch tokens)
- Handles coordinate mapping between image and feature space
- Provides L2-normalized features for similarity computation

In [None]:
class DINOv2FeatureExtractor:
    """
    Extract dense spatial features from DINOv2 for correspondence matching.
    
    Features:
    - Input: 224×224 RGB images
    - Output: 16×16×768 feature maps (for ViT-B/14)
    - Features are L2-normalized for cosine similarity
    """
    
    def __init__(self, model, device='cuda'):
        self.model = model
        self.device = device
        self.model.eval()
        
        # Image preprocessing
        self.transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
        ])
        
        # Feature dimensions
        self.patch_size = 14
        self.img_size = 224
        self.feat_h = self.feat_w = self.img_size // self.patch_size  # 16
        
    def extract_features(self, img):
        """
        Extract dense feature map from an image.
        
        Args:
            img: PIL Image or tensor [3, H, W]
            
        Returns:
            features: torch.Tensor [1, feat_h*feat_w, dim]
            features_2d: torch.Tensor [1, dim, feat_h, feat_w]
        """
        # Preprocess
        if isinstance(img, Image.Image):
            img_tensor = self.transform(img).unsqueeze(0).to(self.device)
        else:
            img_tensor = img.unsqueeze(0).to(self.device) if img.dim() == 3 else img.to(self.device)
        
        # Extract features
        with torch.no_grad():
            # Get patch tokens (excluding CLS token)
            features_dict = self.model.forward_features(img_tensor)
            features = features_dict['x_norm_patchtokens']  # [1, num_patches, dim]
            
            # Reshape to spatial grid
            B, N, D = features.shape
            features_2d = features.reshape(B, self.feat_h, self.feat_w, D)
            features_2d = features_2d.permute(0, 3, 1, 2)  # [1, dim, feat_h, feat_w]
            
            # L2 normalize for cosine similarity
            features = F.normalize(features, p=2, dim=-1)
            features_2d = F.normalize(features_2d, p=2, dim=1)
        
        return features, features_2d
    
    def extract_keypoint_features(self, img, keypoints):
        """
        Extract features at specific keypoint locations.
        
        Args:
            img: PIL Image or tensor
            keypoints: numpy array [N, 2] in image coordinates (x, y)
            
        Returns:
            kp_features: torch.Tensor [N, dim]
        """
        _, features_2d = self.extract_features(img)  # [1, dim, feat_h, feat_w]
        
        # Map image coordinates to feature coordinates
        feat_coords = self.map_coords_to_features(keypoints)
        
        # Extract features using bilinear interpolation
        kp_features = []
        for x, y in feat_coords:
            if 0 <= x < self.feat_w and 0 <= y < self.feat_h:
                # Use bilinear interpolation for sub-pixel accuracy
                x0, y0 = int(np.floor(x)), int(np.floor(y))
                x1, y1 = min(x0 + 1, self.feat_w - 1), min(y0 + 1, self.feat_h - 1)
                
                # Interpolation weights
                wx = x - x0
                wy = y - y0
                
                # Bilinear interpolation
                feat = (1 - wx) * (1 - wy) * features_2d[0, :, y0, x0] + \
                       wx * (1 - wy) * features_2d[0, :, y0, x1] + \
                       (1 - wx) * wy * features_2d[0, :, y1, x0] + \
                       wx * wy * features_2d[0, :, y1, x1]
                
                kp_features.append(feat)
            else:
                # Out of bounds - use zero vector
                kp_features.append(torch.zeros(features_2d.shape[1], device=self.device))
        
        kp_features = torch.stack(kp_features)
        kp_features = F.normalize(kp_features, p=2, dim=-1)
        
        return kp_features
    
    def map_coords_to_features(self, coords):
        """
        Map image coordinates to feature map coordinates.
        
        Args:
            coords: numpy array [N, 2] in image space (x, y)
            
        Returns:
            feat_coords: numpy array [N, 2] in feature space
        """
        scale_x = self.feat_w / self.img_size
        scale_y = self.feat_h / self.img_size
        
        feat_coords = coords.copy()
        feat_coords[:, 0] = coords[:, 0] * scale_x
        feat_coords[:, 1] = coords[:, 1] * scale_y
        
        return feat_coords
    
    def map_features_to_coords(self, feat_coords):
        """
        Map feature coordinates back to image space.
        
        Args:
            feat_coords: numpy array [N, 2] in feature space
            
        Returns:
            img_coords: numpy array [N, 2] in image space (x, y)
        """
        scale_x = self.img_size / self.feat_w
        scale_y = self.img_size / self.feat_h
        
        img_coords = feat_coords.copy()
        img_coords[:, 0] = feat_coords[:, 0] * scale_x
        img_coords[:, 1] = feat_coords[:, 1] * scale_y
        
        return img_coords

# Create feature extractor
feature_extractor = DINOv2FeatureExtractor(dinov2_model, device)
print("✓ Feature extractor created")

## Section 4: Test Feature Extraction with Visualization

Let's test the feature extractor on a sample image pair and visualize the features.

In [None]:
# Load test dataset
print("Loading SPair-71k test dataset...")
try:
    dataset = SPairDataset(cfg, 'test', 'cat')
    print(f"✓ Successfully loaded {len(dataset)} test pairs")
except Exception as e:
    print(f"Error loading dataset: {e}")
    print("Make sure dataset is properly extracted.")

In [None]:
# Get a sample and extract features
sample_idx = 0
sample = dataset[sample_idx]

print(f"Extracting features for sample {sample_idx}...")

# Get images and keypoints
src_img = sample['src_img']
tgt_img = sample['trg_img']
src_kps = sample['src_kps']
tgt_kps = sample['trg_kps']

print(f"Source keypoints shape: {src_kps.shape}")
print(f"Target keypoints shape: {tgt_kps.shape}")

# Extract dense features
src_features, src_features_2d = feature_extractor.extract_features(src_img)
tgt_features, tgt_features_2d = feature_extractor.extract_features(tgt_img)

print(f"\nFeature shapes:")
print(f"Source features: {src_features.shape}")  # [1, 256, 768]
print(f"Source features 2D: {src_features_2d.shape}")  # [1, 768, 16, 16]
print(f"Target features: {tgt_features.shape}")
print(f"Target features 2D: {tgt_features_2d.shape}")

# Denormalize function for visualization
def denorm_show(img_tensor):
    img = img_tensor.permute(1, 2, 0).numpy()
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img = std * img + mean
    return np.clip(img, 0, 1)

In [None]:
# Visualize images with keypoints
fig, axes = plt.subplots(1, 2, figsize=(14, 7))

# Source image
axes[0].imshow(denorm_show(src_img))
valid_src_kps = src_kps[src_kps[:, 0] >= 0]
axes[0].scatter(valid_src_kps[:, 0], valid_src_kps[:, 1], 
               c='red', s=100, marker='x', linewidths=3, label='Keypoints')
axes[0].set_title(f'Source Image\n{len(valid_src_kps)} keypoints', fontsize=12, fontweight='bold')
axes[0].axis('off')
axes[0].legend()

# Target image
axes[1].imshow(denorm_show(tgt_img))
valid_tgt_kps = tgt_kps[tgt_kps[:, 0] >= 0]
axes[1].scatter(valid_tgt_kps[:, 0], valid_tgt_kps[:, 1], 
               c='red', s=100, marker='x', linewidths=3, label='Keypoints')
axes[1].set_title(f'Target Image\n{len(valid_tgt_kps)} keypoints', fontsize=12, fontweight='bold')
axes[1].axis('off')
axes[1].legend()

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'sample_images_with_keypoints.png'), dpi=150, bbox_inches='tight')
plt.show()

print(f"✓ Visualization saved to {OUTPUT_DIR}")

In [None]:
# Visualize DINOv2 features using PCA
print("Visualizing DINOv2 features with PCA...")

# Get features as numpy arrays
src_feat_np = src_features[0].cpu().numpy()  # [256, 768]
tgt_feat_np = tgt_features[0].cpu().numpy()  # [256, 768]

# Apply PCA to reduce to 3 components for RGB visualization
pca = PCA(n_components=3)
src_pca = pca.fit_transform(src_feat_np)  # [256, 3]
tgt_pca = pca.transform(tgt_feat_np)  # [256, 3]

# Reshape to spatial grid
src_pca_img = src_pca.reshape(16, 16, 3)
tgt_pca_img = tgt_pca.reshape(16, 16, 3)

# Normalize to [0, 1] for visualization
src_pca_img = (src_pca_img - src_pca_img.min()) / (src_pca_img.max() - src_pca_img.min())
tgt_pca_img = (tgt_pca_img - tgt_pca_img.min()) / (tgt_pca_img.max() - tgt_pca_img.min())

# Visualize
fig, axes = plt.subplots(2, 2, figsize=(12, 12))

# Original images
axes[0, 0].imshow(denorm_show(src_img))
axes[0, 0].set_title('Source Image', fontsize=14, fontweight='bold')
axes[0, 0].axis('off')

axes[0, 1].imshow(denorm_show(tgt_img))
axes[0, 1].set_title('Target Image', fontsize=14, fontweight='bold')
axes[0, 1].axis('off')

# PCA visualizations
axes[1, 0].imshow(src_pca_img)
axes[1, 0].set_title(f'Source DINOv2 Features (PCA)\nExplained variance: {pca.explained_variance_ratio_.sum():.2%}', 
                     fontsize=12, fontweight='bold')
axes[1, 0].axis('off')

axes[1, 1].imshow(tgt_pca_img)
axes[1, 1].set_title(f'Target DINOv2 Features (PCA)\nExplained variance: {pca.explained_variance_ratio_.sum():.2%}', 
                     fontsize=12, fontweight='bold')
axes[1, 1].axis('off')

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'pca_feature_visualization.png'), dpi=150, bbox_inches='tight')
plt.show()

print(f"\n✓ PCA captures {pca.explained_variance_ratio_.sum():.2%} of feature variance")
print("Colors represent semantic regions learned by DINOv2")

In [None]:
# Compute patch-to-patch similarity matrix
print("Computing patch-to-patch similarity...")

# Normalize features for cosine similarity
src_feat_norm = F.normalize(torch.from_numpy(src_feat_np), dim=1)  # [256, 768]
tgt_feat_norm = F.normalize(torch.from_numpy(tgt_feat_np), dim=1)  # [256, 768]

# Compute similarity matrix
similarity_matrix = torch.mm(src_feat_norm, tgt_feat_norm.t()).numpy()  # [256, 256]

# Visualize
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Source image
axes[0].imshow(denorm_show(src_img))
axes[0].scatter(valid_src_kps[:, 0], valid_src_kps[:, 1], 
               c='red', s=100, marker='x', linewidths=3)
axes[0].set_title(f'Source Image\n({len(valid_src_kps)} keypoints)', fontsize=12, fontweight='bold')
axes[0].axis('off')

# Similarity heatmap
im = axes[1].imshow(similarity_matrix, cmap='hot', aspect='auto')
axes[1].set_title(f'Patch-to-Patch Similarity\n(DINOv2 Features)', fontsize=12, fontweight='bold')
axes[1].set_xlabel('Target Patches')
axes[1].set_ylabel('Source Patches')
plt.colorbar(im, ax=axes[1], label='Cosine Similarity')

# Target image
axes[2].imshow(denorm_show(tgt_img))
axes[2].scatter(valid_tgt_kps[:, 0], valid_tgt_kps[:, 1], 
               c='red', s=100, marker='x', linewidths=3)
axes[2].set_title(f'Target Image\n({len(valid_tgt_kps)} keypoints)', fontsize=12, fontweight='bold')
axes[2].axis('off')

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'similarity_heatmap.png'), dpi=150, bbox_inches='tight')
plt.show()

print(f"\nSimilarity statistics:")
print(f"  Mean: {similarity_matrix.mean():.4f}")
print(f"  Max: {similarity_matrix.max():.4f}")
print(f"  Min: {similarity_matrix.min():.4f}")
print(f"  High similarity (>0.9): {(similarity_matrix > 0.9).sum()} patch pairs")

## Section 5: Correspondence Matcher

Now let's implement the correspondence matcher that finds matching keypoints between images.

In [None]:
class CorrespondenceMatcher:
    """
    Find correspondences between source and target images using feature similarity.
    
    Methods:
    - Nearest Neighbor (NN): Find target point with highest similarity
    - Mutual Nearest Neighbors (MNN): Enforce bidirectional consistency
    - Ratio Test: Reject ambiguous matches (Lowe's ratio test)
    """
    
    def __init__(self, mutual_nn=False, ratio_threshold=None):
        """
        Args:
            mutual_nn: If True, only keep mutual nearest neighbors
            ratio_threshold: If set, apply ratio test (e.g., 0.8)
        """
        self.mutual_nn = mutual_nn
        self.ratio_threshold = ratio_threshold
    
    def match_keypoints(self, src_features, tgt_features_2d, src_keypoints, 
                       feature_extractor):
        """
        Find correspondences for source keypoints in target image.
        
        Args:
            src_features: Source keypoint features [N, dim]
            tgt_features_2d: Target dense features [1, dim, H, W]
            src_keypoints: Source keypoint coordinates [N, 2]
            feature_extractor: Feature extractor instance
            
        Returns:
            pred_keypoints: Predicted target coordinates [N, 2]
            confidences: Match confidence scores [N]
        """
        N = src_features.shape[0]
        _, D, H, W = tgt_features_2d.shape
        
        # Reshape target features to [H*W, D]
        tgt_features_flat = tgt_features_2d.reshape(D, H * W).t()  # [H*W, D]
        
        # Compute similarity: [N, H*W]
        similarities = torch.mm(src_features, tgt_features_flat.t())
        
        # Find best matches
        max_sims, max_indices = similarities.max(dim=1)
        
        # Apply ratio test if specified
        if self.ratio_threshold is not None:
            # Get second best matches
            sorted_sims, _ = similarities.sort(dim=1, descending=True)
            ratios = sorted_sims[:, 0] / (sorted_sims[:, 1] + 1e-8)
            
            # Mark low-confidence matches
            valid_mask = ratios > self.ratio_threshold
            max_sims = max_sims * valid_mask.float()
        
        # Convert flat indices to 2D coordinates
        pred_y = (max_indices // W).float()
        pred_x = (max_indices % W).float()
        pred_coords_feat = torch.stack([pred_x, pred_y], dim=1).cpu().numpy()
        
        # Map to image coordinates
        pred_keypoints = feature_extractor.map_features_to_coords(pred_coords_feat)
        confidences = max_sims.cpu().numpy()
        
        # Apply mutual nearest neighbors if specified
        if self.mutual_nn:
            # Find reverse matches (target → source)
            reverse_sims = similarities.t()  # [H*W, N]
            _, reverse_indices = reverse_sims.max(dim=1)
            
            # Check mutual consistency
            forward_indices = max_indices.cpu().numpy()
            reverse_map = reverse_indices.cpu().numpy()
            
            for i in range(N):
                target_idx = forward_indices[i]
                if reverse_map[target_idx] != i:
                    # Not mutually consistent
                    confidences[i] = 0.0
        
        return pred_keypoints, confidences
    
    def match_images(self, src_img, tgt_img, src_keypoints, feature_extractor):
        """
        Complete matching pipeline for an image pair.
        
        Args:
            src_img: Source image tensor
            tgt_img: Target image tensor
            src_keypoints: Source keypoint coordinates [N, 2]
            feature_extractor: Feature extractor instance
            
        Returns:
            pred_keypoints: Predicted target keypoints [N, 2]
            confidences: Match confidence scores [N]
        """
        # Extract features
        src_kp_features = feature_extractor.extract_keypoint_features(src_img, src_keypoints)
        _, tgt_features_2d = feature_extractor.extract_features(tgt_img)
        
        # Match
        pred_keypoints, confidences = self.match_keypoints(
            src_kp_features, tgt_features_2d, src_keypoints, feature_extractor
        )
        
        return pred_keypoints, confidences

# Create matcher
matcher = CorrespondenceMatcher(mutual_nn=False, ratio_threshold=None)
print("✓ Correspondence matcher created")
print(f"  Mutual NN: {matcher.mutual_nn}")
print(f"  Ratio test: {matcher.ratio_threshold}")

## Section 6: PCK Evaluator

Implement PCK (Percentage of Correct Keypoints) evaluation metric.

In [None]:
class PCKEvaluator:
    """
    Evaluate correspondence quality using PCK (Percentage of Correct Keypoints).
    
    A keypoint is correct if:
        ||predicted - ground_truth|| ≤ α × bbox_diagonal
    
    Standard thresholds: α ∈ {0.05, 0.10, 0.15}
    """
    
    def __init__(self, alpha_values=[0.05, 0.10, 0.15]):
        """
        Args:
            alpha_values: List of PCK thresholds
        """
        self.alpha_values = alpha_values
    
    def compute_pck(self, pred_kps, gt_kps, bbox=None, img_size=(224, 224)):
        """
        Compute PCK for a single image pair.
        
        Args:
            pred_kps: Predicted keypoints [N, 2]
            gt_kps: Ground truth keypoints [N, 2]
            bbox: Bounding box [4] as [x1, y1, x2, y2] (optional)
            img_size: Image size (H, W) for normalization if no bbox
            
        Returns:
            pck_dict: Dictionary with PCK@alpha for each threshold
            distances: Normalized distances for each keypoint
        """
        # Filter valid keypoints (ground truth with positive coordinates)
        valid_mask = (gt_kps[:, 0] >= 0) & (gt_kps[:, 1] >= 0)
        
        if valid_mask.sum() == 0:
            # No valid keypoints
            return {f'pck@{alpha:.2f}': 0.0 for alpha in self.alpha_values}, np.array([])
        
        pred_valid = pred_kps[valid_mask]
        gt_valid = gt_kps[valid_mask]
        
        # Compute distances
        distances = np.linalg.norm(pred_valid - gt_valid, axis=1)
        
        # Compute normalization factor
        if bbox is not None:
            # Use bounding box diagonal
            bbox_w = bbox[2] - bbox[0]
            bbox_h = bbox[3] - bbox[1]
            norm_factor = np.sqrt(bbox_w ** 2 + bbox_h ** 2)
        else:
            # Use image diagonal
            norm_factor = np.sqrt(img_size[0] ** 2 + img_size[1] ** 2)
        
        # Normalize distances
        normalized_distances = distances / (norm_factor + 1e-8)
        
        # Compute PCK for each threshold
        pck_dict = {}
        for alpha in self.alpha_values:
            correct = (normalized_distances <= alpha).sum()
            pck = correct / len(normalized_distances)
            pck_dict[f'pck@{alpha:.2f}'] = pck
        
        return pck_dict, normalized_distances
    
    def evaluate_dataset(self, predictions, ground_truths, bboxes=None):
        """
        Evaluate PCK over entire dataset.
        
        Args:
            predictions: List of predicted keypoints [N_samples, N_kps, 2]
            ground_truths: List of ground truth keypoints [N_samples, N_kps, 2]
            bboxes: List of bounding boxes (optional)
            
        Returns:
            avg_pck: Dictionary with average PCK across dataset
            per_sample_pck: List of per-sample PCK dictionaries
        """
        per_sample_pck = []
        all_distances = []
        
        for i in range(len(predictions)):
            bbox = bboxes[i] if bboxes is not None else None
            pck_dict, distances = self.compute_pck(predictions[i], ground_truths[i], bbox)
            per_sample_pck.append(pck_dict)
            all_distances.extend(distances.tolist())
        
        # Compute average PCK
        avg_pck = {}
        for alpha in self.alpha_values:
            key = f'pck@{alpha:.2f}'
            avg_pck[key] = np.mean([sample[key] for sample in per_sample_pck])
        
        return avg_pck, per_sample_pck, np.array(all_distances)

# Create evaluator
evaluator = PCKEvaluator(alpha_values=[0.05, 0.10, 0.15])
print("✓ PCK evaluator created")
print(f"  Alpha values: {evaluator.alpha_values}")

## Section 7: Test on Sample

Let's test the complete pipeline on our sample image pair.

In [None]:
# Match keypoints for the sample
print("Matching keypoints for sample...")

# Get valid keypoints
valid_src_mask = (src_kps[:, 0] >= 0) & (src_kps[:, 1] >= 0)
valid_src_kps = src_kps[valid_src_mask]

print(f"Source keypoints: {len(valid_src_kps)}")

# Perform matching
pred_kps, confidences = matcher.match_images(
    src_img, tgt_img, valid_src_kps, feature_extractor
)

print(f"Predicted keypoints: {pred_kps.shape}")
print(f"Confidences: min={confidences.min():.3f}, max={confidences.max():.3f}, mean={confidences.mean():.3f}")

# Evaluate
valid_tgt_kps = tgt_kps[valid_src_mask]
pck_dict, distances = evaluator.compute_pck(pred_kps, valid_tgt_kps)

print("\nPCK Results:")
for key, value in pck_dict.items():
    print(f"  {key}: {value:.4f} ({value*100:.2f}%)")

print(f"\nDistance statistics:")
print(f"  Min: {distances.min():.4f}")
print(f"  Max: {distances.max():.4f}")
print(f"  Mean: {distances.mean():.4f}")
print(f"  Median: {np.median(distances):.4f}")

In [None]:
# Visualize matches
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Source image with keypoints
axes[0].imshow(denorm_show(src_img))
axes[0].scatter(valid_src_kps[:, 0], valid_src_kps[:, 1], 
               c='red', s=150, marker='x', linewidths=3, label='Source KPs')
axes[0].set_title('Source Image', fontsize=14, fontweight='bold')
axes[0].axis('off')
axes[0].legend()

# Side-by-side comparison
axes[1].imshow(denorm_show(tgt_img))
axes[1].scatter(valid_tgt_kps[:, 0], valid_tgt_kps[:, 1], 
               c='lime', s=150, marker='o', alpha=0.6, linewidths=2, 
               edgecolors='darkgreen', label='Ground Truth')
axes[1].scatter(pred_kps[:, 0], pred_kps[:, 1], 
               c='red', s=100, marker='x', linewidths=3, label='Predicted')
axes[1].set_title(f'Target Image: Predictions vs Ground Truth\nPCK@0.10: {pck_dict["pck@0.10"]*100:.1f}%', 
                 fontsize=14, fontweight='bold')
axes[1].axis('off')
axes[1].legend()

# Distance histogram
axes[2].hist(distances, bins=30, color='steelblue', alpha=0.7, edgecolor='black')
for alpha in evaluator.alpha_values:
    axes[2].axvline(alpha, color='red', linestyle='--', linewidth=2, 
                   label=f'α={alpha:.2f}')
axes[2].set_xlabel('Normalized Distance', fontsize=12)
axes[2].set_ylabel('Frequency', fontsize=12)
axes[2].set_title('Error Distribution', fontsize=14, fontweight='bold')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'sample_matching_result.png'), dpi=150, bbox_inches='tight')
plt.show()

print(f"✓ Visualization saved")

## Section 8: Full Dataset Evaluation

Now let's evaluate on the complete test set.

In [None]:
def evaluate_on_dataset(dataset, feature_extractor, matcher, evaluator, 
                       max_samples=None, save_visualizations=False):
    """
    Evaluate correspondence on entire dataset.
    
    Args:
        dataset: SPairDataset instance
        feature_extractor: DINOv2FeatureExtractor instance
        matcher: CorrespondenceMatcher instance
        evaluator: PCKEvaluator instance
        max_samples: Maximum samples to evaluate (None = all)
        save_visualizations: Whether to save sample visualizations
        
    Returns:
        results: Dictionary with evaluation metrics
    """
    print(f"Evaluating on {len(dataset)} samples...")
    
    all_predictions = []
    all_ground_truths = []
    all_confidences = []
    
    num_samples = min(max_samples, len(dataset)) if max_samples else len(dataset)
    
    for idx in tqdm(range(num_samples), desc="Evaluating"):
        sample = dataset[idx]
        
        src_img = sample['src_img']
        tgt_img = sample['trg_img']
        src_kps = sample['src_kps']
        tgt_kps = sample['trg_kps']
        
        # Get valid keypoints
        valid_mask = (src_kps[:, 0] >= 0) & (src_kps[:, 1] >= 0)
        valid_src_kps = src_kps[valid_mask]
        valid_tgt_kps = tgt_kps[valid_mask]
        
        if len(valid_src_kps) == 0:
            continue
        
        # Match
        pred_kps, confidences = matcher.match_images(
            src_img, tgt_img, valid_src_kps, feature_extractor
        )
        
        all_predictions.append(pred_kps)
        all_ground_truths.append(valid_tgt_kps)
        all_confidences.append(confidences)
        
        # Save visualization for first few samples
        if save_visualizations and idx < 5:
            fig, axes = plt.subplots(1, 2, figsize=(14, 7))
            
            axes[0].imshow(denorm_show(src_img))
            axes[0].scatter(valid_src_kps[:, 0], valid_src_kps[:, 1], 
                           c='red', s=100, marker='x', linewidths=3)
            axes[0].set_title(f'Source (Sample {idx})', fontsize=12, fontweight='bold')
            axes[0].axis('off')
            
            axes[1].imshow(denorm_show(tgt_img))
            axes[1].scatter(valid_tgt_kps[:, 0], valid_tgt_kps[:, 1], 
                           c='lime', s=100, marker='o', alpha=0.6, linewidths=2, 
                           edgecolors='darkgreen', label='GT')
            axes[1].scatter(pred_kps[:, 0], pred_kps[:, 1], 
                           c='red', s=80, marker='x', linewidths=2, label='Pred')
            axes[1].set_title(f'Target (Sample {idx})', fontsize=12, fontweight='bold')
            axes[1].axis('off')
            axes[1].legend()
            
            plt.tight_layout()
            plt.savefig(os.path.join(OUTPUT_DIR, f'match_sample_{idx}.png'), 
                       dpi=150, bbox_inches='tight')
            plt.close()
    
    # Evaluate
    print("\nComputing PCK metrics...")
    avg_pck, per_sample_pck, all_distances = evaluator.evaluate_dataset(
        all_predictions, all_ground_truths
    )
    
    # Compute additional statistics
    all_confidences_flat = np.concatenate(all_confidences)
    
    results = {
        'avg_pck': avg_pck,
        'num_samples': num_samples,
        'num_keypoints': len(all_distances),
        'avg_confidence': float(all_confidences_flat.mean()),
        'distance_stats': {
            'mean': float(all_distances.mean()),
            'std': float(all_distances.std()),
            'median': float(np.median(all_distances)),
            'min': float(all_distances.min()),
            'max': float(all_distances.max())
        }
    }
    
    return results

# Run evaluation on small subset first (for testing)
print("Running evaluation on 50 samples...")
results = evaluate_on_dataset(
    dataset=dataset,
    feature_extractor=feature_extractor,
    matcher=matcher,
    evaluator=evaluator,
    max_samples=50,
    save_visualizations=True
)

print("\n" + "="*60)
print("EVALUATION RESULTS (50 samples)")
print("="*60)
print(f"\nPCK Metrics:")
for key, value in results['avg_pck'].items():
    print(f"  {key}: {value:.4f} ({value*100:.2f}%)")

print(f"\nDataset Statistics:")
print(f"  Samples evaluated: {results['num_samples']}")
print(f"  Total keypoints: {results['num_keypoints']}")
print(f"  Average confidence: {results['avg_confidence']:.4f}")

print(f"\nDistance Statistics:")
for key, value in results['distance_stats'].items():
    print(f"  {key}: {value:.4f}")

# Save results
results_path = os.path.join(OUTPUT_DIR, 'evaluation_results_subset.json')
with open(results_path, 'w') as f:
    json.dump(results, f, indent=2)
print(f"\n✓ Results saved to {results_path}")

## Section 9: Full Evaluation (Optional)

Uncomment and run this cell to evaluate on the complete test set. This will take longer (~30-60 minutes).

In [None]:
# # Full dataset evaluation
# print("Running FULL evaluation on all test samples...")
# print("This will take approximately 30-60 minutes...\n")

# results_full = evaluate_on_dataset(
#     dataset=dataset,
#     feature_extractor=feature_extractor,
#     matcher=matcher,
#     evaluator=evaluator,
#     max_samples=None,  # Use all samples
#     save_visualizations=False  # Don't save all visualizations
# )

# print("\n" + "="*60)
# print("FULL EVALUATION RESULTS")
# print("="*60)
# print(f"\nPCK Metrics:")
# for key, value in results_full['avg_pck'].items():
#     print(f"  {key}: {value:.4f} ({value*100:.2f}%)")

# print(f"\nDataset Statistics:")
# print(f"  Samples evaluated: {results_full['num_samples']}")
# print(f"  Total keypoints: {results_full['num_keypoints']}")
# print(f"  Average confidence: {results_full['avg_confidence']:.4f}")

# # Save results
# results_path = os.path.join(OUTPUT_DIR, 'evaluation_results_full.json')
# with open(results_path, 'w') as f:
#     json.dump(results_full, f, indent=2)
# print(f"\n✓ Results saved to {results_path}")

## Summary

### What we accomplished:

1. **Environment Setup**: Cross-platform compatible setup (Windows/Linux/macOS/Colab)

2. **DINOv2 Feature Extraction**:
   - Model: DINOv2 ViT-B/14
   - Features: 16×16×768 dense feature maps
   - Preprocessing: ImageNet normalization
   - Output: L2-normalized features for cosine similarity

3. **Feature Visualization**:
   - PCA visualization showing semantic structure
   - Patch-to-patch similarity heatmaps
   - Feature quality analysis

4. **Correspondence Matching**:
   - Nearest neighbor matching with bilinear interpolation
   - Optional mutual nearest neighbors
   - Optional ratio test for ambiguous matches

5. **PCK Evaluation**:
   - Standard thresholds: 0.05, 0.10, 0.15
   - Bounding box normalization
   - Per-sample and aggregate metrics

6. **Results** (50 samples):
   - See evaluation results above
   - Visualizations saved to `outputs/dinov2/`

### Key Insights:
- DINOv2 learns rich semantic features without supervision
- Features capture object parts and semantic regions
- Performance varies by object category and pose variation
- Higher similarity scores correlate with better matches

### Next Steps:
1. **Run full evaluation** (uncomment Section 9)
2. **Try different matching strategies**:
   - Enable mutual nearest neighbors: `matcher = CorrespondenceMatcher(mutual_nn=True)`
   - Enable ratio test: `matcher = CorrespondenceMatcher(ratio_threshold=0.8)`
3. **Compare with other backbones** (DINOv3, SAM)
4. **Analyze per-category performance**
5. **Try larger models** (dinov2_vitl14, dinov2_vitg14)

### Files Generated:
- `outputs/dinov2/sample_images_with_keypoints.png`
- `outputs/dinov2/pca_feature_visualization.png`
- `outputs/dinov2/similarity_heatmap.png`
- `outputs/dinov2/sample_matching_result.png`
- `outputs/dinov2/match_sample_*.png` (first 5 samples)
- `outputs/dinov2/evaluation_results_subset.json`