# Embedding Extraction with Vision Transformers

This notebook demonstrates how to extract embeddings from images using various Vision Transformer (ViT) models. The code uses modularized functions available in the `src/vit_embeddings` package.

## 1. Environment Setup

Configure paths and parameters needed for embedding extraction.

In [None]:
# Path configuration
DATASETS_DIR = "datasets"  # Base directory containing all datasets
OUTPUT_DIR = "embeddings"  # Directory where embeddings will be saved

# Processing configuration
BATCH_SIZE = 64  # Batch size for processing

# List of models to use
MODELS = ['FRANCA', 'DINOv2', 'CLIP', 'SigLIPv2']

# Optionally, specify specific datasets (leave None to process all)
SELECTED_DATASETS = None  # Example: ['CUB-200-2011', 'Stanford Cars']

## 2. Module and Dependencies Import

Import required functions from `vit_embeddings` package and other dependencies.

In [None]:
import os
import sys
from pathlib import Path

# Add root directory to path for importing project modules
project_root = str(Path().absolute().parent)
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# Import functions from vit_embeddings package
from src.vit_embeddings import (
    get_device,
    load_vit_model,
    get_embedding_output,
    check_model_requirements,
    get_model_configs,
    get_model_transforms,
    extract_embeddings_batch,
    save_embeddings
)

# Import dataset loading functions
from src.data import (
    get_image_paths_and_labels,
    load_datasets,
    get_datasets_config
)

## 3. Models and Datasets Preparation

Load and filter dataset configurations for processing.

In [None]:
# Get and filter dataset configuration
datasets_config = get_datasets_config()
if SELECTED_DATASETS:
    datasets_config = [cfg for cfg in datasets_config if cfg['name'] in SELECTED_DATASETS]

print("Datasets to process:")
for cfg in datasets_config:
    print(f"- {cfg['name']}")

# Load dataset information
print("\nLoading dataset information...")
all_datasets_info = load_datasets(DATASETS_DIR, datasets_config)

# Check model requirements
print("\nChecking model requirements:")
for model_name in MODELS:
    requirements_met, message = check_model_requirements(model_name)
    print(f"{model_name}: {'✓' if requirements_met else '✗'} - {message}")

# Configure device
device = get_device()
print(f"\nDevice to use: {device}")

## 4. Embedding Extraction

Process each model and dataset to extract embeddings.

In [None]:
# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

for model_name in MODELS:
    print(f"\n=== Processing model: {model_name} ===")
    
    # Check model requirements
    requirements_met, message = check_model_requirements(model_name)
    if not requirements_met:
        print(f"Skipping {model_name}: {message}")
        continue
    
    # Load model and transforms
    model = load_vit_model(model_name, device)
    transform = get_model_transforms(model_name)
    
    # Process each dataset
    for dataset_name, subsets in all_datasets_info.items():
        print(f"\nProcessing {dataset_name}...")
        
        # Process each subset (train/test/validation)
        for subset_name, images_info in subsets.items():
            if not images_info:
                print(f"Skipping empty subset: {subset_name}")
                continue
                
            print(f"Extracting embeddings for {subset_name} "
                  f"({len(images_info)} images)")
            
            # Extract embeddings
            embeddings = extract_embeddings_batch(
                images_info,
                model_name=model_name,
                batch_size=BATCH_SIZE
            )
            
            if embeddings:
                # Save embeddings
                save_embeddings(
                    embeddings,
                    OUTPUT_DIR,
                    model_name,
                    dataset_name,
                    subset_name
                )
            else:
                print(f"No embeddings extracted for {subset_name}")

## 5. Results Verification

Check the structure and content of extracted embeddings.

In [None]:
import pandas as pd

def print_embeddings_info(output_dir):
    """Print information about extracted embeddings."""
    total_files = 0
    
    for model in os.listdir(output_dir):
        model_path = os.path.join(output_dir, model)
        if os.path.isdir(model_path):
            print(f"\nModel: {model}")
            
            for dataset in os.listdir(model_path):
                dataset_path = os.path.join(model_path, dataset)
                if os.path.isdir(dataset_path):
                    print(f"\n  Dataset: {dataset}")
                    
                    for file in os.listdir(dataset_path):
                        if file.endswith('.parquet'):
                            file_path = os.path.join(dataset_path, file)
                            df = pd.read_parquet(file_path)
                            print(f"    - {file}: {df.shape[0]} images, "
                                  f"{df.shape[1]-2} dimensions")
                            total_files += 1
    
    print(f"\nTotal embedding files: {total_files}")

# Verify structure and content
print_embeddings_info(OUTPUT_DIR)