In [None]:
# Cell 1: Configuration and Setup
import os
import shutil
import torch
from PIL import Image
import numpy as np
from pathlib import Path
import ipywidgets as widgets
from IPython.display import display, clear_output
import matplotlib.pyplot as plt
import matplotlib.patches as patches

# Configuration Settings
CONFIG = {
    # Model Selection (OpenAI CLIP)
    'MODEL_NAME': 'ViT-B/32',  # Options: RN50, RN101, ViT-B/32, ViT-B/16, ViT-L/14
    'DEVICE': 'auto',  # 'auto', 'cpu', 'cuda'

    # Paths
    'SUSPECTS_GALLERY_PATH': './suspects_gallery',
    'RESULTS_OUTPUT_PATH': './search_results',

    # Detection parameters
    'SIMILARITY_THRESHOLD': 0.25,
    'PATCH_SIZE': 224,  # CLIP input size
    'STRIDE': 112,      # Half of patch size for overlap
    'MAX_PATCHES': 16,  # Limit patches per image

    # Processing settings
    'BATCH_SIZE': 8,
    'MAX_RESULTS_DISPLAY': 10,
    'FIGURE_SIZE': (12, 8),
}

# Available OpenAI CLIP models
AVAILABLE_MODELS = {
    'RN50': {
        'name': 'ResNet-50',
        'description': 'Fastest model - good for quick analysis',
        'performance': 'Fast speed, good accuracy'
    },
    'RN101': {
        'name': 'ResNet-101',
        'description': 'Balanced performance and speed',
        'performance': 'Balanced speed and accuracy'
    },
    'ViT-B/32': {
        'name': 'ViT-Base/32',
        'description': 'Vision Transformer - good balance (recommended)',
        'performance': 'Good performance, moderate speed'
    },
    'ViT-B/16': {
        'name': 'ViT-Base/16',
        'description': 'Higher resolution ViT - better accuracy',
        'performance': 'Better accuracy, slower speed'
    },
    'ViT-L/14': {
        'name': 'ViT-Large/14',
        'description': 'Largest model - best accuracy, slowest',
        'performance': 'Best accuracy, slowest speed'
    }
}

print("✅ Configuration loaded successfully")
print(f"📁 Suspects gallery: {CONFIG['SUSPECTS_GALLERY_PATH']}")
print(f"📁 Results output: {CONFIG['RESULTS_OUTPUT_PATH']}")
print(f"🔍 Selected model: {CONFIG['MODEL_NAME']}")

# Cell 2: Install and Import Dependencies
try:
    import clip
    print("✅ OpenAI CLIP already installed")
except ImportError:
    print("⚠️ Installing OpenAI CLIP...")
    import subprocess
    import sys

    # Install required packages
    packages = [
        "torch",
        "torchvision",
        "git+https://github.com/openai/CLIP.git",
        "ipywidgets",
        "matplotlib",
        "pillow",
        "numpy"
    ]

    for package in packages:
        try:
            subprocess.check_call([sys.executable, "-m", "pip", "install", package])
        except subprocess.CalledProcessError as e:
            print(f"❌ Failed to install {package}: {e}")

    print("📦 Installation complete - please restart kernel and run again")

# Import required libraries
try:
    import clip
    import torch
    import torch.nn.functional as F
    from PIL import Image
    import numpy as np

    print("✅ All dependencies imported successfully")

    # Check device compatibility
    if torch.cuda.is_available():
        print(f"🚀 CUDA available: {torch.cuda.get_device_name(0)}")
        default_device = "cuda"
    else:
        print("🖥️ Using CPU mode")
        default_device = "cpu"

    if CONFIG['DEVICE'] == 'auto':
        CONFIG['DEVICE'] = default_device
        print(f"📍 Device: {CONFIG['DEVICE']}")

    # Show available models
    available_models = clip.available_models()
    print(f"🔍 Available CLIP models: {available_models}")

except ImportError as e:
    print(f"❌ Import error: {e}")
    print("Please restart kernel and run Cell 2 again")

# Cell 3: Initialize Model and Setup
def setup_directories():
    """Create necessary directories"""
    os.makedirs(CONFIG['SUSPECTS_GALLERY_PATH'], exist_ok=True)
    os.makedirs(CONFIG['RESULTS_OUTPUT_PATH'], exist_ok=True)
    print(f"📁 Directories ready")

def load_clip_model(model_name=None):
    """Load OpenAI CLIP model"""
    try:
        if model_name is None:
            model_name = CONFIG['MODEL_NAME']

        device = CONFIG['DEVICE']
        print(f"📥 Loading {model_name} on {device}...")

        model, preprocess = clip.load(model_name, device=device)
        model.eval()

        print(f"✅ Model loaded successfully")
        return model, preprocess, device, model_name

    except Exception as e:
        print(f"❌ Error loading model: {e}")
        return None, None, None, None

def switch_model(model_key):
    """Switch to different model"""
    if model_key in AVAILABLE_MODELS:
        CONFIG['MODEL_NAME'] = model_key
        print(f"🔄 Switching to {AVAILABLE_MODELS[model_key]['name']}")
        return load_clip_model()
    else:
        print(f"❌ Unknown model: {model_key}")
        return None, None, None, None

# Initialize
setup_directories()
model, preprocess, device, model_name = load_clip_model()

# Cell 4: Core Search Functions
def extract_patches(image, patch_size=224, stride=112, max_patches=16):
    """Extract patches from image"""
    width, height = image.size
    patches = []
    positions = []

    # Calculate patch positions
    x_positions = range(0, max(1, width - patch_size + 1), stride)
    y_positions = range(0, max(1, height - patch_size + 1), stride)

    for y in y_positions:
        for x in x_positions:
            # Ensure we don't go out of bounds
            x_end = min(x + patch_size, width)
            y_end = min(y + patch_size, height)

            # Extract patch
            patch = image.crop((x, y, x_end, y_end))

            # Resize to CLIP input size if needed
            if patch.size != (patch_size, patch_size):
                patch = patch.resize((patch_size, patch_size), Image.LANCZOS)

            patches.append(patch)
            positions.append((x, y, x_end, y_end))

            # Limit number of patches
            if len(patches) >= max_patches:
                break
        if len(patches) >= max_patches:
            break

    return patches, positions

def compute_similarity(model, preprocess, text_query, images, device):
    """Compute CLIP similarity between text and images"""
    try:
        # Preprocess images
        image_inputs = torch.stack([preprocess(img) for img in images]).to(device)

        # Tokenize text
        text_inputs = clip.tokenize([text_query]).to(device)

        with torch.no_grad():
            # Get embeddings
            image_features = model.encode_image(image_inputs)
            text_features = model.encode_text(text_inputs)

            # Normalize features
            image_features = F.normalize(image_features, dim=-1)
            text_features = F.normalize(text_features, dim=-1)

            # Calculate similarities
            similarities = torch.matmul(image_features, text_features.T).squeeze(-1)

        return similarities.cpu().numpy()

    except Exception as e:
        print(f"Error in similarity computation: {e}")
        return np.array([])

def search_single_image(image_path, model, preprocess, device, query):
    """Search single image for query"""
    try:
        # Load image
        image = Image.open(image_path).convert('RGB')

        # Extract patches
        patches, positions = extract_patches(
            image,
            CONFIG['PATCH_SIZE'],
            CONFIG['STRIDE'],
            CONFIG['MAX_PATCHES']
        )

        if not patches:
            return None

        # Compute similarities
        similarities = compute_similarity(model, preprocess, query, patches, device)

        if len(similarities) == 0:
            return None

        # Filter by threshold
        high_sim_indices = similarities >= CONFIG['SIMILARITY_THRESHOLD']

        if not high_sim_indices.any():
            return None

        # Get results above threshold
        filtered_similarities = similarities[high_sim_indices]
        filtered_positions = [positions[i] for i in range(len(positions)) if high_sim_indices[i]]

        return {
            'image_path': image_path,
            'image': image,
            'boxes': torch.tensor(filtered_positions, dtype=torch.float32),
            'scores': torch.tensor(filtered_similarities, dtype=torch.float32),
            'query': query
        }

    except Exception as e:
        print(f"Error processing {image_path.name}: {e}")
        return None

def search_images_with_query(query, model, preprocess, device, model_name, gallery_path):
    """Search all images for query with timing"""
    import time
    from datetime import datetime

    start_time = time.time()
    start_datetime = datetime.now()

    print(f"🔍 Starting search for: '{query}'")
    print(f"🖥️ Model: {model_name} on {device}")
    print(f"⏰ Start: {start_datetime.strftime('%H:%M:%S')}")
    print("-" * 50)

    # Get image files
    gallery_path = Path(gallery_path)
    image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
    image_files = [f for f in gallery_path.iterdir() if f.suffix.lower() in image_extensions]

    if not image_files:
        print(f"⚠️ No images found in {gallery_path}")
        return []

    results = []
    total_files = len(image_files)

    print(f"📸 Processing {total_files} images...")

    for i, img_path in enumerate(image_files):
        if i % 10 == 0:
            print(f"Progress: {i+1}/{total_files}", end='\r')

        result = search_single_image(img_path, model, preprocess, device, query)
        if result:
            results.append(result)

        # Memory cleanup
        if torch.cuda.is_available() and i % 20 == 0:
            torch.cuda.empty_cache()

    # Calculate timing
    end_time = time.time()
    end_datetime = datetime.now()
    duration = end_time - start_time

    # Format duration
    if duration >= 60:
        duration_str = f"{int(duration//60)}m {duration%60:.1f}s"
    else:
        duration_str = f"{duration:.1f}s"

    # Print summary
    print(f"\n" + "="*50)
    print(f"📊 SEARCH SUMMARY")
    print(f"="*50)
    print(f"📸 Images processed: {total_files}")
    print(f"✅ Matches found: {len(results)}")
    print(f"⏰ Start: {start_datetime.strftime('%H:%M:%S')}")
    print(f"🏁 End: {end_datetime.strftime('%H:%M:%S')}")
    print(f"⏱️ Duration: {duration_str}")
    print(f"📈 Avg per image: {duration/total_files:.2f}s")
    print(f"="*50)

    return results

def copy_results_to_folder(results, output_folder):
    """Copy matched images to results folder"""
    if not results:
        return None, []

    output_path = Path(output_folder)
    output_path.mkdir(exist_ok=True)

    # Create timestamped folder
    from datetime import datetime
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    search_folder = output_path / f"search_{timestamp}"
    search_folder.mkdir(exist_ok=True)

    copied_files = []

    for i, result in enumerate(results):
        try:
            source_path = result['image_path']
            max_score = float(result['scores'].max()) if len(result['scores']) > 0 else 0.0
            filename = f"{i+1:03d}_{source_path.stem}_sim{max_score:.2f}{source_path.suffix}"
            dest_path = search_folder / filename

            shutil.copy2(source_path, dest_path)
            copied_files.append(dest_path)

        except Exception as e:
            print(f"Error copying {source_path}: {e}")

    print(f"📋 Copied {len(copied_files)} files to {search_folder}")
    return search_folder, copied_files

# Cell 5: Interactive Interface
def create_search_interface():
    """Create search interface"""

    # Widgets
    model_selector = widgets.Dropdown(
        options=[(f"{info['name']} - {info['description']}", key)
                for key, info in AVAILABLE_MODELS.items()],
        value='ViT-B/32',
        description='Model:',
        layout=widgets.Layout(width='400px')
    )

    query_input = widgets.Text(
        value='person with weapon',
        placeholder='Enter search query...',
        description='Query:',
        layout=widgets.Layout(width='400px')
    )

    threshold_slider = widgets.FloatSlider(
        value=CONFIG['SIMILARITY_THRESHOLD'],
        min=0.1,
        max=0.8,
        step=0.05,
        description='Threshold:',
        readout_format='.2f'
    )

    search_button = widgets.Button(
        description='🔍 Search',
        button_style='primary',
        layout=widgets.Layout(width='120px')
    )

    copy_button = widgets.Button(
        description='📋 Copy Results',
        button_style='success',
        layout=widgets.Layout(width='120px'),
        disabled=True
    )

    switch_button = widgets.Button(
        description='🔄 Switch Model',
        button_style='info',
        layout=widgets.Layout(width='120px')
    )

    output_area = widgets.Output()

    # State
    search_results = []
    current_model = model
    current_preprocess = preprocess
    current_device = device
    current_model_name = model_name

    def on_switch_model(b):
        nonlocal current_model, current_preprocess, current_device, current_model_name
        with output_area:
            selected = model_selector.value
            new_model, new_preprocess, new_device, new_name = switch_model(selected)
            if new_model:
                current_model = new_model
                current_preprocess = new_preprocess
                current_device = new_device
                current_model_name = new_name
                print("✅ Model switched!")
            else:
                print("❌ Model switch failed")

    def on_search(b):
        nonlocal search_results
        with output_area:
            clear_output(wait=True)

            if not current_model:
                print("❌ No model loaded")
                return

            query = query_input.value.strip()
            if not query:
                print("⚠️ Enter a query")
                return

            CONFIG['SIMILARITY_THRESHOLD'] = threshold_slider.value

            search_results = search_images_with_query(
                query, current_model, current_preprocess, current_device,
                current_model_name, CONFIG['SUSPECTS_GALLERY_PATH']
            )

            if search_results:
                copy_button.disabled = False
                display_search_results(search_results[:CONFIG['MAX_RESULTS_DISPLAY']])
                if len(search_results) > CONFIG['MAX_RESULTS_DISPLAY']:
                    print(f"\nShowing {CONFIG['MAX_RESULTS_DISPLAY']} of {len(search_results)} results")
            else:
                print("🔍 No matches found")
                copy_button.disabled = True

    def on_copy(b):
        with output_area:
            if search_results:
                folder, files = copy_results_to_folder(search_results, CONFIG['RESULTS_OUTPUT_PATH'])
                print(f"✅ Results saved to: {folder}")
            else:
                print("⚠️ No results to copy")

    # Connect events
    switch_button.on_click(on_switch_model)
    search_button.on_click(on_search)
    copy_button.on_click(on_copy)

    # Layout
    controls = widgets.VBox([
        widgets.HTML("<h3>🔍 Forensic Image Search (OpenAI CLIP)</h3>"),
        model_selector,
        query_input,
        threshold_slider,
        widgets.HBox([search_button, copy_button, switch_button]),
        widgets.HTML("<hr>")
    ])

    return widgets.VBox([controls, output_area])

def display_search_results(results):
    """Display results with bounding boxes"""
    if not results:
        return

    cols = 2
    rows = (len(results) + cols - 1) // cols

    fig, axes = plt.subplots(rows, cols, figsize=CONFIG['FIGURE_SIZE'])
    if rows == 1:
        axes = [axes] if cols == 1 else axes
    else:
        axes = axes.flatten()

    for i, result in enumerate(results):
        ax = axes[i] if len(results) > 1 else axes

        # Display image
        ax.imshow(result['image'])

        # Draw bounding boxes
        boxes = result['boxes']
        scores = result['scores']

        for box, score in zip(boxes, scores):
            x1, y1, x2, y2 = box.tolist()

            rect = patches.Rectangle(
                (x1, y1), x2 - x1, y2 - y1,
                linewidth=2, edgecolor='red', facecolor='none'
            )
            ax.add_patch(rect)

            # Add score
            ax.text(x1, y1 - 5, f'{score:.2f}',
                   color='red', fontweight='bold', fontsize=9)

        ax.set_title(f"{result['image_path'].name}\nMatches: {len(boxes)}", fontsize=10)
        ax.axis('off')

    # Hide empty subplots
    for j in range(len(results), len(axes)):
        axes[j].axis('off')

    plt.tight_layout()
    plt.show()

# Display interface
interface = create_search_interface()
display(interface)

# Cell 6: Batch Processing
def run_batch_analysis(custom_queries=None):
    """Run batch analysis with multiple queries"""
    import time
    from datetime import datetime

    # Default queries
    default_queries = [
        "person with weapon",
        "weapon",
        "gun",
        "knife",
        "suspicious person",
        "vehicle",
        "mask",
        "backpack"
    ]

    queries = default_queries + (custom_queries or [])

    print(f"🚀 Starting batch analysis with {len(queries)} queries")
    batch_start = time.time()

    all_results = {}

    for i, query in enumerate(queries, 1):
        print(f"\n[{i}/{len(queries)}] Processing: '{query}'")

        query_start = time.time()
        results = search_images_with_query(
            query, model, preprocess, device, model_name,
            CONFIG['SUSPECTS_GALLERY_PATH']
        )
        query_time = time.time() - query_start

        all_results[query] = {
            'results': results,
            'count': len(results),
            'time': query_time
        }

        if results:
            folder, files = copy_results_to_folder(results, CONFIG['RESULTS_OUTPUT_PATH'])
            print(f"📁 Saved to: {folder.name}")

    # Summary
    batch_time = time.time() - batch_start
    total_matches = sum(data['count'] for data in all_results.values())

    print(f"\n{'='*50}")
    print(f"📊 BATCH SUMMARY")
    print(f"{'='*50}")

    for query, data in all_results.items():
        status = "✅" if data['count'] > 0 else "⚪"
        print(f"{status} '{query}': {data['count']} matches ({data['time']:.1f}s)")

    print(f"\n🎯 Total: {total_matches} matches in {batch_time/60:.1f} minutes")
    return all_results

# Cell 7: Usage Guide
print("""
🎯 FORENSIC IMAGE SEARCH - OPENAI CLIP
=====================================

📋 SETUP:
1. Place images in './suspects_gallery' folder
2. Run cells 1-5 in order
3. Use the interface above to search

🔍 AVAILABLE MODELS:
• RN50: Fastest
• ViT-B/32: Recommended balance
• ViT-B/16: Higher accuracy
• ViT-L/14: Best accuracy, slowest

🎛️ PARAMETERS:
• Similarity Threshold: 0.1-0.8 (start with 0.25)
• Lower = more results, higher = more precise

🔍 GOOD QUERIES:
• "weapon", "gun", "knife"
• "person", "suspicious person"
• "vehicle", "car"
• "mask", "backpack"

💡 TIPS:
• Use simple, clear terms
• Try different thresholds
• Check timing info for optimization
• Use batch processing for multiple queries

🚨 TROUBLESHOOTING:
• Restart kernel if installation fails
• Lower threshold if no matches
• Use RN50 model for speed
• Check image folder path

# Uncomment to run batch analysis:
# batch_results = run_batch_analysis()
""")