# Forensic Similarity Search Notebook
This notebook demonstrates a similarity search using CLIP and FAISS in a modular manner.

## Usage Instructions
Below is the original usage and description extracted from the script:
  
#!/usr/bin/env python3  
"""
forensic_similarity_search.py

Usage:
  python forensic_similarity_search.py \
      --ref_dir  /evidence/suspects/john_doe_faces \
      --gallery  /evidence/phone_dump/DCIM \
      --out_dir  ./matches \
      --threshold 0.25 \
      --top_k 10
"""

In [None]:
# Cell 1: Configuration and Setup
import os
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from sklearn.metrics.pairwise import cosine_similarity
import cv2
from PIL import Image
import torch
import torch.nn.functional as F

# Configuration Settings
CONFIG = {
    # Model Selection
    'MODEL_NAME': 'dinov2_vitb14',  # Options: dinov2_vits14, dinov2_vitb14, dinov2_vitl14
    'DEVICE': 'auto',  # 'auto', 'cpu', 'cuda'

    # Paths
    'SUSPECTS_GALLERY_PATH': '../../datasets/images/face/gallery',  # Input folder with suspect images
    'QUERY_IMAGES_PATH': '../../datasets/images/face/reference_images',          # Query faces to search for
    'RESULTS_OUTPUT_PATH': '../../datasets/images/face/matched_images/face_recognition_results',  # Output folder

    # Recognition parameters
    'SIMILARITY_THRESHOLD': 0.7,    # Cosine similarity threshold for matches
    'TOP_K_MATCHES': 5,             # Number of top matches to return
    'FIGURE_SIZE': (15, 10),        # Size of result visualization
    'FACE_SIZE': (224, 224),        # Size for DINOv2 input
}

# Available DINOv2 models
AVAILABLE_MODELS = {
    'dinov2_vits14': {
        'name': 'DINOv2-ViT-S/14',
        'description': 'Small model - fastest processing',
        'embed_dim': 384
    },
    'dinov2_vitb14': {
        'name': 'DINOv2-ViT-B/14',
        'description': 'Base model - good balance (recommended)',
        'embed_dim': 768
    },
    'dinov2_vitl14': {
        'name': 'DINOv2-ViT-L/14',
        'description': 'Large model - highest accuracy',
        'embed_dim': 1024
    }
}

print("✅ Configuration loaded successfully")
print(f"📁 Suspects gallery: {CONFIG['SUSPECTS_GALLERY_PATH']}")
print(f"📁 Query images: {CONFIG['QUERY_IMAGES_PATH']}")
print(f"📁 Results output: {CONFIG['RESULTS_OUTPUT_PATH']}")
print(f"🤖 Selected model: {CONFIG['MODEL_NAME']}")

# Cell 2: Install Dependencies and Load Models
# Install required packages
try:
    import torch
    import torchvision.transforms as transforms
    print("✅ PyTorch already installed")
except ImportError:
    print("⚠️ Installing PyTorch...")
    !pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
    !pip install opencv-python
    !pip install scikit-learn
    import torch
    import torchvision.transforms as transforms

# Check device
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"🚀 Using GPU: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("🖥️ Using CPU")

if CONFIG['DEVICE'] == 'auto':
    CONFIG['DEVICE'] = device

def load_dinov2_model(model_name=None):
    """Load DINOv2 model from torch hub"""
    try:
        if model_name is None:
            model_name = CONFIG['MODEL_NAME']

        print(f"📥 Loading DINOv2 model: {model_name}")

        # Load model from torch hub
        model = torch.hub.load('facebookresearch/dinov2', model_name)
        model = model.to(CONFIG['DEVICE'])
        model.eval()

        # Get embedding dimension
        embed_dim = AVAILABLE_MODELS[model_name]['embed_dim']

        print(f"✅ Model loaded successfully!")
        print(f"   📊 Model: {AVAILABLE_MODELS[model_name]['name']}")
        print(f"   🖥️ Device: {CONFIG['DEVICE']}")
        print(f"   📐 Embedding dimension: {embed_dim}")

        return model, embed_dim

    except Exception as e:
        print(f"❌ Error loading model: {e}")
        return None, None

# Load model
model, embed_dim = load_dinov2_model()

# Cell 3: Face Detection and Preprocessing
def setup_directories():
    """Create necessary directories"""
    for path in [CONFIG['SUSPECTS_GALLERY_PATH'], CONFIG['QUERY_IMAGES_PATH'], CONFIG['RESULTS_OUTPUT_PATH']]:
        os.makedirs(path, exist_ok=True)
        print(f"📁 Created directory: {path}")

def detect_faces_opencv(image_path):
    """Detect faces using OpenCV Haar Cascades"""
    try:
        # Load OpenCV face detector
        face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')

        # Read image
        img = cv2.imread(str(image_path))
        if img is None:
            return []

        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

        # Detect faces
        faces = face_cascade.detectMultiScale(gray, 1.1, 4)

        face_crops = []
        for (x, y, w, h) in faces:
            # Add some padding around face
            padding = 20
            x1 = max(0, x - padding)
            y1 = max(0, y - padding)
            x2 = min(img.shape[1], x + w + padding)
            y2 = min(img.shape[0], y + h + padding)

            face_crop = img[y1:y2, x1:x2]
            face_crops.append({
                'crop': face_crop,
                'bbox': (x1, y1, x2, y2),
                'original_bbox': (x, y, w, h)
            })

        return face_crops

    except Exception as e:
        print(f"❌ Error detecting faces in {image_path}: {e}")
        return []

def preprocess_face_for_dinov2(face_image):
    """Preprocess face image for DINOv2"""
    try:
        # Convert BGR to RGB if needed
        if len(face_image.shape) == 3 and face_image.shape[2] == 3:
            face_image = cv2.cvtColor(face_image, cv2.COLOR_BGR2RGB)

        # Convert to PIL Image
        pil_image = Image.fromarray(face_image)

        # Define transforms for DINOv2
        transform = transforms.Compose([
            transforms.Resize(CONFIG['FACE_SIZE']),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
        ])

        tensor_image = transform(pil_image).unsqueeze(0)
        return tensor_image

    except Exception as e:
        print(f"❌ Error preprocessing face: {e}")
        return None

# Cell 4: Feature Extraction and Matching
def extract_face_features(face_tensor, model, device):
    """Extract DINOv2 features from face tensor"""
    try:
        face_tensor = face_tensor.to(device)

        with torch.no_grad():
            # Get features from DINOv2
            features = model(face_tensor)

            # Use CLS token (first token) as face representation
            if len(features.shape) == 3:  # [batch, tokens, dim]
                features = features[:, 0, :]  # Take CLS token

            # Normalize features
            features = F.normalize(features, p=2, dim=1)

        return features.cpu().numpy()

    except Exception as e:
        print(f"❌ Error extracting features: {e}")
        return None

def process_gallery_images(gallery_path, model, device):
    """Process all images in gallery and extract face features"""
    gallery_path = Path(gallery_path)

    if not gallery_path.exists():
        print(f"❌ Gallery path {gallery_path} does not exist")
        return {}

    # Supported image extensions
    image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
    image_files = [f for f in gallery_path.iterdir()
                  if f.suffix.lower() in image_extensions]

    if not image_files:
        print(f"⚠️ No images found in {gallery_path}")
        return {}

    print(f"🔍 Processing {len(image_files)} gallery images...")

    gallery_features = {}

    for i, img_path in enumerate(image_files):
        print(f"Processing {i+1}/{len(image_files)}: {img_path.name}", end='\r')

        # Detect faces in image
        faces = detect_faces_opencv(img_path)

        if faces:
            face_features = []
            for j, face_data in enumerate(faces):
                # Preprocess face
                face_tensor = preprocess_face_for_dinov2(face_data['crop'])

                if face_tensor is not None:
                    # Extract features
                    features = extract_face_features(face_tensor, model, device)

                    if features is not None:
                        face_features.append({
                            'features': features,
                            'bbox': face_data['bbox'],
                            'face_id': j
                        })

            if face_features:
                gallery_features[img_path] = face_features

    print(f"\n✅ Processed gallery: {len(gallery_features)} images with faces")
    return gallery_features

def find_face_matches(query_features, gallery_features, threshold=0.7, top_k=5):
    """Find matching faces in gallery based on similarity"""
    matches = []

    for gallery_img_path, gallery_faces in gallery_features.items():
        for face_data in gallery_faces:
            gallery_face_features = face_data['features']

            # Calculate cosine similarity
            similarity = cosine_similarity(query_features, gallery_face_features)[0][0]

            if similarity >= threshold:
                matches.append({
                    'image_path': gallery_img_path,
                    'similarity': similarity,
                    'bbox': face_data['bbox'],
                    'face_id': face_data['face_id']
                })

    # Sort by similarity and return top K
    matches = sorted(matches, key=lambda x: x['similarity'], reverse=True)
    return matches[:top_k]

# Cell 5: Main Search Functions
def search_face_in_gallery(query_image_path, gallery_features, model, device,
                          similarity_threshold=0.7, top_k=5):
    """Search for a specific face in the gallery"""
    try:
        print(f"🔍 Searching for face in: {query_image_path}")

        # Detect faces in query image
        query_faces = detect_faces_opencv(query_image_path)

        if not query_faces:
            print("❌ No faces detected in query image")
            return []

        all_matches = []

        # Process each detected face in query image
        for i, face_data in enumerate(query_faces):
            print(f"   Processing face {i+1}/{len(query_faces)}")

            # Preprocess query face
            face_tensor = preprocess_face_for_dinov2(face_data['crop'])

            if face_tensor is None:
                continue

            # Extract features
            query_features = extract_face_features(face_tensor, model, device)

            if query_features is None:
                continue

            # Find matches
            matches = find_face_matches(query_features, gallery_features,
                                      similarity_threshold, top_k)

            # Add query face info to matches
            for match in matches:
                match['query_face_id'] = i
                match['query_image'] = query_image_path

            all_matches.extend(matches)

        # Sort all matches by similarity
        all_matches = sorted(all_matches, key=lambda x: x['similarity'], reverse=True)

        print(f"✅ Found {len(all_matches)} matches above threshold {similarity_threshold}")
        return all_matches[:top_k]

    except Exception as e:
        print(f"❌ Error searching face: {e}")
        return []

def display_face_matches(matches, query_image_path=None):
    """Display face matching results"""
    if not matches:
        print("No matches to display")
        return

    num_matches = len(matches)
    cols = min(3, num_matches)
    rows = (num_matches + cols - 1) // cols

    fig, axes = plt.subplots(rows, cols, figsize=CONFIG['FIGURE_SIZE'])

    if rows == 1:
        axes = [axes] if cols == 1 else axes
    else:
        axes = axes.flatten() if rows > 1 else [axes]

    for i, match in enumerate(matches):
        if i >= len(axes):
            break

        ax = axes[i]

        # Load and display image
        img = cv2.imread(str(match['image_path']))
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        ax.imshow(img_rgb)

        # Draw bounding box around detected face
        bbox = match['bbox']
        x1, y1, x2, y2 = bbox
        rect = patches.Rectangle((x1, y1), x2-x1, y2-y1,
                               linewidth=3, edgecolor='red', facecolor='none')
        ax.add_patch(rect)

        # Add query information to display
        ax.text(x1, y2+5, f'Query: {match.get("query_key", "Unknown")}',
               color='blue', fontweight='bold', fontsize=8,
               bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue", alpha=0.8))
        ax.text(x1, y1-10, f'Similarity: {match["similarity"]:.3f}',
               color='red', fontweight='bold', fontsize=12,
               bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))

        ax.set_title(f"{match['image_path'].name}", fontsize=10)
        ax.axis('off')

    # Hide empty subplots
    for j in range(num_matches, len(axes)):
        axes[j].axis('off')

    plt.tight_layout()
    plt.suptitle(f'Face Recognition Results - Top {num_matches} Matches',
                 fontsize=14, y=1.02)
    plt.show()

# Cell 6: Interactive GUI Interface
import ipywidgets as widgets
from IPython.display import display, clear_output
import shutil
from datetime import datetime

setup_directories()

def create_face_recognition_interface():
    """Create interactive GUI interface for facial recognition"""

    # Model selection dropdown
    model_options = [(f"{info['name']} - {info['description']}", key)
                    for key, info in AVAILABLE_MODELS.items()]

    model_selector = widgets.Dropdown(
        options=model_options,
        value='dinov2_vitb14',
        description='Model:',
        style={'description_width': 'initial'},
        layout=widgets.Layout(width='500px')
    )

    # Gallery paths
    gallery_path_input = widgets.Text(
        value=CONFIG['SUSPECTS_GALLERY_PATH'],
        placeholder='Path to suspects gallery folder',
        description='Gallery Path:',
        style={'description_width': 'initial'},
        layout=widgets.Layout(width='500px')
    )

    query_path_input = widgets.Text(
        value=CONFIG['QUERY_IMAGES_PATH'],
        placeholder='Path to query images folder',
        description='Query Folder:',
        style={'description_width': 'initial'},
        layout=widgets.Layout(width='500px')
    )

    # Similarity threshold slider
    similarity_slider = widgets.FloatSlider(
        value=CONFIG['SIMILARITY_THRESHOLD'],
        min=0.3,
        max=0.95,
        step=0.05,
        description='Similarity Threshold:',
        style={'description_width': 'initial'},
        readout_format='.2f'
    )

    # Top K matches slider
    top_k_slider = widgets.IntSlider(
        value=CONFIG['TOP_K_MATCHES'],
        min=1,
        max=20,
        step=1,
        description='Max Results:',
        style={'description_width': 'initial'}
    )

    # Buttons
    search_button = widgets.Button(
        description='🔍 Search All Faces',
        button_style='primary',
        layout=widgets.Layout(width='170px')
    )

    process_gallery_button = widgets.Button(
        description='📁 Process Gallery',
        button_style='info',
        layout=widgets.Layout(width='150px')
    )

    process_queries_button = widgets.Button(
        description='🎯 Process Queries',
        button_style='info',
        layout=widgets.Layout(width='150px')
    )

    copy_results_button = widgets.Button(
        description='📋 Copy Results',
        button_style='success',
        layout=widgets.Layout(width='150px'),
        disabled=True
    )

    clear_button = widgets.Button(
        description='🗑️ Clear Results',
        button_style='warning',
        layout=widgets.Layout(width='150px')
    )

    switch_model_button = widgets.Button(
        description='🔄 Switch Model',
        button_style='info',
        layout=widgets.Layout(width='150px')
    )

    # Status and progress
    status_label = widgets.HTML(value="<b>Status:</b> Ready")
    progress_bar = widgets.IntProgress(
        value=0,
        min=0,
        max=100,
        description='Progress:',
        bar_style='',
        style={'bar_color': '#1f77b4'},
        layout=widgets.Layout(width='400px')
    )

    # Output area
    output_area = widgets.Output()

    # Store current state
    current_model = model
    current_embed_dim = embed_dim
    gallery_features = {}
    query_features = {}
    search_results = []

    def update_status(message, progress=None):
        status_label.value = f"<b>Status:</b> {message}"
        if progress is not None:
            progress_bar.value = progress

    def on_switch_model_clicked(b):
        nonlocal current_model, current_embed_dim, gallery_features
        with output_area:
            clear_output(wait=True)
            selected_model = model_selector.value
            update_status(f"Switching to {AVAILABLE_MODELS[selected_model]['name']}...", 10)

            try:
                new_model, new_embed_dim = load_dinov2_model(selected_model)
                if new_model:
                    current_model = new_model
                    current_embed_dim = new_embed_dim
                    gallery_features = {}  # Clear cached features
                    query_features = {}    # Clear cached query features
                    update_status("Model switched successfully!", 100)
                    print("✅ Model switched successfully!")
                    print("💡 Gallery will be reprocessed on next search")
                else:
                    update_status("Failed to switch model", 0)
                    print("❌ Failed to switch model")
            except Exception as e:
                update_status("Model switch error", 0)
                print(f"❌ Error switching model: {e}")

    def on_process_gallery_clicked(b):
        nonlocal gallery_features
        with output_area:
            clear_output(wait=True)

            if not current_model:
                print("❌ No model loaded. Please switch to a valid model first.")
                return

            update_status("Processing gallery images...", 20)
            print("📁 Processing gallery images...")

            try:
                gallery_features = process_gallery_images(
                    gallery_path_input.value,
                    current_model,
                    CONFIG['DEVICE']
                )

                if gallery_features:
                    total_faces = sum(len(faces) for faces in gallery_features.values())
                    update_status(f"Gallery processed: {len(gallery_features)} images, {total_faces} faces", 100)
                    print(f"✅ Gallery processed successfully!")
                    print(f"📊 Found {len(gallery_features)} images with {total_faces} total faces")
                else:
                    update_status("No faces found in gallery", 100)
                    print("⚠️ No faces found in gallery images")

            except Exception as e:
                update_status("Gallery processing error", 0)
                print(f"❌ Error processing gallery: {e}")

    def on_process_queries_clicked(b):
        nonlocal query_features
        with output_area:
            clear_output(wait=True)

            if not current_model:
                print("❌ No model loaded. Please switch to a valid model first.")
                return

            update_status("Processing query images...", 20)
            print("🎯 Processing query images...")

            try:
                query_features = process_gallery_images(
                    query_path_input.value,
                    current_model,
                    CONFIG['DEVICE']
                )

                if query_features:
                    total_faces = sum(len(faces) for faces in query_features.values())
                    update_status(f"Queries processed: {len(query_features)} images, {total_faces} faces", 100)
                    print(f"✅ Query images processed successfully!")
                    print(f"📊 Found {len(query_features)} images with {total_faces} total faces")
                else:
                    update_status("No faces found in query images", 100)
                    print("⚠️ No faces found in query images")

            except Exception as e:
                update_status("Query processing error", 0)
                print(f"❌ Error processing queries: {e}")

    def search_all_query_faces_in_gallery(query_features, gallery_features,
                                         similarity_threshold, top_k):
        """Search all query faces against gallery"""
        all_results = {}

        for query_img_path, query_faces in query_features.items():
            print(f"🔍 Processing query image: {query_img_path.name}")

            for i, query_face in enumerate(query_faces):
                query_face_features = query_face['features']

                # Find matches for this query face
                matches = find_face_matches(query_face_features, gallery_features,
                                          similarity_threshold, top_k)

                if matches:
                    query_key = f"{query_img_path.name}_face_{i}"
                    all_results[query_key] = {
                        'query_image': query_img_path,
                        'query_face_id': i,
                        'matches': matches
                    }
                    print(f"  Face {i+1}: Found {len(matches)} matches")
                else:
                    print(f"  Face {i+1}: No matches found")

        return all_results
    def on_search_clicked(b):
        nonlocal search_results, gallery_features, query_features
        with output_area:
            clear_output(wait=True)

            if not current_model:
                print("❌ No model loaded. Please switch to a valid model first.")
                return

            # Update parameters
            similarity_threshold = similarity_slider.value
            top_k = top_k_slider.value

            update_status("Starting face recognition search...", 10)
            print(f"🚀 Starting Batch Face Recognition Search")
            print(f"📁 Gallery: {gallery_path_input.value}")
            print(f"🎯 Query folder: {query_path_input.value}")
            print(f"🎯 Similarity threshold: {similarity_threshold:.2f}")
            print(f"📊 Max results per query: {top_k}")
            print("="*60)

            # Process gallery if not already done
            if not gallery_features:
                update_status("Processing gallery images...", 30)
                print("📁 Processing gallery images...")
                gallery_features = process_gallery_images(
                    gallery_path_input.value,
                    current_model,
                    CONFIG['DEVICE']
                )

                if not gallery_features:
                    update_status("No faces found in gallery", 100)
                    print("❌ No faces found in gallery images")
                    return

            # Process query images if not already done
            if not query_features:
                update_status("Processing query images...", 50)
                print("🎯 Processing query images...")
                query_features = process_gallery_images(
                    query_path_input.value,
                    current_model,
                    CONFIG['DEVICE']
                )

                if not query_features:
                    update_status("No faces found in query images", 100)
                    print("❌ No faces found in query images")
                    return

            # Search for matches
            update_status("Searching for face matches...", 70)
            print(f"\n🔍 Searching for matches...")

            try:
                all_results = search_all_query_faces_in_gallery(
                    query_features, gallery_features, similarity_threshold, top_k
                )

                if all_results:
                    # Flatten results for display
                    search_results = []
                    total_matches = 0

                    print(f"\n📊 Search Results Summary:")
                    print("-" * 50)

                    for query_key, result_data in all_results.items():
                        matches = result_data['matches']
                        total_matches += len(matches)

                        print(f"🔍 {query_key}: {len(matches)} matches")

                        # Add query info to each match for display
                        for match in matches:
                            match['query_key'] = query_key
                            match['query_image'] = result_data['query_image']
                            search_results.append(match)

                    # Sort all results by similarity
                    search_results = sorted(search_results,
                                          key=lambda x: x['similarity'], reverse=True)

                    update_status(f"Found {total_matches} total matches!", 100)
                    copy_results_button.disabled = False

                    print(f"\n🎯 Total matches found: {total_matches}")
                    print(f"📊 Showing top {min(len(search_results), 20)} matches:")

                    for i, match in enumerate(search_results[:20], 1):
                        print(f"  {i}. {match['query_key']} → {match['image_path'].name} "
                              f"(similarity: {match['similarity']:.3f})")

                    # Display top results visually
                    print(f"\n🖼️ Displaying top {min(len(search_results), 10)} results...")
                    display_face_matches(search_results[:10])

                else:
                    update_status("No matches found", 100)
                    copy_results_button.disabled = True
                    print("❌ No matches found above the similarity threshold")
                    print("💡 Try lowering the similarity threshold")

            except Exception as e:
                update_status("Search error", 0)
                print(f"❌ Error during search: {e}")

    def on_copy_results_clicked(b):
        with output_area:
            if not search_results:
                print("⚠️ No results to copy")
                return

            try:
                # Create timestamped results folder
                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                results_folder = Path(CONFIG['RESULTS_OUTPUT_PATH']) / f"face_search_{timestamp}"
                results_folder.mkdir(parents=True, exist_ok=True)

                update_status("Copying results...", 50)
                print(f"📋 Copying {len(search_results)} matching images...")

                copied_files = []
                for i, result in enumerate(search_results):
                    source_path = result['image_path']
                    similarity = result['similarity']
                    filename = f"{i+1:03d}_{source_path.stem}_sim{similarity:.3f}{source_path.suffix}"
                    dest_path = results_folder / filename

                    shutil.copy2(source_path, dest_path)
                    copied_files.append(dest_path)

                update_status(f"Results copied to {results_folder.name}", 100)
                print(f"✅ Copied {len(copied_files)} files to: {results_folder}")
                print(f"📁 Results folder: {results_folder}")

            except Exception as e:
                update_status("Copy error", 0)
                print(f"❌ Error copying results: {e}")

    def on_clear_clicked(b):
        nonlocal search_results, gallery_features, query_features
        search_results = []
        gallery_features = {}
        query_features = {}
        copy_results_button.disabled = True
        with output_area:
            clear_output()
            update_status("All data cleared", 0)
            print("🗑️ All results and cached data cleared")

    # Connect button events
    switch_model_button.on_click(on_switch_model_clicked)
    process_gallery_button.on_click(on_process_gallery_clicked)
    process_queries_button.on_click(on_process_queries_clicked)
    search_button.on_click(on_search_clicked)
    copy_results_button.on_click(on_copy_results_clicked)
    clear_button.on_click(on_clear_clicked)

    # Layout
    controls = widgets.VBox([
        widgets.HTML("<h3>🔍 DINOv2 Facial Recognition Interface</h3>"),
        model_selector,
        gallery_path_input,
        query_path_input,
        widgets.HBox([similarity_slider, top_k_slider]),
        widgets.HBox([search_button, process_gallery_button, process_queries_button]),
        widgets.HBox([copy_results_button, clear_button, switch_model_button]),
        status_label,
        progress_bar,
        widgets.HTML("<hr>")
    ])

    return widgets.VBox([controls, output_area])

def run_face_recognition_search(query_image_path, gallery_path=None,
                               similarity_threshold=None, top_k=None):
    """Main function to run face recognition search (command-line version)"""

    # Use defaults if not provided
    if gallery_path is None:
        gallery_path = CONFIG['SUSPECTS_GALLERY_PATH']
    if similarity_threshold is None:
        similarity_threshold = CONFIG['SIMILARITY_THRESHOLD']
    if top_k is None:
        top_k = CONFIG['TOP_K_MATCHES']

    print("🚀 Starting Face Recognition Search")
    print("="*50)

    # Check if model is loaded
    if model is None:
        print("❌ Model not loaded. Please run the model loading cell first.")
        return None

    # Process gallery images (this might take a while for large galleries)
    print("📁 Processing gallery images...")
    gallery_features = process_gallery_images(gallery_path, model, CONFIG['DEVICE'])

    if not gallery_features:
        print("❌ No faces found in gallery images")
        return None

    # Search for face
    print(f"\n🔍 Searching for matches...")
    matches = search_face_in_gallery(query_image_path, gallery_features, model,
                                   CONFIG['DEVICE'], similarity_threshold, top_k)

    # Display results
    if matches:
        print(f"\n📊 Found {len(matches)} matches:")
        for i, match in enumerate(matches, 1):
            print(f"  {i}. {match['image_path'].name} - Similarity: {match['similarity']:.3f}")

        # Visualize results
        display_face_matches(matches, query_image_path)
        return matches
    else:
        print("❌ No matches found above the similarity threshold")
        return []

# Display the GUI interface
print("🎨 Creating GUI Interface...")
interface = create_face_recognition_interface()
display(interface)

# Cell 7: Usage Instructions and Command-line Functions

# Usage Instructions
print("""
🎯 DINOv2 FACIAL RECOGNITION SYSTEM WITH GUI
===========================================

📋 SETUP:
1. Place suspect images in './suspects_gallery/'
2. Place query face images in './query_images/'
3. Run all cells in order
4. Use the GUI interface above for interactive search

🖱️ GUI INTERFACE FEATURES:
• Model Selection: Choose between DINOv2 variants
• Query Image Input: Path to the face you want to search for
• Similarity Threshold: Adjust matching strictness (0.3-0.95)
• Max Results: Number of top matches to display
• Process Gallery: Pre-process gallery for faster searches
• Search Faces: Find matching faces in gallery
• Copy Results: Save matching images to results folder
• Real-time Status: Progress updates and status messages

🔍 COMMAND-LINE USAGE (Alternative):

# Basic search
matches = run_face_recognition_search('./query_images/suspect_face.jpg')

# Custom parameters
matches = run_face_recognition_search(
    './query_images/suspect_face.jpg',
    similarity_threshold=0.8,  # Higher = more strict
    top_k=10                   # Return top 10 matches
)

# Batch processing
batch_results = batch_face_recognition('./query_images/')

⚙️ PARAMETERS:
• similarity_threshold: 0.3-0.95 (0.7 recommended)
  - Lower values = more matches, less strict
  - Higher values = fewer matches, more strict
• top_k: Number of top matches to return (1-20)
• Models: dinov2_vits14 (fast), dinov2_vitb14 (balanced), dinov2_vitl14 (accurate)

🎯 WORKFLOW:
1. Load suspect images into gallery folder
2. Select DINOv2 model variant
3. Optionally pre-process gallery (recommended for large datasets)
4. Enter query image path
5. Adjust similarity threshold as needed
6. Click "Search Faces" to find matches
7. Review results with similarity scores
8. Copy results to organized folder if needed

💡 TIPS:
• Pre-process gallery once, then run multiple searches quickly
• Use clear, front-facing face images for best results
• Start with similarity threshold 0.7, adjust based on results
• Multiple faces in images are automatically detected
• Results show bounding boxes around detected faces
• Higher resolution images generally give better results

🔧 TROUBLESHOOTING:
• "No faces found": Check image quality and face visibility
• "Low similarity scores": Try different DINOv2 model or lower threshold
• "Out of memory": Use smaller model (vits14) or process fewer images
• "Gallery not processed": Click "Process Gallery" before searching
""")

# Example usage (uncomment to run):
# matches = run_face_recognition_search('./query_images/suspect1.jpg')

# Advanced batch processing function
def batch_face_recognition(query_folder, gallery_folder=None):
    """Process multiple query images at once"""
    query_path = Path(query_folder)

    if not query_path.exists():
        print(f"❌ Query folder {query_folder} does not exist")
        return {}

    if gallery_folder is None:
        gallery_folder = CONFIG['SUSPECTS_GALLERY_PATH']

    # Process gallery once
    print("📁 Processing gallery images...")
    gallery_features = process_gallery_images(gallery_folder, model, CONFIG['DEVICE'])

    if not gallery_features:
        print("❌ No faces found in gallery")
        return {}

    # Process each query image
    image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
    query_files = [f for f in query_path.iterdir()
                  if f.suffix.lower() in image_extensions]

    batch_results = {}

    for query_file in query_files:
        print(f"\n🔍 Processing query: {query_file.name}")
        matches = search_face_in_gallery(query_file, gallery_features, model,
                                       CONFIG['DEVICE'], CONFIG['SIMILARITY_THRESHOLD'],
                                       CONFIG['TOP_K_MATCHES'])
        batch_results[query_file] = matches

        if matches:
            print(f"   ✅ Found {len(matches)} matches")
        else:
            print(f"   ⚪ No matches found")

    return batch_results