# Semantic Correspondence Project - Phase 1 Setup
## DINOv2, DINOv3, and SAM Backbones

This notebook sets up the infrastructure for semantic correspondence using:
- **DINOv2** (Facebook Research)
- **DINOv3** (Facebook Research)
- **SAM** (Segment Anything Model)
- **SD4Match** dataset for evaluation

**Professor's recommendations:**
- Use **Base (ViT-B)** versions for all backbones
- Use official repositories (not just Hugging Face) to access internal components
- Dataset splits: train (trn), validation (val), test (test)
- Always evaluate on test split only

## 1. Environment Setup & Dependencies

In [2]:
# Check if running on Google Colab
import sys
import os

IN_COLAB = 'google.colab' in sys.modules
print(f"Running on Colab: {IN_COLAB}")

# Set up paths
if IN_COLAB:
    from google.colab import drive 
    drive.mount('/content/drive')
    PROJECT_ROOT = '/content/AMLProject'
    DATA_ROOT = '/content/drive/MyDrive/AMLProject/data'  # Recommended: upload dataset to Drive
else:
    PROJECT_ROOT = os.getcwd()
    DATA_ROOT = os.path.join(PROJECT_ROOT, 'data')

CHECKPOINT_DIR = os.path.join(PROJECT_ROOT, 'checkpoints')
OUTPUT_DIR = os.path.join(PROJECT_ROOT, 'outputs')
MODEL_DIR = os.path.join(PROJECT_ROOT, 'models')

# Create directories if they don't exist
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(DATA_ROOT, exist_ok=True)

print(f"Project root: {PROJECT_ROOT}")
print(f"Data root: {DATA_ROOT}")
print(f"Checkpoint dir: {CHECKPOINT_DIR}")
print(f"Output dir: {OUTPUT_DIR}")

Running on Colab: False
Project root: /Users/giuliavarga/Desktop/2. AML/Project/AMLProject
Data root: /Users/giuliavarga/Desktop/2. AML/Project/AMLProject/data
Checkpoint dir: /Users/giuliavarga/Desktop/2. AML/Project/AMLProject/checkpoints
Output dir: /Users/giuliavarga/Desktop/2. AML/Project/AMLProject/outputs


In [3]:
# Install required packages
# Note: Use standard PyPI for Mac (no CUDA), use --index-url for Linux with CUDA
import platform
import sys

if platform.system() == 'Darwin':  # macOS
    print("üì± Detected macOS - installing CPU/MPS version")
    !pip install torch torchvision
    # torchaudio not needed for this project, skip if unavailable
    try:
        !pip install torchaudio
    except:
        print("‚ö†Ô∏è  torchaudio not available on this platform (not needed for project)")
elif 'google.colab' in sys.modules:  # Google Colab
    print("‚òÅÔ∏è  Detected Colab - using default installation")
    !pip install torch torchvision torchaudio
else:  # Linux with CUDA
    print("üñ•Ô∏è  Detected Linux - installing CUDA version")
    !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

!pip install opencv-python matplotlib numpy scipy tqdm
!pip install timm einops
!pip install pillow requests

üì± Detected macOS - installing CPU/MPS version
Collecting sympy==1.13.1 (from torch)
  Using cached sympy-1.13.1-py3-none-any.whl.metadata (12 kB)
Using cached sympy-1.13.1-py3-none-any.whl (6.2 MB)
Collecting sympy==1.13.1 (from torch)
  Using cached sympy-1.13.1-py3-none-any.whl.metadata (12 kB)
Using cached sympy-1.13.1-py3-none-any.whl (6.2 MB)
Installing collected packages: sympy
  Attempting uninstall: sympy
    Found existing installation: sympy 1.14.0
Installing collected packages: sympy
  Attempting uninstall: sympy
    Found existing installation: sympy 1.14.0
    Uninstalling sympy-1.14.0:
    Uninstalling sympy-1.14.0:
      Successfully uninstalled sympy-1.14.0
      Successfully uninstalled sympy-1.14.0
Successfully installed sympy-1.13.1
Successfully installed sympy-1.13.1
Collecting torch==2.9.1 (from torchaudio)
Collecting torch==2.9.1 (from torchaudio)
  Using cached torch-2.9.1-cp311-none-macosx_11_0_arm64.whl.metadata (30 kB)
Collecting sympy>=1.13.3 (from torch==

In [4]:
# Import common libraries
try:
    import torch
    print(f"‚úì PyTorch version: {torch.__version__}")
except ImportError:
    print("‚úó PyTorch not installed! Please run the installation cell (cell 4) first.")
    raise

import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
from pathlib import Path
from tqdm import tqdm

# Check device availability (CUDA, MPS, or CPU)
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"Using device: {device}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")
elif torch.backends.mps.is_available():
    device = torch.device('mps')
    print(f"Using device: {device} (Apple Silicon GPU)")
else:
    device = torch.device('cpu')
    print(f"Using device: {device} (CPU only)")

‚úì PyTorch version: 2.5.1
Using device: mps (Apple Silicon GPU)
Using device: mps (Apple Silicon GPU)


## 2. Dataset Setup - SD4Match

SD4Match is the dataset for semantic correspondence evaluation.
- **Repository**: https://github.com/ActiveVisionLab/SD4Match
- **Splits**: train (trn), validation (val), test
- **Usage**: Train on trn, validate on val, report final results on test only

In [5]:
# Clone SD4Match repository
sd4match_dir = os.path.join(PROJECT_ROOT, 'SD4Match')
if not os.path.exists(sd4match_dir):
    !git clone https://github.com/ActiveVisionLab/SD4Match.git "{sd4match_dir}"
    print("SD4Match repository cloned successfully")
else:
    print("SD4Match repository already exists")

# Add to Python path
if sd4match_dir not in sys.path:
    sys.path.insert(0, sd4match_dir)
    
print(f"SD4Match path: {sd4match_dir}")

SD4Match repository already exists
SD4Match path: /Users/giuliavarga/Desktop/2. AML/Project/AMLProject/SD4Match


In [6]:
# Dataset configuration
"""
After cloning SD4Match, download the dataset and place it in the DATA_ROOT directory.
If on Colab, upload to Google Drive for faster access across sessions.

Expected structure:
DATA_ROOT/
    SD4Match/
        trn/  (training split)
        val/  (validation split)
        test/ (test split)
"""

sd4match_data_dir = os.path.join(DATA_ROOT, 'SD4Match')
print(f"Dataset should be placed in: {sd4match_data_dir}")
print(f"Expected splits: trn/, val/, test/")
print("\nNote: Download instructions are in the SD4Match repository README")

Dataset should be placed in: /Users/giuliavarga/Desktop/2. AML/Project/AMLProject/data/SD4Match
Expected splits: trn/, val/, test/

Note: Download instructions are in the SD4Match repository README


In [7]:
# Download SD4Match benchmark datasets automatically
import requests
from pathlib import Path
import zipfile
import tarfile
import urllib.request
import shutil

def download_file(url, destination):
    """Download file with progress indication."""
    print(f"  Downloading from {url}")
    try:
        urllib.request.urlretrieve(url, destination)
        return True
    except Exception as e:
        print(f"  ‚úó Error: {e}")
        return False

def extract_zip(zip_path, extract_to):
    """Extract zip file."""
    print(f"  Extracting {os.path.basename(zip_path)}...")
    try:
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(extract_to)
        return True
    except Exception as e:
        print(f"  ‚úó Error extracting: {e}")
        return False

def download_sd4match_datasets(data_dir):
    """
    Download and extract SD4Match benchmark datasets.
    Includes: PF-Pascal, PF-Willow, and SPair-71k
    """
    print("="*60)
    print("SD4MATCH BENCHMARK DATASETS DOWNLOAD")
    print("="*60)
    
    os.makedirs(data_dir, exist_ok=True)
    
    # Dataset configurations
    datasets = {
        'pf-pascal': {
            'images_url': 'https://www.di.ens.fr/willow/research/proposalflow/dataset/PF-dataset-PASCAL.zip',
            'pairs_url': 'https://www.robots.ox.ac.uk/~xinghui/sd4match/pf-pascal_image_pairs.zip',
            'has_splits': True
        },
        'pf-willow': {
            'images_url': 'https://www.di.ens.fr/willow/research/proposalflow/dataset/PF-dataset.zip',
            'pairs_url': 'https://www.robots.ox.ac.uk/~xinghui/sd4match/test_pairs.csv',
            'has_splits': False
        },
        'spair-71k': {
            'images_url': 'http://cvlab.postech.ac.kr/research/SPair-71k/data/SPair-71k.tar.gz',
            'has_splits': True
        }
    }
    
    all_ready = True
    
    for dataset_name, config in datasets.items():
        dataset_path = os.path.join(data_dir, dataset_name)
        print(f"\nüì¶ {dataset_name.upper()}")
        print("-" * 40)
        
        # Check if already exists
        if os.path.exists(dataset_path) and os.listdir(dataset_path):
            print(f"  ‚úì Already exists at {dataset_path}")
            continue
        
        os.makedirs(dataset_path, exist_ok=True)
        
        # Download images
        print(f"  Downloading {dataset_name} images...")
        images_filename = os.path.basename(config['images_url'])
        images_path = os.path.join(data_dir, images_filename)
        
        if not os.path.exists(images_path):
            if download_file(config['images_url'], images_path):
                print(f"  ‚úì Downloaded {images_filename}")
            else:
                print(f"  ‚ö†Ô∏è  Failed to download images")
                all_ready = False
                continue
        
        # Extract images
        if images_filename.endswith('.zip'):
            extract_zip(images_path, dataset_path)
        elif images_filename.endswith('.tar.gz'):
            print(f"  Extracting {images_filename}...")
            with tarfile.open(images_path, 'r:gz') as tar:
                tar.extractall(dataset_path)
        
        # Download pairs/splits if applicable
        if 'pairs_url' in config:
            pairs_filename = os.path.basename(config['pairs_url'])
            pairs_path = os.path.join(data_dir, pairs_filename)
            
            print(f"  Downloading image pairs...")
            if download_file(config['pairs_url'], pairs_path):
                if pairs_filename.endswith('.zip'):
                    extract_zip(pairs_path, dataset_path)
                elif pairs_filename.endswith('.csv'):
                    shutil.move(pairs_path, os.path.join(dataset_path, pairs_filename))
                print(f"  ‚úì Downloaded pairs/splits")
        
        # Clean up zip files
        if os.path.exists(images_path):
            os.remove(images_path)
        
        print(f"  ‚úì {dataset_name} setup complete!")
    
    print("\n" + "="*60)
    
    if all_ready:
        print("‚úÖ All datasets downloaded successfully!")
        print(f"\nDatasets location: {data_dir}")
        print("\nStructure:")
        print(f"{data_dir}/")
        print("  ‚îú‚îÄ‚îÄ pf-pascal/")
        print("  ‚îú‚îÄ‚îÄ pf-willow/")
        print("  ‚îî‚îÄ‚îÄ spair-71k/")
    else:
        print("‚ö†Ô∏è  Some datasets failed to download automatically.")
        print("\nüì• MANUAL DOWNLOAD INSTRUCTIONS:")
        print("-" * 60)
        print("1. PF-Pascal: https://www.di.ens.fr/willow/research/proposalflow/")
        print("   Pairs: https://www.robots.ox.ac.uk/~xinghui/sd4match/pf-pascal_image_pairs.zip")
        print("\n2. PF-Willow: https://www.di.ens.fr/willow/research/proposalflow/")
        print("   Pairs: https://www.robots.ox.ac.uk/~xinghui/sd4match/test_pairs.csv")
        print("\n3. SPair-71k: http://cvlab.postech.ac.kr/research/SPair-71k/")
        print("-" * 60)
        
        if IN_COLAB:
            print("\nüí° FOR GOOGLE COLAB:")
            print("   1. Download datasets to your computer")
            print("   2. Upload to Google Drive")
            print("   3. Mount Drive and set DATA_ROOT accordingly")
    
    return all_ready

# Attempt to download datasets
print("‚è≥ Starting dataset download... This may take several minutes.\n")
dataset_ready = download_sd4match_datasets(sd4match_data_dir)

‚è≥ Starting dataset download... This may take several minutes.

SD4MATCH BENCHMARK DATASETS DOWNLOAD

üì¶ PF-PASCAL
----------------------------------------
  ‚úì Already exists at /Users/giuliavarga/Desktop/2. AML/Project/AMLProject/data/SD4Match/pf-pascal

üì¶ PF-WILLOW
----------------------------------------
  ‚úì Already exists at /Users/giuliavarga/Desktop/2. AML/Project/AMLProject/data/SD4Match/pf-willow

üì¶ SPAIR-71K
----------------------------------------
  ‚úì Already exists at /Users/giuliavarga/Desktop/2. AML/Project/AMLProject/data/SD4Match/spair-71k

‚úÖ All datasets downloaded successfully!

Datasets location: /Users/giuliavarga/Desktop/2. AML/Project/AMLProject/data/SD4Match

Structure:
/Users/giuliavarga/Desktop/2. AML/Project/AMLProject/data/SD4Match/
  ‚îú‚îÄ‚îÄ pf-pascal/
  ‚îú‚îÄ‚îÄ pf-willow/
  ‚îî‚îÄ‚îÄ spair-71k/


## 3. DINOv2 Backbone Setup

**Repository**: https://github.com/facebookresearch/dinov2  
**Model**: ViT-B (Base version)  
**Key**: Use official repo (not just Hugging Face) to access internal components

In [8]:
# Clone DINOv2 repository
dinov2_dir = os.path.join(MODEL_DIR, 'dinov2')
if not os.path.exists(dinov2_dir):
    !git clone https://github.com/facebookresearch/dinov2.git "{dinov2_dir}"
    print("DINOv2 repository cloned successfully")
else:
    print("DINOv2 repository already exists")

# Add to Python path
if dinov2_dir not in sys.path:
    sys.path.insert(0, dinov2_dir)

print(f"DINOv2 path: {dinov2_dir}")

DINOv2 repository already exists
DINOv2 path: /Users/giuliavarga/Desktop/2. AML/Project/AMLProject/models/dinov2


In [9]:
# Load DINOv2 ViT-B model
def load_dinov2_model(model_name='dinov2_vitb14', device='cuda'):
    """
    Load DINOv2 model from official repository.
    
    Available models:
    - dinov2_vits14: Small (ViT-S/14)
    - dinov2_vitb14: Base (ViT-B/14) - RECOMMENDED
    - dinov2_vitl14: Large (ViT-L/14)
    - dinov2_vitg14: Giant (ViT-G/14)
    
    The '14' indicates patch size of 14x14 pixels.
    """
    try:
        model = torch.hub.load('facebookresearch/dinov2', model_name)
        model = model.to(device)
        model.eval()
        print(f"‚úì DINOv2 model '{model_name}' loaded successfully")
        print(f"  - Patch size: 14x14")
        print(f"  - Device: {device}")
        return model
    except Exception as e:
        print(f"‚úó Error loading DINOv2: {e}")
        return None

# Load the Base model (ViT-B)
dinov2_model = load_dinov2_model('dinov2_vitb14', device=device)

Using cache found in /Users/giuliavarga/.cache/torch/hub/facebookresearch_dinov2_main


‚úì DINOv2 model 'dinov2_vitb14' loaded successfully
  - Patch size: 14x14
  - Device: mps


In [10]:
# DINOv2 feature extraction utility
def extract_dinov2_features(model, image, return_class_token=True, return_patch_tokens=True):
    """
    Extract features from DINOv2 model.
    
    Args:
        model: DINOv2 model
        image: PIL Image or tensor (C, H, W) in range [0, 1]
        return_class_token: Return [CLS] token
        return_patch_tokens: Return patch tokens
        
    Returns:
        Dictionary containing requested features
    """
    from torchvision import transforms
    
    # Prepare image
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    if isinstance(image, Image.Image):
        image = transform(image).unsqueeze(0)
    elif image.dim() == 3:
        image = image.unsqueeze(0)
    
    image = image.to(next(model.parameters()).device)
    
    # Extract features
    with torch.no_grad():
        features = model.forward_features(image)
        
    result = {}
    if return_class_token:
        result['cls_token'] = features['x_norm_clstoken']
    if return_patch_tokens:
        result['patch_tokens'] = features['x_norm_patchtokens']
    
    return result

print("DINOv2 feature extraction utility defined")

DINOv2 feature extraction utility defined


## 4. DINOv3 Backbone Setup

**Repository**: https://github.com/facebookresearch/dinov3  
**Model**: ViT-B (Base version)  
**Key**: Request access to checkpoints, then download pretrained weights

In [10]:
# Clone DINOv3 repository
dinov3_dir = os.path.join(MODEL_DIR, 'dinov3')
if not os.path.exists(dinov3_dir):
    !git clone https://github.com/facebookresearch/dinov3.git "{dinov3_dir}"
    print("DINOv3 repository cloned successfully")
else:
    print("DINOv3 repository already exists")

# Add to Python path
if dinov3_dir not in sys.path:
    sys.path.insert(0, dinov3_dir)

print(f"DINOv3 path: {dinov3_dir}")
print("\n‚ö†Ô∏è  IMPORTANT: Request access and download DINOv3 checkpoints")
print("   Follow instructions in the DINOv3 repository README")

DINOv3 repository already exists
DINOv3 path: /Users/giuliavarga/Desktop/2. AML/Project/AMLProject/models/dinov3

‚ö†Ô∏è  IMPORTANT: Request access and download DINOv3 checkpoints
   Follow instructions in the DINOv3 repository README


In [11]:
# DINOv3 checkpoint configuration
dinov3_checkpoint_dir = os.path.join(CHECKPOINT_DIR, 'dinov3')
os.makedirs(dinov3_checkpoint_dir, exist_ok=True)

# Expected checkpoint path for ViT-B
dinov3_checkpoint_path = os.path.join(dinov3_checkpoint_dir, 'dinov3_vitb14_pretrain.pth')

print(f"DINOv3 checkpoint directory: {dinov3_checkpoint_dir}")
print(f"Expected checkpoint path: {dinov3_checkpoint_path}")
print("\nAfter obtaining access, download the ViT-B checkpoint to this location")

DINOv3 checkpoint directory: /Users/giuliavarga/Desktop/2. AML/Project/AMLProject/checkpoints/dinov3
Expected checkpoint path: /Users/giuliavarga/Desktop/2. AML/Project/AMLProject/checkpoints/dinov3/dinov3_vitb14_pretrain.pth

After obtaining access, download the ViT-B checkpoint to this location


In [12]:
# Load DINOv3 model (after checkpoint is downloaded)
def load_dinov3_model(checkpoint_path, device='cuda'):
    """
    Load DINOv3 model from checkpoint.
    
    Args:
        checkpoint_path: Path to the downloaded checkpoint
        device: Device to load model on
        
    Returns:
        Loaded DINOv3 model
    """
    if not os.path.exists(checkpoint_path):
        print(f"‚úó Checkpoint not found: {checkpoint_path}")
        print("  Please download the DINOv3 checkpoint after requesting access")
        return None
    
    try:
        # This will be updated once checkpoint structure is known
        # Placeholder for actual loading code
        print(f"‚úì Loading DINOv3 from: {checkpoint_path}")
        
        # Import DINOv3 modules (adjust based on actual repo structure)
        # from dinov3.models import build_model
        # model = build_model(checkpoint_path)
        # model = model.to(device)
        # model.eval()
        
        print("‚úì DINOv3 model loaded successfully")
        print(f"  - Device: {device}")
        return None  # Will return actual model after implementation
    except Exception as e:
        print(f"‚úó Error loading DINOv3: {e}")
        return None

# Note: Uncomment and run after downloading checkpoint
# dinov3_model = load_dinov3_model(dinov3_checkpoint_path, device=device)
print("DINOv3 loader defined (run after downloading checkpoint)")

DINOv3 loader defined (run after downloading checkpoint)


## 5. SAM (Segment Anything) Backbone Setup

**Repository**: https://github.com/facebookresearch/segment-anything  
**Model**: ViT-B (Base version) - RECOMMENDED  
**Optional**: Can experiment with ViT-L (Large) or ViT-H (Huge) for comparison

In [11]:
# Install SAM
!pip install git+https://github.com/facebookresearch/segment-anything.git

Collecting git+https://github.com/facebookresearch/segment-anything.git
  Cloning https://github.com/facebookresearch/segment-anything.git to /private/var/folders/kp/dmvkcybs4k72tbdpsb3zxlrh0000gn/T/pip-req-build-gt5zw5tq
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/segment-anything.git /private/var/folders/kp/dmvkcybs4k72tbdpsb3zxlrh0000gn/T/pip-req-build-gt5zw5tq
  Resolved https://github.com/facebookresearch/segment-anything.git to commit dca509fe793f601edb92606367a655c15ac00fdf
  Installing build dependencies ... [?25l  Resolved https://github.com/facebookresearch/segment-anything.git to commit dca509fe793f601edb92606367a655c15ac00fdf
  Installing build dependencies ... [?25l-done
[?25h  Getting requirements to build wheel ... [?25done
[?25h  Getting requirements to build wheel ... [?25l-done
[?25h  Preparing metadata (pyproject.toml) ... [?25done
[?25h  Preparing metadata (pyproject.toml) ... [?25l-done
[?25done
[?25h

### Troubleshooting: torch/torchvision Version Mismatch

If you encounter `RuntimeError: operator torchvision::nms does not exist` when importing SAM, this means your `torch` and `torchvision` versions are mismatched. The compiled C++ operators in torchvision don't match your PyTorch installation.

**Steps to fix:**
1. Run the diagnostic cell below to check versions
2. If mismatch detected, run the fix cell to reinstall compatible versions
3. Restart the kernel
4. Re-run the diagnostic to verify

The fix uses conda to ensure binary compatibility between torch and torchvision.

In [1]:
# Diagnostic: Check torch/torchvision versions
import torch
import torchvision

print("="*60)
print("TORCH/TORCHVISION VERSION CHECK")
print("="*60)
print(f"torch version: {torch.__version__}")
print(f"torchvision version: {torchvision.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")

print("\n" + "-"*60)
print("Checking torchvision.ops.nms availability...")
try:
    import torchvision.ops as ops
    print(f"‚úì torchvision.ops imported successfully")
    print(f"  has nms attribute: {hasattr(ops, 'nms')}")
    if hasattr(ops, 'nms'):
        print(f"  ‚úì nms operator is available")
    else:
        print(f"  ‚úó nms operator NOT found")
except Exception as e:
    print(f"‚úó Error: {e}")

print("="*60)

  Referenced from: <EB3FF92A-5EB1-3EE8-AF8B-5923C1265422> /opt/anaconda3/envs/aml_project/lib/python3.11/site-packages/torchvision/image.so
  warn(


TORCH/TORCHVISION VERSION CHECK
torch version: 2.5.1
torchvision version: 0.20.1
CUDA available: False

------------------------------------------------------------
Checking torchvision.ops.nms availability...
‚úì torchvision.ops imported successfully
  has nms attribute: True
  ‚úì nms operator is available


In [12]:
# FIX: Reinstall matching torch/torchvision versions
# This uses conda to ensure binary compatibility between torch and torchvision

import platform
import sys

print("="*60)
print("FIXING TORCH/TORCHVISION MISMATCH")
print("="*60)

if platform.system() == 'Darwin':  # macOS
    print("üì± Detected macOS - Reinstalling compatible versions via conda")
    print("\nExecuting: conda install pytorch torchvision -c pytorch -y")
    print("-"*60)
    !conda install pytorch torchvision -c pytorch -y
    
elif 'google.colab' in sys.modules:  # Google Colab
    print("‚òÅÔ∏è Detected Colab - Reinstalling via pip")
    !pip uninstall -y torch torchvision
    !pip install torch torchvision --no-cache-dir
    
else:  # Linux (possibly with CUDA)
    print("üñ•Ô∏è Detected Linux - Reinstalling compatible versions via conda")
    print("\nIf you have CUDA, this will install the CUDA-enabled version.")
    print("Executing: conda install pytorch torchvision pytorch-cuda -c pytorch -c nvidia -y")
    print("-"*60)
    !conda install pytorch torchvision pytorch-cuda -c pytorch -c nvidia -y

print("\n" + "="*60)
print("‚úì Reinstallation complete!")
print("="*60)
print("\n‚ö†Ô∏è IMPORTANT: Restart the kernel to use the new installation!")
print("   In Jupyter/VSCode: Kernel ‚Üí Restart Kernel")
print("\nThen re-run the diagnostic cell above to verify the fix.")

FIXING TORCH/TORCHVISION MISMATCH
üì± Detected macOS - Reinstalling compatible versions via conda

Executing: conda install pytorch torchvision -c pytorch -y
------------------------------------------------------------
[1;32m2[0m[1;32m channel Terms of Service accepted[0m
[1;32m2[0m[1;32m channel Terms of Service accepted[0m
Channels:
 - pytorch
 - defaults
Platform: osx-arm64
Collecting package metadata (repodata.json): - Channels:
 - pytorch
 - defaults
Platform: osx-arm64
Collecting package metadata (repodata.json)\ done
Solving environment: done
Solving environment\ done
done

## Package Plan ##

  environment location: /opt/anaconda3/envs/aml_project

  added / updated specs:
    - pytorch
    - torchvision


The following NEW packages will be INSTALLED:

  pytorch            pytorch/osx-arm64::pytorch-2.5.1-py3.11_0 



## Package Plan ##

  environment location: /opt/anaconda3/envs/aml_project

  added / updated specs:
    - pytorch
    - torchvision


The following NEW

**‚ö†Ô∏è After running the fix above:**
- **Restart the kernel** (Kernel ‚Üí Restart Kernel)
- Re-run the diagnostic cell to verify the fix worked
- Then proceed to download SAM checkpoints below

In [12]:
# Download SAM checkpoints
import urllib.request
import os

sam_checkpoint_dir = os.path.join(CHECKPOINT_DIR, 'sam')
os.makedirs(sam_checkpoint_dir, exist_ok=True)

# SAM model checkpoints
SAM_MODELS = {
    'vit_b': {
        'url': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
        'filename': 'sam_vit_b_01ec64.pth',
        'description': 'ViT-B (Base) - RECOMMENDED'
    },
    'vit_l': {
        'url': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
        'filename': 'sam_vit_l_0b3195.pth',
        'description': 'ViT-L (Large) - Optional comparison'
    },
    'vit_h': {
        'url': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth',
        'filename': 'sam_vit_h_4b8939.pth',
        'description': 'ViT-H (Huge) - Optional comparison'
    }
}

def download_sam_checkpoint(model_type='vit_b'):
    """Download SAM checkpoint if not already present."""
    if model_type not in SAM_MODELS:
        print(f"Invalid model type. Choose from: {list(SAM_MODELS.keys())}")
        return None
    
    model_info = SAM_MODELS[model_type]
    checkpoint_path = os.path.join(sam_checkpoint_dir, model_info['filename'])
    
    if os.path.exists(checkpoint_path):
        print(f"‚úì Checkpoint already exists: {checkpoint_path}")
        return checkpoint_path
    
    print(f"Downloading {model_info['description']}...")
    print(f"URL: {model_info['url']}")
    try:
        urllib.request.urlretrieve(model_info['url'], checkpoint_path)
        print(f"‚úì Downloaded successfully: {checkpoint_path}")
        return checkpoint_path
    except Exception as e:
        print(f"‚úó Error downloading: {e}")
        return None

# Download ViT-B checkpoint (recommended)
sam_checkpoint_path = download_sam_checkpoint('vit_b')

print(f"\nSAM checkpoint directory: {sam_checkpoint_dir}")
print("Available models:", list(SAM_MODELS.keys()))

‚úì Checkpoint already exists: /Users/giuliavarga/Desktop/2. AML/Project/AMLProject/checkpoints/sam/sam_vit_b_01ec64.pth

SAM checkpoint directory: /Users/giuliavarga/Desktop/2. AML/Project/AMLProject/checkpoints/sam
Available models: ['vit_b', 'vit_l', 'vit_h']


In [13]:
# Load SAM model
try:
    from segment_anything import sam_model_registry, SamPredictor
    SAM_IMPORT_SUCCESS = True
except RuntimeError as e:
    if 'torchvision::nms does not exist' in str(e):
        print("="*60)
        print("‚ö†Ô∏è  SAM IMPORT ERROR: torch/torchvision mismatch detected")
        print("="*60)
        print("Error: operator torchvision::nms does not exist")
        print("\nThis means your torch and torchvision versions are incompatible.")
        print("\nüìã TO FIX:")
        print("   1. Scroll up to find the diagnostic cell (after 'Install SAM')")
        print("   2. Run the fix cell to reinstall matching versions")
        print("   3. Restart the kernel (Kernel ‚Üí Restart Kernel)")
        print("   4. Re-run this cell")
        print("="*60)
    SAM_IMPORT_SUCCESS = False
    sam_model_registry = None
    SamPredictor = None

def load_sam_model(checkpoint_path, model_type='vit_b', device='cuda'):
    """
    Load SAM model.
    
    Args:
        checkpoint_path: Path to checkpoint
        model_type: 'vit_b', 'vit_l', or 'vit_h'
        device: Device to load on
        
    Returns:
        SAM model and predictor
    """
    if not SAM_IMPORT_SUCCESS:
        print("‚úó Cannot load SAM: import failed (see error above)")
        return None, None
        
    if not os.path.exists(checkpoint_path):
        print(f"‚úó Checkpoint not found: {checkpoint_path}")
        return None, None
    
    try:
        sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
        sam = sam.to(device)
        sam.eval()
        
        # Create predictor for easier inference
        predictor = SamPredictor(sam)
        
        print(f"‚úì SAM model loaded successfully")
        print(f"  - Model type: {model_type}")
        print(f"  - Device: {device}")
        print(f"  - Checkpoint: {checkpoint_path}")
        
        return sam, predictor
    except Exception as e:
        print(f"‚úó Error loading SAM: {e}")
        return None, None

# Load SAM ViT-B
if SAM_IMPORT_SUCCESS and sam_checkpoint_path:
    sam_model, sam_predictor = load_sam_model(sam_checkpoint_path, 'vit_b', device=device)
elif not SAM_IMPORT_SUCCESS:
    print("‚ö†Ô∏è  Skipping SAM model loading due to import error")
    sam_model, sam_predictor = None, None
else:
    print("SAM checkpoint not available yet")
    sam_model, sam_predictor = None, None

  state_dict = torch.load(f)


‚úì SAM model loaded successfully
  - Model type: vit_b
  - Device: mps
  - Checkpoint: /Users/giuliavarga/Desktop/2. AML/Project/AMLProject/checkpoints/sam/sam_vit_b_01ec64.pth


In [14]:
# SAM feature extraction utility
def extract_sam_features(sam_model, image):
    """
    Extract features from SAM image encoder.
    
    Args:
        sam_model: SAM model
        image: PIL Image or numpy array (H, W, 3) in RGB format
        
    Returns:
        Image embeddings from SAM encoder
    """
    import numpy as np
    from segment_anything.utils.transforms import ResizeLongestSide
    
    # Convert PIL to numpy if needed
    if isinstance(image, Image.Image):
        image = np.array(image)
    
    # SAM preprocessing
    transform = ResizeLongestSide(sam_model.image_encoder.img_size)
    input_image = transform.apply_image(image)
    input_image_torch = torch.as_tensor(input_image, device=sam_model.device)
    input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
    
    # Extract features
    with torch.no_grad():
        image_embedding = sam_model.image_encoder(input_image_torch)
    
    return image_embedding

print("SAM feature extraction utility defined")

SAM feature extraction utility defined


## 6. Utility Functions & Configuration

In [15]:
# Configuration class for the project
class ProjectConfig:
    """Central configuration for the semantic correspondence project."""
    
    def __init__(self):
        # Paths
        self.project_root = PROJECT_ROOT
        self.data_root = DATA_ROOT
        self.checkpoint_dir = CHECKPOINT_DIR
        self.output_dir = OUTPUT_DIR
        self.model_dir = MODEL_DIR
        
        # Dataset
        self.dataset_name = 'SD4Match'
        self.splits = ['trn', 'val', 'test']
        
        # Models
        self.backbones = {
            'dinov2': 'dinov2_vitb14',
            'dinov3': 'dinov3_vitb14',
            'sam': 'vit_b'
        }
        
        # Device
        self.device = device
        
        # Training (to be filled in later phases)
        self.batch_size = 16
        self.num_epochs = 100
        self.learning_rate = 1e-4
        
    def __repr__(self):
        return f"""ProjectConfig:
  Project Root: {self.project_root}
  Data Root: {self.data_root}
  Device: {self.device}
  Dataset: {self.dataset_name}
  Backbones: {list(self.backbones.keys())}
"""

config = ProjectConfig()
print(config)

ProjectConfig:
  Project Root: /Users/giuliavarga/Desktop/2. AML/Project/AMLProject
  Data Root: /Users/giuliavarga/Desktop/2. AML/Project/AMLProject/data
  Device: mps
  Dataset: SD4Match
  Backbones: ['dinov2', 'dinov3', 'sam']



In [16]:
# Visualization utilities
def visualize_correspondence(img1, img2, pts1, pts2, matches=None, figsize=(15, 7)):
    """
    Visualize correspondence between two images.
    
    Args:
        img1, img2: Images (PIL or numpy)
        pts1, pts2: Keypoint coordinates [(x, y), ...]
        matches: Optional list of match indices [(idx1, idx2), ...]
        figsize: Figure size
    """
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
    
    # Display images
    ax1.imshow(img1)
    ax1.set_title('Image 1')
    ax1.axis('off')
    
    ax2.imshow(img2)
    ax2.set_title('Image 2')
    ax2.axis('off')
    
    # Plot keypoints
    if pts1 is not None and len(pts1) > 0:
        pts1 = np.array(pts1)
        ax1.scatter(pts1[:, 0], pts1[:, 1], c='red', s=50, marker='x')
    
    if pts2 is not None and len(pts2) > 0:
        pts2 = np.array(pts2)
        ax2.scatter(pts2[:, 0], pts2[:, 1], c='red', s=50, marker='x')
    
    plt.tight_layout()
    return fig

def save_model_checkpoint(model, optimizer, epoch, path, **kwargs):
    """Save model checkpoint with metadata."""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict() if optimizer else None,
        **kwargs
    }
    torch.save(checkpoint, path)
    print(f"‚úì Checkpoint saved: {path}")

def load_model_checkpoint(model, path, optimizer=None, device='cuda'):
    """Load model checkpoint."""
    if not os.path.exists(path):
        print(f"‚úó Checkpoint not found: {path}")
        return None
    
    checkpoint = torch.load(path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    if optimizer and checkpoint.get('optimizer_state_dict'):
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    epoch = checkpoint.get('epoch', 0)
    print(f"‚úì Checkpoint loaded from epoch {epoch}")
    return checkpoint

print("Visualization and checkpoint utilities defined")

Visualization and checkpoint utilities defined


## 7. Model Summary & Testing

Quick tests to verify all models are loaded correctly.

In [17]:
# Summary of loaded models
print("="*60)
print("MODEL SETUP SUMMARY")
print("="*60)

models_status = {
    'DINOv2 (ViT-B)': dinov2_model is not None if 'dinov2_model' in locals() else False,
    'DINOv3 (ViT-B)': False,  # To be loaded after checkpoint download
    'SAM (ViT-B)': (sam_model is not None) if 'sam_model' in locals() else False,
}

for model_name, status in models_status.items():
    status_symbol = "‚úì" if status else "‚ö†"
    status_text = "Loaded" if status else "Not loaded yet"
    print(f"{status_symbol} {model_name}: {status_text}")

print("\n" + "="*60)
print("NEXT STEPS")
print("="*60)
print("1. DINOv3: Request access and download checkpoint")
print("2. SD4Match: Download dataset to", sd4match_data_dir)
print("3. Verify all models work with test images")
print("4. Ready for team to implement correspondence methods")
print("="*60)

MODEL SETUP SUMMARY
‚úì DINOv2 (ViT-B): Loaded
‚ö† DINOv3 (ViT-B): Not loaded yet
‚úì SAM (ViT-B): Loaded

NEXT STEPS
1. DINOv3: Request access and download checkpoint
2. SD4Match: Download dataset to /Users/giuliavarga/Desktop/2. AML/Project/AMLProject/data/SD4Match
3. Verify all models work with test images
4. Ready for team to implement correspondence methods


In [18]:
# Test with a dummy image (optional)
def test_model_inference():
    """Quick test to verify models can process images."""
    # Create a dummy image
    dummy_image = Image.new('RGB', (224, 224), color='red')
    
    print("Testing model inference with dummy image...")
    print("-" * 40)
    
    # Test DINOv2
    if 'dinov2_model' in locals() and dinov2_model is not None:
        try:
            features = extract_dinov2_features(dinov2_model, dummy_image)
            print(f"‚úì DINOv2: CLS token shape = {features['cls_token'].shape}")
            print(f"           Patch tokens shape = {features['patch_tokens'].shape}")
        except Exception as e:
            print(f"‚úó DINOv2 error: {e}")
    else:
        print("‚ö† DINOv2: Not loaded")
    
    # Test SAM
    if 'sam_model' in locals() and sam_model is not None:
        try:
            embedding = extract_sam_features(sam_model, dummy_image)
            print(f"‚úì SAM: Embedding shape = {embedding.shape}")
        except Exception as e:
            print(f"‚úó SAM error: {e}")
    else:
        print("‚ö† SAM: Not loaded")
    
    print("-" * 40)
    print("Model inference test complete")

# Uncomment to run test
# test_model_inference()

## 8. Additional Resources & Notes

### Window Soft Argmax (GeoAware-SC)
For prediction refinement in later phases:
- **Repository**: https://github.com/Junyi42/geoaware-sc
- This will be used for refining correspondence predictions

### Professor's Key Recommendations Summary:
1. **Backbone Selection**: Use Base (ViT-B) versions for all three backbones
2. **Model Access**: 
   - DINOv2: Use official repo, not just Hugging Face
   - DINOv3: Request access to checkpoints
   - SAM: ViT-B recommended, can compare with L/H if compute allows
3. **Dataset Splits**:
   - Train on `trn` split
   - Validate on `val` split for model selection
   - **Only report final results on `test` split**
4. **Backbone Size Trade-offs**:
   - Larger backbones (Small ‚Üí Base ‚Üí Large) generally improve performance
   - But gains are not always consistent across tasks
   - Increased size = higher compute/memory/time costs

### For Team Members (Later Phases):
- All infrastructure is ready for implementing correspondence methods
- Models are loaded and ready to extract features
- Utilities for visualization and checkpointing are provided
- Follow the professor's evaluation protocol strictly

## 9. Dataset Loaders

This section defines dataset classes for loading correspondence benchmarks:
- **CorrespondenceDataset**: Base class for all datasets
- **PFPascalDataset**: PF-Pascal dataset with CSV-based annotations
- **SPairDataset**: SPair-71k dataset with JSON-based annotations

Each dataset returns:
- Source and target images
- Source and target keypoints
- Category information
- Bounding boxes (for PCK normalization)

In [None]:
# Dataset base class and utilities
import json
import pandas as pd
from torch.utils.data import Dataset

class CorrespondenceDataset(Dataset):
    """Base class for semantic correspondence datasets."""
    
    def __init__(self, root_dir, split='test', transform=None):
        """
        Args:
            root_dir: Root directory of the dataset
            split: 'trn', 'val', or 'test'
            transform: Optional transforms to apply to images
        """
        self.root_dir = Path(root_dir)
        self.split = split
        self.transform = transform
        self.pairs = []
        
    def __len__(self):
        return len(self.pairs)
    
    def load_image(self, path):
        """Load and optionally transform an image."""
        img = Image.open(path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img
    
    def __getitem__(self, idx):
        raise NotImplementedError("Subclasses must implement __getitem__")


class PFPascalDataset(CorrespondenceDataset):
    """PF-Pascal dataset loader."""
    
    def __init__(self, root_dir, split='test', transform=None):
        super().__init__(root_dir, split, transform)
        self.load_annotations()
    
    def load_annotations(self):
        """Load image pairs and keypoint annotations."""
        anno_file = self.root_dir / 'pf-pascal_image_pairs' / f'{self.split}_pairs.csv'
        
        if not anno_file.exists():
            print(f"‚ö†Ô∏è  Annotation file not found: {anno_file}")
            print("   Make sure you've downloaded the dataset")
            return
        
        # Load pairs
        df = pd.read_csv(anno_file)
        
        for _, row in df.iterrows():
            pair = {
                'source_img': self.root_dir / 'PF-dataset-PASCAL' / row['source_image'],
                'target_img': self.root_dir / 'PF-dataset-PASCAL' / row['target_image'],
                'source_kps': self._parse_keypoints(row['source_keypoints']),
                'target_kps': self._parse_keypoints(row['target_keypoints']),
                'category': row.get('category', 'unknown')
            }
            self.pairs.append(pair)
        
        print(f"‚úì Loaded {len(self.pairs)} pairs from PF-Pascal {self.split} split")
    
    def _parse_keypoints(self, kps_str):
        """Parse keypoint string to numpy array."""
        # Format: "x1,y1;x2,y2;..." or similar
        if pd.isna(kps_str) or kps_str == '':
            return np.array([])
        
        kps = []
        for kp in str(kps_str).split(';'):
            if kp.strip():
                coords = [float(x) for x in kp.split(',')]
                kps.append(coords)
        return np.array(kps)
    
    def __getitem__(self, idx):
        pair = self.pairs[idx]
        
        source_img = self.load_image(pair['source_img'])
        target_img = self.load_image(pair['target_img'])
        
        return {
            'source_image': source_img,
            'target_image': target_img,
            'source_keypoints': pair['source_kps'],
            'target_keypoints': pair['target_kps'],
            'category': pair['category']
        }


class SPairDataset(CorrespondenceDataset):
    """SPair-71k dataset loader."""
    
    def __init__(self, root_dir, split='test', transform=None):
        super().__init__(root_dir, split, transform)
        self.load_annotations()
    
    def load_annotations(self):
        """Load annotations from SPair-71k."""
        # SPair uses different split names
        split_map = {'trn': 'trn', 'val': 'val', 'test': 'test'}
        split_name = split_map.get(self.split, 'test')
        
        anno_dir = self.root_dir / 'SPair-71k' / 'PairAnnotation' / split_name
        
        if not anno_dir.exists():
            print(f"‚ö†Ô∏è  Annotation directory not found: {anno_dir}")
            return
        
        # Load all annotation files
        for anno_file in sorted(anno_dir.glob('*.json')):
            with open(anno_file, 'r') as f:
                data = json.load(f)
                
            pair = {
                'source_img': self.root_dir / 'SPair-71k' / 'ImageAnnotation' / data['src_imname'],
                'target_img': self.root_dir / 'SPair-71k' / 'ImageAnnotation' / data['trg_imname'],
                'source_kps': np.array(data['src_kps']).T,  # [N, 2]
                'target_kps': np.array(data['trg_kps']).T,  # [N, 2]
                'category': data.get('category', 'unknown'),
                'source_bbox': np.array(data.get('src_bndbox', [])),
                'target_bbox': np.array(data.get('trg_bndbox', []))
            }
            self.pairs.append(pair)
        
        print(f"‚úì Loaded {len(self.pairs)} pairs from SPair-71k {self.split} split")
    
    def __getitem__(self, idx):
        pair = self.pairs[idx]
        
        source_img = self.load_image(pair['source_img'])
        target_img = self.load_image(pair['target_img'])
        
        return {
            'source_image': source_img,
            'target_image': target_img,
            'source_keypoints': pair['source_kps'],
            'target_keypoints': pair['target_kps'],
            'category': pair['category'],
            'source_bbox': pair.get('source_bbox'),
            'target_bbox': pair.get('target_bbox')
        }


print("‚úì Dataset classes defined")

ModuleNotFoundError: No module named 'pandas'

## 10. Dense Feature Extraction

The `DenseFeatureExtractor` class extracts spatial feature maps from vision backbones:
- Supports **DINOv2** (ViT-B/14: 16√ó16 patches for 224√ó224 input)
- Supports **SAM** (ViT-B: 64√ó64 features for 1024√ó1024 input)
- Handles coordinate mapping between original image space and feature space
- Extracts features at specific keypoint locations

In [None]:
class DenseFeatureExtractor:
    """Extract dense features from images for correspondence."""
    
    def __init__(self, backbone='dinov2', model=None, device='cuda'):
        """
        Args:
            backbone: 'dinov2' or 'sam'
            model: Pre-loaded model (optional)
            device: Device to run on
        """
        self.backbone = backbone
        self.device = device
        self.model = model
        
        if backbone == 'dinov2':
            self.patch_size = 14
            self.feat_dim = 768  # ViT-B feature dimension
        elif backbone == 'sam':
            self.patch_size = 16  # SAM uses 16x16 patches
            self.feat_dim = 256  # SAM image encoder output
    
    def extract_features(self, image, return_numpy=True):
        """
        Extract dense features from an image.
        
        Args:
            image: PIL Image or tensor
            return_numpy: Return numpy array instead of tensor
            
        Returns:
            features: Dense feature map [H', W', D]
            Original image size for coordinate mapping
        """
        if self.backbone == 'dinov2':
            return self._extract_dinov2(image, return_numpy)
        elif self.backbone == 'sam':
            return self._extract_sam(image, return_numpy)
        else:
            raise ValueError(f"Unknown backbone: {self.backbone}")
    
    def _extract_dinov2(self, image, return_numpy=True):
        """Extract features using DINOv2."""
        from torchvision import transforms
        
        # Get original size
        if isinstance(image, Image.Image):
            orig_size = image.size  # (W, H)
        else:
            orig_size = (image.shape[2], image.shape[1])
        
        # Prepare image (224x224 for DINOv2)
        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        if isinstance(image, Image.Image):
            img_tensor = transform(image).unsqueeze(0)
        else:
            img_tensor = image.unsqueeze(0) if image.dim() == 3 else image
        
        img_tensor = img_tensor.to(self.device)
        
        # Extract features
        with torch.no_grad():
            features = self.model.forward_features(img_tensor)
            patch_tokens = features['x_norm_patchtokens']  # [1, N, D]
        
        # Reshape to spatial grid
        # DINOv2 ViT-B/14 produces 16x16 = 256 patches for 224x224 image
        h = w = int(np.sqrt(patch_tokens.shape[1]))
        feature_map = patch_tokens.reshape(1, h, w, -1)[0]  # [H, W, D]
        
        if return_numpy:
            feature_map = feature_map.cpu().numpy()
        
        return {
            'features': feature_map,
            'feature_size': (h, w),
            'original_size': orig_size,
            'processed_size': (224, 224)
        }
    
    def _extract_sam(self, image, return_numpy=True):
        """Extract features using SAM."""
        import numpy as np
        from segment_anything.utils.transforms import ResizeLongestSide
        
        # Get original size
        if isinstance(image, Image.Image):
            image_np = np.array(image)
            orig_size = image.size  # (W, H)
        else:
            image_np = image
            orig_size = (image_np.shape[1], image_np.shape[0])
        
        # SAM preprocessing
        transform = ResizeLongestSide(self.model.image_encoder.img_size)
        input_image = transform.apply_image(image_np)
        input_image_torch = torch.as_tensor(input_image, device=self.device)
        input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
        
        # Extract features
        with torch.no_grad():
            image_embedding = self.model.image_encoder(input_image_torch)
        
        # SAM outputs [1, 256, 64, 64] for 1024x1024 input
        feature_map = image_embedding[0].permute(1, 2, 0)  # [H, W, D]
        
        if return_numpy:
            feature_map = feature_map.cpu().numpy()
        
        h, w = image_embedding.shape[2], image_embedding.shape[3]
        
        return {
            'features': feature_map,
            'feature_size': (h, w),
            'original_size': orig_size,
            'processed_size': input_image.shape[:2]
        }
    
    def extract_features_at_keypoints(self, image, keypoints):
        """
        Extract features at specific keypoint locations.
        
        Args:
            image: PIL Image
            keypoints: Keypoint coordinates [N, 2] in original image space
            
        Returns:
            features: Feature vectors at keypoints [N, D]
        """
        feat_dict = self.extract_features(image, return_numpy=False)
        features = feat_dict['features']  # [H, W, D]
        feat_h, feat_w = feat_dict['feature_size']
        orig_w, orig_h = feat_dict['original_size']
        
        # Map keypoints from original space to feature space
        scale_x = feat_w / orig_w
        scale_y = feat_h / orig_h
        
        feat_kps = keypoints.copy()
        feat_kps[:, 0] = feat_kps[:, 0] * scale_x
        feat_kps[:, 1] = feat_kps[:, 1] * scale_y
        
        # Clip to valid range
        feat_kps[:, 0] = np.clip(feat_kps[:, 0], 0, feat_w - 1)
        feat_kps[:, 1] = np.clip(feat_kps[:, 1], 0, feat_h - 1)
        
        # Round to integer indices
        feat_kps = feat_kps.astype(int)
        
        # Extract features
        if isinstance(features, torch.Tensor):
            kp_features = features[feat_kps[:, 1], feat_kps[:, 0], :]
            return kp_features.cpu().numpy()
        else:
            return features[feat_kps[:, 1], feat_kps[:, 0], :]

print("‚úì Dense feature extractor defined")

## 11. Correspondence Matching

The `CorrespondenceMatcher` class finds correspondences between feature maps:
- **Cosine similarity** with L2 normalization
- **Nearest neighbor** matching
- **Mutual nearest neighbor** constraint (optional)
- **Lowe's ratio test** (optional)

Matches source keypoints to target image locations based on feature similarity.

In [None]:
class CorrespondenceMatcher:
    """Match correspondences between two sets of features."""
    
    def __init__(self, method='nn', mutual=False, ratio_test=None):
        """
        Args:
            method: 'nn' (nearest neighbor) or 'mutual_nn'
            mutual: Use mutual nearest neighbor constraint
            ratio_test: Lowe's ratio test threshold (None to disable)
        """
        self.method = method
        self.mutual = mutual
        self.ratio_test = ratio_test
    
    def match(self, features_src, features_tgt):
        """
        Find correspondences between source and target features.
        
        Args:
            features_src: Source features [N, D] or [H, W, D]
            features_tgt: Target features [M, D] or [H', W', D]
            
        Returns:
            matches: Matched indices [(src_idx, tgt_idx), ...]
            scores: Match confidence scores
        """
        # Flatten if spatial
        if features_src.ndim == 3:
            h_src, w_src, d = features_src.shape
            features_src_flat = features_src.reshape(-1, d)
        else:
            features_src_flat = features_src
            h_src = w_src = None
        
        if features_tgt.ndim == 3:
            h_tgt, w_tgt, d = features_tgt.shape
            features_tgt_flat = features_tgt.reshape(-1, d)
        else:
            features_tgt_flat = features_tgt
            h_tgt = w_tgt = None
        
        # Compute distance matrix
        # Using cosine similarity (dot product after L2 normalization)
        features_src_norm = features_src_flat / (np.linalg.norm(features_src_flat, axis=1, keepdims=True) + 1e-8)
        features_tgt_norm = features_tgt_flat / (np.linalg.norm(features_tgt_flat, axis=1, keepdims=True) + 1e-8)
        
        similarity = features_src_norm @ features_tgt_norm.T  # [N, M]
        
        # Nearest neighbor matching
        src_to_tgt = np.argmax(similarity, axis=1)  # [N]
        src_scores = np.max(similarity, axis=1)  # [N]
        
        matches = []
        scores = []
        
        if self.mutual:
            # Mutual nearest neighbors
            tgt_to_src = np.argmax(similarity, axis=0)  # [M]
            
            for src_idx in range(len(features_src_flat)):
                tgt_idx = src_to_tgt[src_idx]
                if tgt_to_src[tgt_idx] == src_idx:  # Mutual match
                    matches.append((src_idx, tgt_idx))
                    scores.append(src_scores[src_idx])
        else:
            # All nearest neighbors
            for src_idx in range(len(features_src_flat)):
                tgt_idx = src_to_tgt[src_idx]
                
                # Optional ratio test
                if self.ratio_test is not None:
                    sorted_sim = np.sort(similarity[src_idx])[::-1]
                    if len(sorted_sim) > 1:
                        ratio = sorted_sim[0] / (sorted_sim[1] + 1e-8)
                        if ratio < self.ratio_test:
                            continue
                
                matches.append((src_idx, tgt_idx))
                scores.append(src_scores[src_idx])
        
        return np.array(matches), np.array(scores)
    
    def match_keypoints(self, src_image, tgt_image, src_kps, feature_extractor):
        """
        Match source keypoints to target image.
        
        Args:
            src_image: Source PIL Image
            tgt_image: Target PIL Image
            src_kps: Source keypoints [N, 2]
            feature_extractor: DenseFeatureExtractor instance
            
        Returns:
            predicted_kps: Predicted target keypoints [N, 2]
            confidence: Match confidence scores [N]
        """
        # Extract dense features
        src_feat_dict = feature_extractor.extract_features(src_image, return_numpy=True)
        tgt_feat_dict = feature_extractor.extract_features(tgt_image, return_numpy=True)
        
        src_features = src_feat_dict['features']  # [H, W, D]
        tgt_features = tgt_feat_dict['features']  # [H', W', D]
        
        # Get source keypoint features
        src_kp_features = feature_extractor.extract_features_at_keypoints(src_image, src_kps)
        
        # Match to target feature map
        tgt_h, tgt_w, tgt_d = tgt_features.shape
        tgt_features_flat = tgt_features.reshape(-1, tgt_d)
        
        # Normalize features
        src_kp_norm = src_kp_features / (np.linalg.norm(src_kp_features, axis=1, keepdims=True) + 1e-8)
        tgt_norm = tgt_features_flat / (np.linalg.norm(tgt_features_flat, axis=1, keepdims=True) + 1e-8)
        
        # Find nearest neighbors
        similarity = src_kp_norm @ tgt_norm.T  # [N, H'*W']
        best_matches = np.argmax(similarity, axis=1)
        confidence = np.max(similarity, axis=1)
        
        # Convert flat indices to 2D coordinates in feature space
        match_y = best_matches // tgt_w
        match_x = best_matches % tgt_w
        
        # Map back to original image coordinates
        orig_w, orig_h = tgt_feat_dict['original_size']
        scale_x = orig_w / tgt_w
        scale_y = orig_h / tgt_h
        
        predicted_kps = np.stack([match_x * scale_x, match_y * scale_y], axis=1)
        
        return predicted_kps, confidence


print("‚úì Correspondence matcher defined")

## 12. Evaluation Metrics (PCK)

The `PCKEvaluator` class computes **Percentage of Correct Keypoints (PCK)**:
- Multiple thresholds: Œ± = [0.05, 0.10, 0.15]
- Normalization by bbox diagonal or image diagonal
- Batch evaluation across entire datasets
- Per-category performance tracking

A keypoint is "correct" if predicted location is within Œ± √ó normalization_distance from ground truth.

In [None]:
class PCKEvaluator:
    """Evaluate correspondence using Percentage of Correct Keypoints (PCK)."""
    
    def __init__(self, alpha_values=[0.05, 0.10, 0.15], use_bbox=True):
        """
        Args:
            alpha_values: Threshold values for PCK@alpha
            use_bbox: Normalize by bounding box size (else use image size)
        """
        self.alpha_values = alpha_values
        self.use_bbox = use_bbox
    
    def compute_pck(self, predicted_kps, gt_kps, image_size=None, bbox=None):
        """
        Compute PCK for a single image pair.
        
        Args:
            predicted_kps: Predicted keypoints [N, 2]
            gt_kps: Ground truth keypoints [N, 2]
            image_size: (width, height) of target image
            bbox: Bounding box [x, y, w, h] for normalization
            
        Returns:
            pck_scores: Dict of PCK@alpha values
        """
        if len(predicted_kps) == 0 or len(gt_kps) == 0:
            return {f'PCK@{alpha}': 0.0 for alpha in self.alpha_values}
        
        # Compute distances
        distances = np.linalg.norm(predicted_kps - gt_kps, axis=1)
        
        # Compute normalization factor
        if self.use_bbox and bbox is not None and len(bbox) == 4:
            # Normalize by bounding box diagonal
            norm_factor = np.sqrt(bbox[2]**2 + bbox[3]**2)
        elif image_size is not None:
            # Normalize by image diagonal
            norm_factor = np.sqrt(image_size[0]**2 + image_size[1]**2)
        else:
            # No normalization
            norm_factor = 1.0
        
        # Compute PCK at different thresholds
        pck_scores = {}
        for alpha in self.alpha_values:
            threshold = alpha * norm_factor
            correct = (distances <= threshold).sum()
            pck = correct / len(distances)
            pck_scores[f'PCK@{alpha}'] = pck
        
        return pck_scores
    
    def evaluate_dataset(self, predictions, ground_truth, image_sizes=None, bboxes=None):
        """
        Evaluate PCK over entire dataset.
        
        Args:
            predictions: List of predicted keypoints arrays
            ground_truth: List of ground truth keypoints arrays
            image_sizes: List of (width, height) tuples
            bboxes: List of bounding boxes
            
        Returns:
            results: Dict with mean PCK and per-sample results
        """
        all_pck_scores = {f'PCK@{alpha}': [] for alpha in self.alpha_values}
        per_sample_results = []
        
        for i, (pred_kps, gt_kps) in enumerate(zip(predictions, ground_truth)):
            img_size = image_sizes[i] if image_sizes else None
            bbox = bboxes[i] if bboxes else None
            
            pck = self.compute_pck(pred_kps, gt_kps, img_size, bbox)
            per_sample_results.append(pck)
            
            for key, value in pck.items():
                all_pck_scores[key].append(value)
        
        # Compute mean PCK
        mean_pck = {key: np.mean(values) for key, values in all_pck_scores.items()}
        
        results = {
            'mean': mean_pck,
            'per_sample': per_sample_results,
            'num_samples': len(predictions)
        }
        
        return results
    
    def print_results(self, results):
        """Pretty print evaluation results."""
        print("="*60)
        print("PCK EVALUATION RESULTS")
        print("="*60)
        print(f"Number of samples: {results['num_samples']}")
        print("\nMean PCK scores:")
        for key, value in sorted(results['mean'].items()):
            print(f"  {key}: {value*100:.2f}%")
        print("="*60)


print("‚úì PCK evaluator defined")

## 13. End-to-End Evaluation Pipeline

The `evaluate_correspondence()` function wraps the entire pipeline:
1. Feature extraction from source and target images
2. Correspondence matching with selected algorithm
3. PCK evaluation at multiple thresholds
4. Progress tracking with tqdm

Returns predictions, ground truth, confidences, and PCK scores.

In [None]:
def evaluate_correspondence(model, dataset, backbone='dinov2', device='cuda', 
                           max_samples=None, mutual_nn=False):
    """
    End-to-end evaluation pipeline for semantic correspondence.
    
    Args:
        model: Pretrained model (DINOv2 or SAM)
        dataset: Dataset instance (PFPascalDataset or SPairDataset)
        backbone: 'dinov2' or 'sam'
        device: Device to run on
        max_samples: Limit number of samples (None for all)
        mutual_nn: Use mutual nearest neighbor matching
        
    Returns:
        results: Evaluation results including PCK scores
    """
    print("="*60)
    print(f"EVALUATING {backbone.upper()} on {dataset.__class__.__name__}")
    print("="*60)
    print(f"Total samples: {len(dataset)}")
    if max_samples:
        print(f"Evaluating on: {max_samples} samples")
    print(f"Mutual NN: {mutual_nn}")
    print("")
    
    # Initialize components
    feature_extractor = DenseFeatureExtractor(backbone=backbone, model=model, device=device)
    matcher = CorrespondenceMatcher(method='nn', mutual=mutual_nn)
    evaluator = PCKEvaluator(alpha_values=[0.05, 0.10, 0.15])
    
    # Collect predictions and ground truth
    predictions = []
    ground_truths = []
    image_sizes = []
    bboxes = []
    confidences = []
    
    num_samples = min(max_samples, len(dataset)) if max_samples else len(dataset)
    
    for i in tqdm(range(num_samples), desc="Processing pairs"):
        sample = dataset[i]
        
        src_img = sample['source_image']
        tgt_img = sample['target_image']
        src_kps = sample['source_keypoints']
        tgt_kps = sample['target_keypoints']
        
        if len(src_kps) == 0 or len(tgt_kps) == 0:
            continue
        
        # Match keypoints
        pred_kps, conf = matcher.match_keypoints(
            src_img, tgt_img, src_kps, feature_extractor
        )
        
        predictions.append(pred_kps)
        ground_truths.append(tgt_kps)
        confidences.append(conf)
        
        # Get image size
        if isinstance(tgt_img, Image.Image):
            image_sizes.append(tgt_img.size)  # (W, H)
        else:
            image_sizes.append((tgt_img.shape[2], tgt_img.shape[1]))
        
        # Get bbox if available
        if 'target_bbox' in sample and sample['target_bbox'] is not None:
            bboxes.append(sample['target_bbox'])
        else:
            bboxes.append(None)
    
    # Evaluate
    results = evaluator.evaluate_dataset(
        predictions, ground_truths, image_sizes, bboxes
    )
    
    # Print results
    evaluator.print_results(results)
    
    # Add additional info
    results['predictions'] = predictions
    results['ground_truth'] = ground_truths
    results['confidences'] = confidences
    results['backbone'] = backbone
    results['mutual_nn'] = mutual_nn
    
    return results


print("‚úì Evaluation pipeline defined")

## 14. Visualization Utilities

Advanced visualization functions for analyzing correspondence results:
- **visualize_matches()**: Shows source/target images with predicted and GT keypoints
- **visualize_feature_similarity()**: Displays feature similarity heatmaps for debugging

Helps understand model behavior and identify failure cases.

In [None]:
def visualize_matches(src_img, tgt_img, src_kps, pred_kps, gt_kps=None, 
                     max_points=20, figsize=(20, 8), save_path=None):
    """
    Visualize correspondence matches between two images.
    
    Args:
        src_img: Source image (PIL or numpy)
        tgt_img: Target image (PIL or numpy)
        src_kps: Source keypoints [N, 2]
        pred_kps: Predicted target keypoints [N, 2]
        gt_kps: Ground truth target keypoints [N, 2] (optional)
        max_points: Maximum number of points to visualize
        figsize: Figure size
        save_path: Path to save figure (optional)
    """
    import matplotlib.pyplot as plt
    import matplotlib.patches as patches
    from matplotlib.lines import Line2D
    
    # Convert to numpy if needed
    if isinstance(src_img, Image.Image):
        src_img = np.array(src_img)
    if isinstance(tgt_img, Image.Image):
        tgt_img = np.array(tgt_img)
    
    # Limit number of points for clarity
    if len(src_kps) > max_points:
        indices = np.random.choice(len(src_kps), max_points, replace=False)
        src_kps = src_kps[indices]
        pred_kps = pred_kps[indices]
        if gt_kps is not None:
            gt_kps = gt_kps[indices]
    
    # Create figure
    if gt_kps is not None:
        fig, axes = plt.subplots(1, 3, figsize=figsize)
        ax_src, ax_pred, ax_gt = axes
    else:
        fig, axes = plt.subplots(1, 2, figsize=(figsize[0]*2/3, figsize[1]))
        ax_src, ax_pred = axes
        ax_gt = None
    
    # Plot source image with keypoints
    ax_src.imshow(src_img)
    ax_src.scatter(src_kps[:, 0], src_kps[:, 1], c='red', s=100, marker='o', 
                   edgecolors='white', linewidths=2, label='Source KPs')
    ax_src.set_title('Source Image', fontsize=14, fontweight='bold')
    ax_src.axis('off')
    
    # Plot target image with predicted keypoints
    ax_pred.imshow(tgt_img)
    ax_pred.scatter(pred_kps[:, 0], pred_kps[:, 1], c='blue', s=100, marker='x', 
                    linewidths=3, label='Predicted KPs')
    ax_pred.set_title('Target Image (Predictions)', fontsize=14, fontweight='bold')
    ax_pred.axis('off')
    
    # Plot target with ground truth if available
    if gt_kps is not None and ax_gt is not None:
        ax_gt.imshow(tgt_img)
        ax_gt.scatter(gt_kps[:, 0], gt_kps[:, 1], c='green', s=100, marker='o', 
                     edgecolors='white', linewidths=2, label='Ground Truth')
        ax_gt.scatter(pred_kps[:, 0], pred_kps[:, 1], c='blue', s=50, marker='x', 
                     linewidths=2, alpha=0.7, label='Predicted')
        
        # Draw error lines
        for i in range(len(gt_kps)):
            ax_gt.plot([gt_kps[i, 0], pred_kps[i, 0]], 
                      [gt_kps[i, 1], pred_kps[i, 1]], 
                      'r--', alpha=0.3, linewidth=1)
        
        # Compute errors
        errors = np.linalg.norm(pred_kps - gt_kps, axis=1)
        mean_error = errors.mean()
        ax_gt.set_title(f'Ground Truth vs Predicted\nMean Error: {mean_error:.2f}px', 
                       fontsize=14, fontweight='bold')
        ax_gt.axis('off')
        ax_gt.legend(loc='upper right')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"‚úì Saved visualization to {save_path}")
    
    return fig


def visualize_feature_similarity(src_img, tgt_img, feature_extractor, kp_idx=0, src_kps=None):
    """
    Visualize feature similarity map for a keypoint.
    
    Args:
        src_img: Source image
        tgt_img: Target image
        feature_extractor: DenseFeatureExtractor instance
        kp_idx: Index of keypoint to visualize
        src_kps: Source keypoints [N, 2]
    """
    # Extract features
    src_feat_dict = feature_extractor.extract_features(src_img, return_numpy=True)
    tgt_feat_dict = feature_extractor.extract_features(tgt_img, return_numpy=True)
    
    src_features = src_feat_dict['features']
    tgt_features = tgt_feat_dict['features']
    
    # Get query feature
    if src_kps is not None and kp_idx < len(src_kps):
        query_feat = feature_extractor.extract_features_at_keypoints(src_img, src_kps[kp_idx:kp_idx+1])
    else:
        # Use center point
        h, w = src_features.shape[:2]
        query_feat = src_features[h//2, w//2:w//2+1, :]
    
    # Compute similarity map
    query_norm = query_feat / (np.linalg.norm(query_feat) + 1e-8)
    tgt_h, tgt_w, tgt_d = tgt_features.shape
    tgt_flat = tgt_features.reshape(-1, tgt_d)
    tgt_norm = tgt_flat / (np.linalg.norm(tgt_flat, axis=1, keepdims=True) + 1e-8)
    
    similarity = (query_norm @ tgt_norm.T).reshape(tgt_h, tgt_w)
    
    # Visualize
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    axes[0].imshow(src_img)
    if src_kps is not None and kp_idx < len(src_kps):
        axes[0].scatter(src_kps[kp_idx, 0], src_kps[kp_idx, 1], 
                       c='red', s=200, marker='*', edgecolors='white', linewidths=2)
    axes[0].set_title('Source Image (Query Point)', fontsize=12, fontweight='bold')
    axes[0].axis('off')
    
    axes[1].imshow(tgt_img)
    axes[1].set_title('Target Image', fontsize=12, fontweight='bold')
    axes[1].axis('off')
    
    im = axes[2].imshow(similarity, cmap='hot', interpolation='bilinear')
    axes[2].set_title('Feature Similarity Map', fontsize=12, fontweight='bold')
    axes[2].axis('off')
    plt.colorbar(im, ax=axes[2], fraction=0.046, pad=0.04)
    
    plt.tight_layout()
    return fig


print("‚úì Visualization utilities defined")

## 15. Example Usage & Experiments

The following cells demonstrate how to use the pipeline. Uncomment and run to experiment:
- **Example 1**: Load datasets
- **Example 2**: Evaluate on test split
- **Example 3**: Visualize matches
- **Example 4**: Compare different backbones

In [None]:
# Example 1: Load a dataset
# Uncomment and run after downloading datasets

# # Load PF-Pascal test split
# pf_pascal = PFPascalDataset(
#     root_dir=os.path.join(sd4match_data_dir, 'pf-pascal'),
#     split='test'
# )
# print(f"PF-Pascal test set: {len(pf_pascal)} pairs")

# # Load SPair-71k test split
# spair = SPairDataset(
#     root_dir=os.path.join(sd4match_data_dir, 'spair-71k'),
#     split='test'
# )
# print(f"SPair-71k test set: {len(spair)} pairs")

print("‚úì Dataset loading examples defined (uncomment to use)")

In [None]:
# Example 2: Evaluate DINOv2 on a dataset
# Uncomment and run after loading models and datasets

# if dinov2_model is not None:
#     # Evaluate on first 50 samples (for quick testing)
#     results_dinov2 = evaluate_correspondence(
#         model=dinov2_model,
#         dataset=pf_pascal,
#         backbone='dinov2',
#         device=device,
#         max_samples=50,
#         mutual_nn=False
#     )
#     
#     # Save results
#     import json
#     results_path = os.path.join(OUTPUT_DIR, 'dinov2_pfpascal_results.json')
#     with open(results_path, 'w') as f:
#         # Save only serializable parts
#         json.dump({
#             'mean': results_dinov2['mean'],
#             'num_samples': results_dinov2['num_samples'],
#             'backbone': results_dinov2['backbone']
#         }, f, indent=2)
#     print(f"‚úì Results saved to {results_path}")

print("‚úì Evaluation example defined (uncomment to use)")

In [None]:
# Example 3: Visualize a single correspondence
# Uncomment and run to visualize results

# if 'pf_pascal' in locals() and len(pf_pascal) > 0:
#     # Get a sample
#     sample_idx = 0
#     sample = pf_pascal[sample_idx]
#     
#     # Extract predictions
#     feature_extractor = DenseFeatureExtractor(backbone='dinov2', model=dinov2_model, device=device)
#     matcher = CorrespondenceMatcher(method='nn', mutual=False)
#     
#     pred_kps, conf = matcher.match_keypoints(
#         sample['source_image'],
#         sample['target_image'],
#         sample['source_keypoints'],
#         feature_extractor
#     )
#     
#     # Visualize
#     fig = visualize_matches(
#         sample['source_image'],
#         sample['target_image'],
#         sample['source_keypoints'],
#         pred_kps,
#         sample['target_keypoints'],
#         max_points=15,
#         save_path=os.path.join(OUTPUT_DIR, f'match_visualization_{sample_idx}.png')
#     )
#     plt.show()

print("‚úì Visualization example defined (uncomment to use)")

In [None]:
# Example 4: Compare different backbones
# Uncomment to run comparative experiments

# def compare_backbones(dataset, max_samples=100):
#     """Compare DINOv2 vs SAM on a dataset."""
#     results = {}
#     
#     # Evaluate DINOv2
#     if dinov2_model is not None:
#         print("\n" + "="*60)
#         print("EVALUATING DINOV2")
#         print("="*60)
#         results['dinov2'] = evaluate_correspondence(
#             model=dinov2_model,
#             dataset=dataset,
#             backbone='dinov2',
#             device=device,
#             max_samples=max_samples,
#             mutual_nn=False
#         )
#     
#     # Evaluate SAM
#     if sam_model is not None:
#         print("\n" + "="*60)
#         print("EVALUATING SAM")
#         print("="*60)
#         results['sam'] = evaluate_correspondence(
#             model=sam_model,
#             dataset=dataset,
#             backbone='sam',
#             device=device,
#             max_samples=max_samples,
#             mutual_nn=False
#         )
#     
#     # Print comparison
#     print("\n" + "="*60)
#     print("COMPARISON SUMMARY")
#     print("="*60)
#     for backbone, res in results.items():
#         print(f"\n{backbone.upper()}:")
#         for metric, value in res['mean'].items():
#             print(f"  {metric}: {value*100:.2f}%")
#     
#     return results
# 
# # Run comparison
# # comparison_results = compare_backbones(pf_pascal, max_samples=50)

print("‚úì Comparison example defined (uncomment to use)")

## 16. Project Summary & Next Steps

### ‚úÖ Completed Implementation

**Phase 1 - Infrastructure:**
- ‚úì DINOv2 ViT-B model loaded and ready
- ‚úì SAM ViT-B model loaded and ready  
- ‚úì Dataset download utilities for PF-Pascal, PF-Willow, SPair-71k
- ‚úì Environment configuration (paths, device detection)

**Phase 2 - Core Pipeline:**
- ‚úì Dataset loaders (`PFPascalDataset`, `SPairDataset`)
- ‚úì Dense feature extraction (`DenseFeatureExtractor`)
- ‚úì Correspondence matching (`CorrespondenceMatcher`)
  - Nearest neighbor matching
  - Mutual nearest neighbor option
  - Ratio test support
- ‚úì PCK evaluation metrics (`PCKEvaluator`)
  - PCK@0.05, PCK@0.10, PCK@0.15
  - Bbox and image size normalization
- ‚úì End-to-end evaluation pipeline
- ‚úì Visualization utilities

### üéØ How to Use

**Step 1: Ensure all setup cells are run**
```python
# Run cells 1-5 to set up environment
# Run cells for DINOv2 (section 3)
# Run cells for SAM (section 5)
```

**Step 2: Download datasets**
```python
# The dataset download cell (section 2) attempts automatic download
# Or manually download and place in DATA_ROOT/SD4Match/
```

**Step 3: Load a dataset**
```python
pf_pascal = PFPascalDataset(
    root_dir=os.path.join(sd4match_data_dir, 'pf-pascal'),
    split='test'
)
```

**Step 4: Run evaluation**
```python
results = evaluate_correspondence(
    model=dinov2_model,
    dataset=pf_pascal,
    backbone='dinov2',
    device=device,
    max_samples=50  # Start with small number
)
```

**Step 5: Visualize results**
```python
# Use visualization functions to inspect matches
```

### üìä Evaluation Protocol (Professor's Guidelines)

1. **Train on `trn` split** (if doing any training/fine-tuning)
2. **Validate on `val` split** for model selection and hyperparameter tuning
3. **Report final results ONLY on `test` split**
4. **Metrics**: PCK@0.05, PCK@0.10, PCK@0.15
5. **Backbones**: Compare DINOv2 ViT-B vs SAM ViT-B

### üî¨ Suggested Experiments

1. **Baseline Comparison**
   - DINOv2 ViT-B vs SAM ViT-B
   - With/without mutual nearest neighbor

2. **Hyperparameter Tuning** (on val split)
   - Matching thresholds
   - Feature normalization strategies
   - Ratio test thresholds

3. **Dataset Analysis**
   - Per-category performance
   - Effect of viewpoint changes
   - Effect of scale changes

4. **Advanced Methods** (optional)
   - Window soft argmax refinement (GeoAware-SC)
   - Multi-scale features
   - Feature aggregation strategies

### üìù Notes

- All code follows professor's recommendations (Base models, official repos, proper splits)
- The pipeline is modular - easy to swap backbones or add new methods
- Visualization utilities help debug and understand model behavior
- Start with small `max_samples` for quick iteration, then scale up

### üöÄ Ready to Run!

Uncomment the example cells in section 15 to start experiments.