# Compute Average Embeddings for CLIP and DINOv3

This notebook computes category average embeddings from individual CLIP and DINOv3 embeddings.

## Overview

This step:
1. Loads individual CLIP embeddings from files
2. Groups embeddings by category
3. Computes category average embeddings for CLIP
4. Loads individual DINOv3 embeddings from files
5. Groups embeddings by category
6. Computes category average embeddings for DINOv3
7. Saves the average embeddings for later RDM computation

## Prerequisites

Make sure you have the following packages installed:
- `numpy`
- `pandas`
- `tqdm`

## Setup and Imports

In [None]:
import numpy as np
import pandas as pd
from pathlib import Path
from collections import defaultdict
from tqdm import tqdm
import multiprocessing as mp
from concurrent.futures import ThreadPoolExecutor
import warnings
warnings.filterwarnings('ignore')

print("All imports successful!")

## Configuration

**Please update the paths below according to your setup:**

In [None]:
# ============================================================================
# CONFIGURATION - UPDATE THESE PATHS FOR YOUR SETUP
# ============================================================================

# Paths for CLIP embeddings
CLIP_EMBEDDING_LIST = "/data2/dataset/babyview/868_hours/outputs/yoloe_cdi_embeddings/clip_embeddings_new/clip_image_embeddings_doc_normalized_filtered-by-clip-0.26.txt"  # Path to text file with CLIP embedding paths (one per line), or None to scan directory
CLIP_EMBEDDINGS_DIR = "/data2/dataset/babyview/868_hours/outputs/yoloe_cdi_embeddings/clip_embeddings_new"  # Base directory for CLIP embeddings
CLIP_OUTPUT_DIR = "./clip_rdm_results_26"  # Directory where CLIP results will be saved

# Paths for DINOv3 embeddings
DINOV3_EMBEDDING_LIST = "/data2/dataset/babyview/868_hours/outputs/yoloe_cdi_embeddings/clip_embeddings_new/clip_image_embeddings_doc_normalized_filtered-by-clip-0.26.txt"  # Path to text file with DINOv3 embedding paths, or None to scan directory
DINOV3_EMBEDDINGS_DIR = "/data2/dataset/babyview/868_hours/outputs/yoloe_cdi_embeddings/facebook_dinov3-vitb16-pretrain-lvd1689m"  # Base directory for DINOv3 embeddings
DINOV3_OUTPUT_DIR = "./dinov3_rdm_results_26"  # Directory where DINOv3 results will be saved

# Option to match from CLIP list (ensures same images for DINOv3)
CLIP_EMBEDDING_LIST_REF = "/data2/dataset/babyview/868_hours/outputs/yoloe_cdi_embeddings/clip_embeddings_new/clip_image_embeddings_doc_normalized_filtered-by-clip-0.26.txt"  # Reference CLIP list for matching
MATCH_FROM_CLIP_LIST = True  # If True, match DINOv3 filenames from CLIP list (ensures same images)

# Processing options
NUM_WORKERS = 24  # Number of parallel workers (None = auto-detect, max 16)
USE_PARALLEL = True  # Enable parallel loading

print("Configuration loaded. Please review and update paths as needed.")

## Helper Functions

In [None]:
def load_embedding_paths(txt_path):
    """Load embedding file paths from text file"""
    print(f"Loading embedding paths from {txt_path}...")
    with open(txt_path, 'r') as f:
        paths = [line.strip() for line in f if line.strip()]
    print(f"Found {len(paths)} embedding paths")
    return paths

def scan_embedding_directory(embeddings_dir):
    """Scan directory for all .npy embedding files"""
    embeddings_dir = Path(embeddings_dir)
    print(f"Scanning {embeddings_dir} for .npy files...")
    
    npy_files = list(embeddings_dir.rglob("*.npy"))
    
    if len(npy_files) == 0:
        raise ValueError(f"No .npy files found in {embeddings_dir}")
    
    paths = [str(f.relative_to(embeddings_dir)) for f in npy_files]
    paths.sort()
    
    print(f"Found {len(paths)} embedding files")
    return paths

def match_embedding_paths_from_list(reference_list_path, target_embeddings_dir):
    """Match embedding paths from a reference list to target directory"""
    target_embeddings_dir = Path(target_embeddings_dir)
    
    print(f"Loading reference embedding list from {reference_list_path}...")
    with open(reference_list_path, 'r') as f:
        reference_paths = [line.strip() for line in f if line.strip()]
    
    print(f"Found {len(reference_paths)} reference paths")
    
    reference_mapping = {}
    for ref_path in reference_paths:
        ref_path_obj = Path(ref_path)
        if len(ref_path_obj.parts) >= 2:
            category = ref_path_obj.parts[-2]
            filename = ref_path_obj.name
            reference_mapping[(category, filename)] = ref_path
    
    target_files = {}
    if target_embeddings_dir.exists():
        for npy_file in target_embeddings_dir.rglob("*.npy"):
            rel_path = npy_file.relative_to(target_embeddings_dir)
            if len(rel_path.parts) >= 2:
                category = rel_path.parts[0]
                filename = rel_path.name
                if category not in target_files:
                    target_files[category] = {}
                target_files[category][filename] = str(rel_path)
    
    matched_paths = []
    for (category, filename), ref_path in reference_mapping.items():
        if category in target_files and filename in target_files[category]:
            matched_paths.append(target_files[category][filename])
    
    print(f"Matched {len(matched_paths)} files ({len(matched_paths)/len(reference_mapping)*100:.1f}%)")
    matched_paths.sort()
    
    return matched_paths

def load_single_embedding(args):
    """Load a single embedding file (worker function for parallel processing)"""
    path, embeddings_dir, is_absolute = args
    
    try:
        if is_absolute:
            full_path = Path(path)
        else:
            full_path = Path(embeddings_dir) / path
        
        if not full_path.exists():
            return None, None
        
        path_parts = full_path.parts
        if len(path_parts) < 2:
            return None, None
        category = path_parts[-2]
        
        embedding = np.load(full_path)
        if embedding.ndim > 1:
            embedding = embedding.flatten()
        
        return category, embedding
    except Exception:
        return None, None

def load_embeddings_by_category(embedding_paths, embeddings_dir, num_workers=None, use_parallel=True):
    """Load embeddings grouped by category"""
    print("Loading embeddings by category...")
    embeddings_by_category = defaultdict(list)
    
    embeddings_dir = Path(embeddings_dir)
    
    if num_workers is None:
        num_workers = min(16, mp.cpu_count())
    
    path_args = []
    for path in embedding_paths:
        is_absolute = Path(path).is_absolute()
        path_args.append((path, str(embeddings_dir), is_absolute))
    
    if use_parallel and len(embedding_paths) > 100:
        print(f"Using {num_workers} parallel workers...")
        print(f"Processing {len(path_args):,} embedding files...")
        
        chunk_size = max(5000, num_workers * 200)
        total_chunks = (len(path_args) + chunk_size - 1) // chunk_size
        
        successful = 0
        failed = 0
        
        with ThreadPoolExecutor(max_workers=num_workers) as executor:
            for chunk_idx in range(0, len(path_args), chunk_size):
                chunk = path_args[chunk_idx:chunk_idx + chunk_size]
                chunk_num = (chunk_idx // chunk_size) + 1
                
                results = list(executor.map(load_single_embedding, chunk))
                
                for category, embedding in results:
                    if category is not None and embedding is not None:
                        embeddings_by_category[category].append(embedding)
                        successful += 1
                    else:
                        failed += 1
                
                if chunk_num % 50 == 0 or chunk_num == total_chunks:
                    progress = (chunk_num / total_chunks) * 100
                    print(f"Progress: {progress:.1f}% ({chunk_num}/{total_chunks} chunks, {successful:,} loaded)")
        
        if failed > 0:
            print(f"\nWarning: Failed to load {failed:,} embeddings")
        print(f"Successfully loaded {successful:,} embeddings")
    else:
        print("Using sequential loading...")
        for args in tqdm(path_args, desc="Loading embeddings"):
            category, embedding = load_single_embedding(args)
            if category is not None and embedding is not None:
                embeddings_by_category[category].append(embedding)
    
    print(f"\nLoaded embeddings for {len(embeddings_by_category)} categories:")
    for category, emb_list in sorted(embeddings_by_category.items()):
        print(f"  {category}: {len(emb_list)} embeddings")
    
    return embeddings_by_category

def compute_category_averages(embeddings_by_category):
    """Compute average embedding for each category"""
    print("\nComputing category average embeddings...")
    category_averages = {}
    categories = []
    
    for category, emb_list in sorted(embeddings_by_category.items()):
        if len(emb_list) == 0:
            continue
        
        emb_array = np.array(emb_list)
        avg_embedding = np.mean(emb_array, axis=0)
        category_averages[category] = avg_embedding
        categories.append(category)
        
        print(f"  {category}: {len(emb_list)} embeddings -> shape {avg_embedding.shape}")
    
    return category_averages, categories

print("Helper functions loaded!")

## Compute CLIP Average Embeddings

In [None]:
print("="*60)
print("COMPUTING CLIP AVERAGE EMBEDDINGS")
print("="*60)

# Load embedding paths
if CLIP_EMBEDDING_LIST:
    embedding_paths = load_embedding_paths(CLIP_EMBEDDING_LIST)
else:
    embedding_paths = scan_embedding_directory(CLIP_EMBEDDINGS_DIR)

# Load embeddings by category
embeddings_by_category = load_embeddings_by_category(
    embedding_paths, 
    CLIP_EMBEDDINGS_DIR,
    num_workers=NUM_WORKERS,
    use_parallel=USE_PARALLEL
)

# Compute category averages
category_averages, categories = compute_category_averages(embeddings_by_category)

# Save category averages
output_dir = Path(CLIP_OUTPUT_DIR)
output_dir.mkdir(exist_ok=True, parents=True)

print("\nSaving category average embeddings...")
embeddings = np.array([category_averages[cat] for cat in categories])
npz_path = output_dir / 'category_average_embeddings.npz'
np.savez(npz_path, 
         embeddings=embeddings, 
         categories=np.array(categories))
print(f"  Saved to {npz_path}")

# Also save as CSV for easier inspection
csv_path = output_dir / 'category_average_embeddings.csv'
embeddings_df = pd.DataFrame(embeddings, index=categories)
embeddings_df.to_csv(csv_path)
print(f"  Saved to {csv_path}")

# Save category info
info_path = output_dir / 'category_average_info.txt'
with open(info_path, 'w') as f:
    f.write(f"Total categories: {len(categories)}\n")
    f.write(f"Embedding dimension: {embeddings.shape[1]}\n")
    f.write(f"\nCategories:\n")
    for cat in categories:
        f.write(f"  {cat}: {len(embeddings_by_category[cat])} embeddings\n")
print(f"  Saved to {info_path}")

# Save category names
names_path = output_dir / 'category_names.txt'
with open(names_path, 'w') as f:
    for cat in categories:
        f.write(f"{cat}\n")
print(f"  Saved to {names_path}")

print(f"\nCLIP average embeddings computation complete! Results saved to {output_dir}")
print(f"Total categories: {len(categories)}")
print(f"Embedding shape: {embeddings.shape}")

## Compute DINOv3 Average Embeddings

In [None]:
print("="*60)
print("COMPUTING DINOv3 AVERAGE EMBEDDINGS")
print("="*60)

# Load embedding paths
if MATCH_FROM_CLIP_LIST and CLIP_EMBEDDING_LIST_REF:
    embedding_paths = match_embedding_paths_from_list(CLIP_EMBEDDING_LIST_REF, DINOV3_EMBEDDINGS_DIR)
elif DINOV3_EMBEDDING_LIST:
    embedding_paths = load_embedding_paths(DINOV3_EMBEDDING_LIST)
else:
    embedding_paths = scan_embedding_directory(DINOV3_EMBEDDINGS_DIR)

# Load embeddings by category
embeddings_by_category = load_embeddings_by_category(
    embedding_paths, 
    DINOV3_EMBEDDINGS_DIR,
    num_workers=NUM_WORKERS,
    use_parallel=USE_PARALLEL
)

# Compute category averages
category_averages, categories = compute_category_averages(embeddings_by_category)

# Save category averages
output_dir = Path(DINOV3_OUTPUT_DIR)
output_dir.mkdir(exist_ok=True, parents=True)

print("\nSaving category average embeddings...")
embeddings = np.array([category_averages[cat] for cat in categories])
npz_path = output_dir / 'category_average_embeddings.npz'
np.savez(npz_path, 
         embeddings=embeddings, 
         categories=np.array(categories))
print(f"  Saved to {npz_path}")

# Also save as CSV for easier inspection
csv_path = output_dir / 'category_average_embeddings.csv'
embeddings_df = pd.DataFrame(embeddings, index=categories)
embeddings_df.to_csv(csv_path)
print(f"  Saved to {csv_path}")

# Save category info
info_path = output_dir / 'category_average_info.txt'
with open(info_path, 'w') as f:
    f.write(f"Total categories: {len(categories)}\n")
    f.write(f"Embedding dimension: {embeddings.shape[1]}\n")
    f.write(f"\nCategories:\n")
    for cat in categories:
        f.write(f"  {cat}: {len(embeddings_by_category[cat])} embeddings\n")
print(f"  Saved to {info_path}")

# Save category names
names_path = output_dir / 'category_names.txt'
with open(names_path, 'w') as f:
    for cat in categories:
        f.write(f"{cat}\n")
print(f"  Saved to {names_path}")

print(f"\nDINOv3 average embeddings computation complete! Results saved to {output_dir}")
print(f"Total categories: {len(categories)}")
print(f"Embedding shape: {embeddings.shape}")