# Semantic Correspondence Project - Phase 1 Setup
## DINOv2, DINOv3, and SAM Backbones

This notebook sets up the infrastructure for semantic correspondence using:
- **DINOv2** (Facebook Research)
- **DINOv3** (Facebook Research)
- **SAM** (Segment Anything Model)
- **SD4Match** dataset for evaluation

**Professor's recommendations:**
- Use **Base (ViT-B)** versions for all backbones
- Use official repositories (not just Hugging Face) to access internal components
- Dataset splits: train (trn), validation (val), test (test)
- Always evaluate on test split only

## 1. Environment Setup & Dependencies

In [3]:
# Check if running on Google Colab
import sys
import os

IN_COLAB = 'google.colab' in sys.modules
print(f"Running on Colab: {IN_COLAB}")

# Set up paths
if IN_COLAB:
    from google.colab import drive 
    drive.mount('/content/drive')
    PROJECT_ROOT = '/content/AMLProject'
    DATA_ROOT = '/content/drive/MyDrive/AMLProject/data'  # Recommended: upload dataset to Drive
else:
    PROJECT_ROOT = os.getcwd()
    DATA_ROOT = os.path.join(PROJECT_ROOT, 'data')

CHECKPOINT_DIR = os.path.join(PROJECT_ROOT, 'checkpoints')
OUTPUT_DIR = os.path.join(PROJECT_ROOT, 'outputs')
MODEL_DIR = os.path.join(PROJECT_ROOT, 'models')

# Create directories if they don't exist
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(DATA_ROOT, exist_ok=True)

print(f"Project root: {PROJECT_ROOT}")
print(f"Data root: {DATA_ROOT}")
print(f"Checkpoint dir: {CHECKPOINT_DIR}")
print(f"Output dir: {OUTPUT_DIR}")

Running on Colab: False
Project root: /Users/giuliavarga/Desktop/2. AML/Project/AMLProject
Data root: /Users/giuliavarga/Desktop/2. AML/Project/AMLProject/data
Checkpoint dir: /Users/giuliavarga/Desktop/2. AML/Project/AMLProject/checkpoints
Output dir: /Users/giuliavarga/Desktop/2. AML/Project/AMLProject/outputs


In [4]:
# Install required packages
# Note: Use standard PyPI for Mac (no CUDA), use --index-url for Linux with CUDA
import platform
import sys

if platform.system() == 'Darwin':  # macOS
    print("üì± Detected macOS - installing CPU/MPS version")
    !pip install torch torchvision
    # torchaudio not needed for this project, skip if unavailable
    try:
        !pip install torchaudio
    except:
        print("‚ö†Ô∏è  torchaudio not available on this platform (not needed for project)")
elif 'google.colab' in sys.modules:  # Google Colab
    print("‚òÅÔ∏è  Detected Colab - using default installation")
    !pip install torch torchvision torchaudio
else:  # Linux with CUDA
    print("üñ•Ô∏è  Detected Linux - installing CUDA version")
    !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

!pip install opencv-python matplotlib numpy scipy tqdm
!pip install timm einops
!pip install pillow requests

üì± Detected macOS - installing CPU/MPS version


In [4]:
# Import common libraries
try:
    import torch
    print(f"‚úì PyTorch version: {torch.__version__}")
except ImportError:
    print("‚úó PyTorch not installed! Please run the installation cell (cell 4) first.")
    raise

import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
from pathlib import Path
from tqdm import tqdm

# Check device availability (CUDA, MPS, or CPU)
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"Using device: {device}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")
elif torch.backends.mps.is_available():
    device = torch.device('mps')
    print(f"Using device: {device} (Apple Silicon GPU)")
else:
    device = torch.device('cpu')
    print(f"Using device: {device} (CPU only)")

‚úì PyTorch version: 2.5.1
Using device: mps (Apple Silicon GPU)


## 2. Dataset Setup - SD4Match

SD4Match is the dataset for semantic correspondence evaluation.
- **Repository**: https://github.com/ActiveVisionLab/SD4Match
- **Splits**: train (trn), validation (val), test
- **Usage**: Train on trn, validate on val, report final results on test only

In [6]:
# Clone SD4Match repository
sd4match_dir = os.path.join(PROJECT_ROOT, 'SD4Match')
if not os.path.exists(sd4match_dir):
    !git clone https://github.com/ActiveVisionLab/SD4Match.git "{sd4match_dir}"
    print("SD4Match repository cloned successfully")
else:
    print("SD4Match repository already exists")

# Add to Python path
if sd4match_dir not in sys.path:
    sys.path.insert(0, sd4match_dir)
    
print(f"SD4Match path: {sd4match_dir}")

Cloning into '/Users/giuliavarga/Desktop/2. AML/Project/AMLProject/SD4Match'...
remote: Enumerating objects: 146, done.[K
remote: Counting objects: 100% (44/44), done.[K
remote: Compressing objects: 100% (39/39), done.[K
remote: Enumerating objects: 146, done.[K
remote: Counting objects: 100% (44/44), done.[K
remote: Compressing objects: 100% (39/39), done.[K
remote: Total 146 (delta 15), reused 18 (delta 5), pack-reused 102 (from 1)[K
Receiving objects: 100% (146/146), 34.71 MiB | 6.89 MiB/s, done.
Resolving deltas: 100% (18/18), done.
remote: Total 146 (delta 15), reused 18 (delta 5), pack-reused 102 (from 1)[K
Receiving objects: 100% (146/146), 34.71 MiB | 6.89 MiB/s, done.
Resolving deltas: 100% (18/18), done.
SD4Match repository cloned successfully
SD4Match path: /Users/giuliavarga/Desktop/2. AML/Project/AMLProject/SD4Match
SD4Match repository cloned successfully
SD4Match path: /Users/giuliavarga/Desktop/2. AML/Project/AMLProject/SD4Match


In [7]:
# Dataset configuration
"""
After cloning SD4Match, download the dataset and place it in the DATA_ROOT directory.
If on Colab, upload to Google Drive for faster access across sessions.

Expected structure:
DATA_ROOT/
    SD4Match/
        trn/  (training split)
        val/  (validation split)
        test/ (test split)
"""

sd4match_data_dir = os.path.join(DATA_ROOT, 'SD4Match')
print(f"Dataset should be placed in: {sd4match_data_dir}")
print(f"Expected splits: trn/, val/, test/")
print("\nNote: Download instructions are in the SD4Match repository README")

Dataset should be placed in: /Users/giuliavarga/Desktop/2. AML/Project/AMLProject/data/SD4Match
Expected splits: trn/, val/, test/

Note: Download instructions are in the SD4Match repository README


In [10]:
# Download SD4Match benchmark datasets automatically
import requests
from pathlib import Path
import zipfile
import tarfile
import urllib.request
import shutil

def download_file(url, destination):
    """Download file with progress indication."""
    print(f"  Downloading from {url}")
    try:
        urllib.request.urlretrieve(url, destination)
        return True
    except Exception as e:
        print(f"  ‚úó Error: {e}")
        return False

def extract_zip(zip_path, extract_to):
    """Extract zip file."""
    print(f"  Extracting {os.path.basename(zip_path)}...")
    try:
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(extract_to)
        return True
    except Exception as e:
        print(f"  ‚úó Error extracting: {e}")
        return False

def download_sd4match_datasets(data_dir):
    """
    Download and extract SD4Match benchmark datasets.
    Includes: PF-Pascal, PF-Willow, and SPair-71k
    """
    print("="*60)
    print("SD4MATCH BENCHMARK DATASETS DOWNLOAD")
    print("="*60)
    
    os.makedirs(data_dir, exist_ok=True)
    
    # Dataset configurations
    datasets = {
        'pf-pascal': {
            'images_url': 'https://www.di.ens.fr/willow/research/proposalflow/dataset/PF-dataset-PASCAL.zip',
            'pairs_url': 'https://www.robots.ox.ac.uk/~xinghui/sd4match/pf-pascal_image_pairs.zip',
            'has_splits': True
        },
        'pf-willow': {
            'images_url': 'https://www.di.ens.fr/willow/research/proposalflow/dataset/PF-dataset.zip',
            'pairs_url': 'https://www.robots.ox.ac.uk/~xinghui/sd4match/test_pairs.csv',
            'has_splits': False
        },
        'spair-71k': {
            'images_url': 'http://cvlab.postech.ac.kr/research/SPair-71k/data/SPair-71k.tar.gz',
            'has_splits': True
        }
    }
    
    all_ready = True
    
    for dataset_name, config in datasets.items():
        dataset_path = os.path.join(data_dir, dataset_name)
        print(f"\nüì¶ {dataset_name.upper()}")
        print("-" * 40)
        
        # Check if already exists
        if os.path.exists(dataset_path) and os.listdir(dataset_path):
            print(f"  ‚úì Already exists at {dataset_path}")
            continue
        
        os.makedirs(dataset_path, exist_ok=True)
        
        # Download images
        print(f"  Downloading {dataset_name} images...")
        images_filename = os.path.basename(config['images_url'])
        images_path = os.path.join(data_dir, images_filename)
        
        if not os.path.exists(images_path):
            if download_file(config['images_url'], images_path):
                print(f"  ‚úì Downloaded {images_filename}")
            else:
                print(f"  ‚ö†Ô∏è  Failed to download images")
                all_ready = False
                continue
        
        # Extract images
        if images_filename.endswith('.zip'):
            extract_zip(images_path, dataset_path)
        elif images_filename.endswith('.tar.gz'):
            print(f"  Extracting {images_filename}...")
            with tarfile.open(images_path, 'r:gz') as tar:
                tar.extractall(dataset_path)
        
        # Download pairs/splits if applicable
        if 'pairs_url' in config:
            pairs_filename = os.path.basename(config['pairs_url'])
            pairs_path = os.path.join(data_dir, pairs_filename)
            
            print(f"  Downloading image pairs...")
            if download_file(config['pairs_url'], pairs_path):
                if pairs_filename.endswith('.zip'):
                    extract_zip(pairs_path, dataset_path)
                elif pairs_filename.endswith('.csv'):
                    shutil.move(pairs_path, os.path.join(dataset_path, pairs_filename))
                print(f"  ‚úì Downloaded pairs/splits")
        
        # Clean up zip files
        if os.path.exists(images_path):
            os.remove(images_path)
        
        print(f"  ‚úì {dataset_name} setup complete!")
    
    print("\n" + "="*60)
    
    if all_ready:
        print("‚úÖ All datasets downloaded successfully!")
        print(f"\nDatasets location: {data_dir}")
        print("\nStructure:")
        print(f"{data_dir}/")
        print("  ‚îú‚îÄ‚îÄ pf-pascal/")
        print("  ‚îú‚îÄ‚îÄ pf-willow/")
        print("  ‚îî‚îÄ‚îÄ spair-71k/")
    else:
        print("‚ö†Ô∏è  Some datasets failed to download automatically.")
        print("\nüì• MANUAL DOWNLOAD INSTRUCTIONS:")
        print("-" * 60)
        print("1. PF-Pascal: https://www.di.ens.fr/willow/research/proposalflow/")
        print("   Pairs: https://www.robots.ox.ac.uk/~xinghui/sd4match/pf-pascal_image_pairs.zip")
        print("\n2. PF-Willow: https://www.di.ens.fr/willow/research/proposalflow/")
        print("   Pairs: https://www.robots.ox.ac.uk/~xinghui/sd4match/test_pairs.csv")
        print("\n3. SPair-71k: http://cvlab.postech.ac.kr/research/SPair-71k/")
        print("-" * 60)
        
        if IN_COLAB:
            print("\nüí° FOR GOOGLE COLAB:")
            print("   1. Download datasets to your computer")
            print("   2. Upload to Google Drive")
            print("   3. Mount Drive and set DATA_ROOT accordingly")
    
    return all_ready

# Attempt to download datasets
print("‚è≥ Starting dataset download... This may take several minutes.\n")
dataset_ready = download_sd4match_datasets(sd4match_data_dir)

‚è≥ Starting dataset download... This may take several minutes.

SD4MATCH BENCHMARK DATASETS DOWNLOAD

üì¶ PF-PASCAL
----------------------------------------
  Downloading pf-pascal images...
  Downloading from https://www.di.ens.fr/willow/research/proposalflow/dataset/PF-dataset-PASCAL.zip
  ‚úì Downloaded PF-dataset-PASCAL.zip
  Extracting PF-dataset-PASCAL.zip...
  ‚úì Downloaded PF-dataset-PASCAL.zip
  Extracting PF-dataset-PASCAL.zip...
  Downloading image pairs...
  Downloading from https://www.robots.ox.ac.uk/~xinghui/sd4match/pf-pascal_image_pairs.zip
  Downloading image pairs...
  Downloading from https://www.robots.ox.ac.uk/~xinghui/sd4match/pf-pascal_image_pairs.zip
  Extracting pf-pascal_image_pairs.zip...
  ‚úì Downloaded pairs/splits
  ‚úì pf-pascal setup complete!

üì¶ PF-WILLOW
----------------------------------------
  Downloading pf-willow images...
  Downloading from https://www.di.ens.fr/willow/research/proposalflow/dataset/PF-dataset.zip
  Extracting pf-pascal_im

## 3. DINOv2 Backbone Setup

**Repository**: https://github.com/facebookresearch/dinov2  
**Model**: ViT-B (Base version)  
**Key**: Use official repo (not just Hugging Face) to access internal components

In [None]:
# Clone DINOv2 repository
dinov2_dir = os.path.join(MODEL_DIR, 'dinov2')
if not os.path.exists(dinov2_dir):
    !git clone https://github.com/facebookresearch/dinov2.git "{dinov2_dir}"
    print("DINOv2 repository cloned successfully")
else:
    print("DINOv2 repository already exists")

# Add to Python path
if dinov2_dir not in sys.path:
    sys.path.insert(0, dinov2_dir)

print(f"DINOv2 path: {dinov2_dir}")

In [None]:
# Load DINOv2 ViT-B model
def load_dinov2_model(model_name='dinov2_vitb14', device='cuda'):
    """
    Load DINOv2 model from official repository.
    
    Available models:
    - dinov2_vits14: Small (ViT-S/14)
    - dinov2_vitb14: Base (ViT-B/14) - RECOMMENDED
    - dinov2_vitl14: Large (ViT-L/14)
    - dinov2_vitg14: Giant (ViT-G/14)
    
    The '14' indicates patch size of 14x14 pixels.
    """
    try:
        model = torch.hub.load('facebookresearch/dinov2', model_name)
        model = model.to(device)
        model.eval()
        print(f"‚úì DINOv2 model '{model_name}' loaded successfully")
        print(f"  - Patch size: 14x14")
        print(f"  - Device: {device}")
        return model
    except Exception as e:
        print(f"‚úó Error loading DINOv2: {e}")
        return None

# Load the Base model (ViT-B)
dinov2_model = load_dinov2_model('dinov2_vitb14', device=device)

In [None]:
# DINOv2 feature extraction utility
def extract_dinov2_features(model, image, return_class_token=True, return_patch_tokens=True):
    """
    Extract features from DINOv2 model.
    
    Args:
        model: DINOv2 model
        image: PIL Image or tensor (C, H, W) in range [0, 1]
        return_class_token: Return [CLS] token
        return_patch_tokens: Return patch tokens
        
    Returns:
        Dictionary containing requested features
    """
    from torchvision import transforms
    
    # Prepare image
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    if isinstance(image, Image.Image):
        image = transform(image).unsqueeze(0)
    elif image.dim() == 3:
        image = image.unsqueeze(0)
    
    image = image.to(next(model.parameters()).device)
    
    # Extract features
    with torch.no_grad():
        features = model.forward_features(image)
        
    result = {}
    if return_class_token:
        result['cls_token'] = features['x_norm_clstoken']
    if return_patch_tokens:
        result['patch_tokens'] = features['x_norm_patchtokens']
    
    return result

print("DINOv2 feature extraction utility defined")

## 4. DINOv3 Backbone Setup

**Repository**: https://github.com/facebookresearch/dinov3  
**Model**: ViT-B (Base version)  
**Key**: Request access to checkpoints, then download pretrained weights

In [None]:
# Clone DINOv3 repository
dinov3_dir = os.path.join(MODEL_DIR, 'dinov3')
if not os.path.exists(dinov3_dir):
    !git clone https://github.com/facebookresearch/dinov3.git "{dinov3_dir}"
    print("DINOv3 repository cloned successfully")
else:
    print("DINOv3 repository already exists")

# Add to Python path
if dinov3_dir not in sys.path:
    sys.path.insert(0, dinov3_dir)

print(f"DINOv3 path: {dinov3_dir}")
print("\n‚ö†Ô∏è  IMPORTANT: Request access and download DINOv3 checkpoints")
print("   Follow instructions in the DINOv3 repository README")

In [None]:
# DINOv3 checkpoint configuration
dinov3_checkpoint_dir = os.path.join(CHECKPOINT_DIR, 'dinov3')
os.makedirs(dinov3_checkpoint_dir, exist_ok=True)

# Expected checkpoint path for ViT-B
dinov3_checkpoint_path = os.path.join(dinov3_checkpoint_dir, 'dinov3_vitb14_pretrain.pth')

print(f"DINOv3 checkpoint directory: {dinov3_checkpoint_dir}")
print(f"Expected checkpoint path: {dinov3_checkpoint_path}")
print("\nAfter obtaining access, download the ViT-B checkpoint to this location")

In [None]:
# Load DINOv3 model (after checkpoint is downloaded)
def load_dinov3_model(checkpoint_path, device='cuda'):
    """
    Load DINOv3 model from checkpoint.
    
    Args:
        checkpoint_path: Path to the downloaded checkpoint
        device: Device to load model on
        
    Returns:
        Loaded DINOv3 model
    """
    if not os.path.exists(checkpoint_path):
        print(f"‚úó Checkpoint not found: {checkpoint_path}")
        print("  Please download the DINOv3 checkpoint after requesting access")
        return None
    
    try:
        # This will be updated once checkpoint structure is known
        # Placeholder for actual loading code
        print(f"‚úì Loading DINOv3 from: {checkpoint_path}")
        
        # Import DINOv3 modules (adjust based on actual repo structure)
        # from dinov3.models import build_model
        # model = build_model(checkpoint_path)
        # model = model.to(device)
        # model.eval()
        
        print("‚úì DINOv3 model loaded successfully")
        print(f"  - Device: {device}")
        return None  # Will return actual model after implementation
    except Exception as e:
        print(f"‚úó Error loading DINOv3: {e}")
        return None

# Note: Uncomment and run after downloading checkpoint
# dinov3_model = load_dinov3_model(dinov3_checkpoint_path, device=device)
print("DINOv3 loader defined (run after downloading checkpoint)")

## 5. SAM (Segment Anything) Backbone Setup

**Repository**: https://github.com/facebookresearch/segment-anything  
**Model**: ViT-B (Base version) - RECOMMENDED  
**Optional**: Can experiment with ViT-L (Large) or ViT-H (Huge) for comparison

In [11]:
# Install SAM
!pip install git+https://github.com/facebookresearch/segment-anything.git

Collecting git+https://github.com/facebookresearch/segment-anything.git
  Cloning https://github.com/facebookresearch/segment-anything.git to /private/var/folders/kp/dmvkcybs4k72tbdpsb3zxlrh0000gn/T/pip-req-build-mpm6h8y8
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/segment-anything.git /private/var/folders/kp/dmvkcybs4k72tbdpsb3zxlrh0000gn/T/pip-req-build-mpm6h8y8
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/segment-anything.git /private/var/folders/kp/dmvkcybs4k72tbdpsb3zxlrh0000gn/T/pip-req-build-mpm6h8y8
  Resolved https://github.com/facebookresearch/segment-anything.git to commit dca509fe793f601edb92606367a655c15ac00fdf
  Installing build dependencies ... [?25l  Resolved https://github.com/facebookresearch/segment-anything.git to commit dca509fe793f601edb92606367a655c15ac00fdf
  Installing build dependencies ... [?25l-done
[?25h  Getting requirements to build wheel ... [?25done
[?25h  

In [12]:
# Download SAM checkpoints
import urllib.request

sam_checkpoint_dir = os.path.join(CHECKPOINT_DIR, 'sam')
os.makedirs(sam_checkpoint_dir, exist_ok=True)

# SAM model checkpoints
SAM_MODELS = {
    'vit_b': {
        'url': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
        'filename': 'sam_vit_b_01ec64.pth',
        'description': 'ViT-B (Base) - RECOMMENDED'
    },
    'vit_l': {
        'url': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
        'filename': 'sam_vit_l_0b3195.pth',
        'description': 'ViT-L (Large) - Optional comparison'
    },
    'vit_h': {
        'url': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth',
        'filename': 'sam_vit_h_4b8939.pth',
        'description': 'ViT-H (Huge) - Optional comparison'
    }
}

def download_sam_checkpoint(model_type='vit_b'):
    """Download SAM checkpoint if not already present."""
    if model_type not in SAM_MODELS:
        print(f"Invalid model type. Choose from: {list(SAM_MODELS.keys())}")
        return None
    
    model_info = SAM_MODELS[model_type]
    checkpoint_path = os.path.join(sam_checkpoint_dir, model_info['filename'])
    
    if os.path.exists(checkpoint_path):
        print(f"‚úì Checkpoint already exists: {checkpoint_path}")
        return checkpoint_path
    
    print(f"Downloading {model_info['description']}...")
    print(f"URL: {model_info['url']}")
    try:
        urllib.request.urlretrieve(model_info['url'], checkpoint_path)
        print(f"‚úì Downloaded successfully: {checkpoint_path}")
        return checkpoint_path
    except Exception as e:
        print(f"‚úó Error downloading: {e}")
        return None

# Download ViT-B checkpoint (recommended)
sam_checkpoint_path = download_sam_checkpoint('vit_b')

print(f"\nSAM checkpoint directory: {sam_checkpoint_dir}")
print("Available models:", list(SAM_MODELS.keys()))

Downloading ViT-B (Base) - RECOMMENDED...
URL: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
‚úì Downloaded successfully: /Users/giuliavarga/Desktop/2. AML/Project/AMLProject/checkpoints/sam/sam_vit_b_01ec64.pth

SAM checkpoint directory: /Users/giuliavarga/Desktop/2. AML/Project/AMLProject/checkpoints/sam
Available models: ['vit_b', 'vit_l', 'vit_h']
‚úì Downloaded successfully: /Users/giuliavarga/Desktop/2. AML/Project/AMLProject/checkpoints/sam/sam_vit_b_01ec64.pth

SAM checkpoint directory: /Users/giuliavarga/Desktop/2. AML/Project/AMLProject/checkpoints/sam
Available models: ['vit_b', 'vit_l', 'vit_h']


In [13]:
# Load SAM model
from segment_anything import sam_model_registry, SamPredictor

def load_sam_model(checkpoint_path, model_type='vit_b', device='cuda'):
    """
    Load SAM model.
    
    Args:
        checkpoint_path: Path to checkpoint
        model_type: 'vit_b', 'vit_l', or 'vit_h'
        device: Device to load on
        
    Returns:
        SAM model and predictor
    """
    if not os.path.exists(checkpoint_path):
        print(f"‚úó Checkpoint not found: {checkpoint_path}")
        return None, None
    
    try:
        sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
        sam = sam.to(device)
        sam.eval()
        
        # Create predictor for easier inference
        predictor = SamPredictor(sam)
        
        print(f"‚úì SAM model loaded successfully")
        print(f"  - Model type: {model_type}")
        print(f"  - Device: {device}")
        print(f"  - Checkpoint: {checkpoint_path}")
        
        return sam, predictor
    except Exception as e:
        print(f"‚úó Error loading SAM: {e}")
        return None, None

# Load SAM ViT-B
if sam_checkpoint_path:
    sam_model, sam_predictor = load_sam_model(sam_checkpoint_path, 'vit_b', device=device)
else:
    print("SAM checkpoint not available yet")

  Referenced from: <EB3FF92A-5EB1-3EE8-AF8B-5923C1265422> /opt/anaconda3/envs/aml_project/lib/python3.11/site-packages/torchvision/image.so
  warn(
  state_dict = torch.load(f)
  state_dict = torch.load(f)


‚úì SAM model loaded successfully
  - Model type: vit_b
  - Device: mps
  - Checkpoint: /Users/giuliavarga/Desktop/2. AML/Project/AMLProject/checkpoints/sam/sam_vit_b_01ec64.pth


In [14]:
# SAM feature extraction utility
def extract_sam_features(sam_model, image):
    """
    Extract features from SAM image encoder.
    
    Args:
        sam_model: SAM model
        image: PIL Image or numpy array (H, W, 3) in RGB format
        
    Returns:
        Image embeddings from SAM encoder
    """
    import numpy as np
    from segment_anything.utils.transforms import ResizeLongestSide
    
    # Convert PIL to numpy if needed
    if isinstance(image, Image.Image):
        image = np.array(image)
    
    # SAM preprocessing
    transform = ResizeLongestSide(sam_model.image_encoder.img_size)
    input_image = transform.apply_image(image)
    input_image_torch = torch.as_tensor(input_image, device=sam_model.device)
    input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
    
    # Extract features
    with torch.no_grad():
        image_embedding = sam_model.image_encoder(input_image_torch)
    
    return image_embedding

print("SAM feature extraction utility defined")

SAM feature extraction utility defined


## 6. Utility Functions & Configuration

In [None]:
# Configuration class for the project
class ProjectConfig:
    """Central configuration for the semantic correspondence project."""
    
    def __init__(self):
        # Paths
        self.project_root = PROJECT_ROOT
        self.data_root = DATA_ROOT
        self.checkpoint_dir = CHECKPOINT_DIR
        self.output_dir = OUTPUT_DIR
        self.model_dir = MODEL_DIR
        
        # Dataset
        self.dataset_name = 'SD4Match'
        self.splits = ['trn', 'val', 'test']
        
        # Models
        self.backbones = {
            'dinov2': 'dinov2_vitb14',
            'dinov3': 'dinov3_vitb14',
            'sam': 'vit_b'
        }
        
        # Device
        self.device = device
        
        # Training (to be filled in later phases)
        self.batch_size = 16
        self.num_epochs = 100
        self.learning_rate = 1e-4
        
    def __repr__(self):
        return f"""ProjectConfig:
  Project Root: {self.project_root}
  Data Root: {self.data_root}
  Device: {self.device}
  Dataset: {self.dataset_name}
  Backbones: {list(self.backbones.keys())}
"""

config = ProjectConfig()
print(config)

In [None]:
# Visualization utilities
def visualize_correspondence(img1, img2, pts1, pts2, matches=None, figsize=(15, 7)):
    """
    Visualize correspondence between two images.
    
    Args:
        img1, img2: Images (PIL or numpy)
        pts1, pts2: Keypoint coordinates [(x, y), ...]
        matches: Optional list of match indices [(idx1, idx2), ...]
        figsize: Figure size
    """
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
    
    # Display images
    ax1.imshow(img1)
    ax1.set_title('Image 1')
    ax1.axis('off')
    
    ax2.imshow(img2)
    ax2.set_title('Image 2')
    ax2.axis('off')
    
    # Plot keypoints
    if pts1 is not None and len(pts1) > 0:
        pts1 = np.array(pts1)
        ax1.scatter(pts1[:, 0], pts1[:, 1], c='red', s=50, marker='x')
    
    if pts2 is not None and len(pts2) > 0:
        pts2 = np.array(pts2)
        ax2.scatter(pts2[:, 0], pts2[:, 1], c='red', s=50, marker='x')
    
    plt.tight_layout()
    return fig

def save_model_checkpoint(model, optimizer, epoch, path, **kwargs):
    """Save model checkpoint with metadata."""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict() if optimizer else None,
        **kwargs
    }
    torch.save(checkpoint, path)
    print(f"‚úì Checkpoint saved: {path}")

def load_model_checkpoint(model, path, optimizer=None, device='cuda'):
    """Load model checkpoint."""
    if not os.path.exists(path):
        print(f"‚úó Checkpoint not found: {path}")
        return None
    
    checkpoint = torch.load(path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    if optimizer and checkpoint.get('optimizer_state_dict'):
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    epoch = checkpoint.get('epoch', 0)
    print(f"‚úì Checkpoint loaded from epoch {epoch}")
    return checkpoint

print("Visualization and checkpoint utilities defined")

## 7. Model Summary & Testing

Quick tests to verify all models are loaded correctly.

In [None]:
# Summary of loaded models
print("="*60)
print("MODEL SETUP SUMMARY")
print("="*60)

models_status = {
    'DINOv2 (ViT-B)': dinov2_model is not None if 'dinov2_model' in locals() else False,
    'DINOv3 (ViT-B)': False,  # To be loaded after checkpoint download
    'SAM (ViT-B)': (sam_model is not None) if 'sam_model' in locals() else False,
}

for model_name, status in models_status.items():
    status_symbol = "‚úì" if status else "‚ö†"
    status_text = "Loaded" if status else "Not loaded yet"
    print(f"{status_symbol} {model_name}: {status_text}")

print("\n" + "="*60)
print("NEXT STEPS")
print("="*60)
print("1. DINOv3: Request access and download checkpoint")
print("2. SD4Match: Download dataset to", sd4match_data_dir)
print("3. Verify all models work with test images")
print("4. Ready for team to implement correspondence methods")
print("="*60)

In [None]:
# Test with a dummy image (optional)
def test_model_inference():
    """Quick test to verify models can process images."""
    # Create a dummy image
    dummy_image = Image.new('RGB', (224, 224), color='red')
    
    print("Testing model inference with dummy image...")
    print("-" * 40)
    
    # Test DINOv2
    if 'dinov2_model' in locals() and dinov2_model is not None:
        try:
            features = extract_dinov2_features(dinov2_model, dummy_image)
            print(f"‚úì DINOv2: CLS token shape = {features['cls_token'].shape}")
            print(f"           Patch tokens shape = {features['patch_tokens'].shape}")
        except Exception as e:
            print(f"‚úó DINOv2 error: {e}")
    else:
        print("‚ö† DINOv2: Not loaded")
    
    # Test SAM
    if 'sam_model' in locals() and sam_model is not None:
        try:
            embedding = extract_sam_features(sam_model, dummy_image)
            print(f"‚úì SAM: Embedding shape = {embedding.shape}")
        except Exception as e:
            print(f"‚úó SAM error: {e}")
    else:
        print("‚ö† SAM: Not loaded")
    
    print("-" * 40)
    print("Model inference test complete")

# Uncomment to run test
# test_model_inference()

## 8. Additional Resources & Notes

### Window Soft Argmax (GeoAware-SC)
For prediction refinement in later phases:
- **Repository**: https://github.com/Junyi42/geoaware-sc
- This will be used for refining correspondence predictions

### Professor's Key Recommendations Summary:
1. **Backbone Selection**: Use Base (ViT-B) versions for all three backbones
2. **Model Access**: 
   - DINOv2: Use official repo, not just Hugging Face
   - DINOv3: Request access to checkpoints
   - SAM: ViT-B recommended, can compare with L/H if compute allows
3. **Dataset Splits**:
   - Train on `trn` split
   - Validate on `val` split for model selection
   - **Only report final results on `test` split**
4. **Backbone Size Trade-offs**:
   - Larger backbones (Small ‚Üí Base ‚Üí Large) generally improve performance
   - But gains are not always consistent across tasks
   - Increased size = higher compute/memory/time costs

### For Team Members (Later Phases):
- All infrastructure is ready for implementing correspondence methods
- Models are loaded and ready to extract features
- Utilities for visualization and checkpointing are provided
- Follow the professor's evaluation protocol strictly