In [None]:
# Cell 1: Configuration and Setup
import torch
from IPython.display import display

# Configuration Settings
CONFIG = {
    # Paths
    'SUSPECTS_GALLERY_PATH': '../../datasets/images/objects/raw',  # Input folder with suspect images
    'RESULTS_OUTPUT_PATH': '../../datasets/images/objects/detections',      # Output folder for matched images
    'MODEL_CHECKPOINT': '../GroundingDINO/weights/groundingdino_swint_ogc.pth',  # GroundingDINO model path
    'MODEL_CONFIG': '../GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py',       # Model config file

    # Detection parameters
    'CONFIDENCE_THRESHOLD': 0.35,  # Minimum confidence for detections
    'BOX_THRESHOLD': 0.3,          # Box threshold for NMS
    'TEXT_THRESHOLD': 0.25,        # Text similarity threshold

    # Display settings
    'MAX_RESULTS_DISPLAY': 10,     # Maximum results to display at once
    'FIGURE_SIZE': (12, 8),        # Size of result visualization
}

print("✅ Configuration loaded successfully")
print(f"📁 Suspects gallery: {CONFIG['SUSPECTS_GALLERY_PATH']}")
print(f"📁 Results output: {CONFIG['RESULTS_OUTPUT_PATH']}")

# Cell 2: Install and Import Dependencies
# Run this cell first to install required packages
try:
    import groundingdino
    print("✅ GroundingDINO already installed")
except ImportError:
    print("⚠️ Installing GroundingDINO and dependencies...")
    !pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu  # CPU-only version
    !pip install groundingdino-py
    !pip install supervision
    !pip install transformers
    !pip install ipywidgets
    print("📦 Installation complete - using CPU-optimized PyTorch")

# Additional imports after installation
try:
    from groundingdino.models import build_model
    from groundingdino.util.slconfig import SLConfig
    from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
    from groundingdino.util.inference import annotate, load_image, predict
    import supervision as sv
    print("✅ All dependencies imported successfully")

    # Check PyTorch device compatibility
    print(f"🔧 PyTorch version: {torch.__version__}")
    if torch.cuda.is_available():
        print(f"🚀 CUDA available: {torch.cuda.get_device_name(0)}")
    else:
        print("🖥️ Using CPU mode (CUDA not available)")

except ImportError as e:
    print(f"❌ Import error: {e}")
    print("🔧 Troubleshooting steps:")
    print("1. Restart kernel and run this cell again")
    print("2. Check if all packages installed correctly")
    print("3. Try installing dependencies manually:")
    print("   !pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu")
    print("   !pip install groundingdino-py supervision transformers")

# Cell 3: Initialize Model and Directories
def setup_directories():
    """Create necessary directories if they don't exist"""
    os.makedirs(CONFIG['SUSPECTS_GALLERY_PATH'], exist_ok=True)
    os.makedirs(CONFIG['RESULTS_OUTPUT_PATH'], exist_ok=True)
    print(f"📁 Created directories: {CONFIG['SUSPECTS_GALLERY_PATH']}, {CONFIG['RESULTS_OUTPUT_PATH']}")

def check_device_compatibility():
    """Check and configure device compatibility"""
    print("🔧 Checking device compatibility...")

    # Check PyTorch installation
    print(f"PyTorch version: {torch.__version__}")

    # Check CUDA availability
    cuda_available = torch.cuda.is_available()
    print(f"CUDA available: {cuda_available}")

    if cuda_available:
        print(f"CUDA version: {torch.version.cuda}")
        print(f"GPU device: {torch.cuda.get_device_name(0)}")
        device = 'cuda'
    else:
        print("⚠️ CUDA not available - using CPU mode")
        print("Note: CPU inference will be slower but should work fine")
        device = 'cpu'

    return device

def load_groundingdino_model():
    """Load GroundingDINO model with proper device handling"""
    try:
        # Check device compatibility first
        device = check_device_compatibility()

        # Load model configuration
        args = SLConfig.fromfile(CONFIG['MODEL_CONFIG'])
        args.device = device
        model = build_model(args)

        # Load checkpoint with proper device mapping
        print(f"📥 Loading model checkpoint...")
        checkpoint = torch.load(CONFIG['MODEL_CHECKPOINT'], map_location=device)
        model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)

        # Move model to device and set to eval mode
        model = model.to(device)
        model.eval()

        # Store device info for later use
        model.device = device

        print(f"✅ Model loaded successfully on {device}")
        return model

    except FileNotFoundError as e:
        print(f"❌ Model files not found: {e}")
        print("Please ensure model files are available:")
        print(f"  - {CONFIG['MODEL_CONFIG']}")
        print(f"  - {CONFIG['MODEL_CHECKPOINT']}")
        print("\nTo download GroundingDINO model files:")
        print("1. Visit: https://github.com/IDEA-Research/GroundingDINO")
        print("2. Download the pre-trained weights")
        print("3. Update the CONFIG paths accordingly")
        return None

    except Exception as e:
        print(f"❌ Error loading model: {e}")
        print("\nTroubleshooting steps:")
        print("1. Check if all dependencies are installed")
        print("2. Verify model file paths are correct")
        print("3. Ensure sufficient memory available")
        return None

# Initialize
setup_directories()
model = load_groundingdino_model()

# Cell 4: Core Search Functions
def process_image_batch(image_batch, model, query, device):
    """Process a batch of images efficiently"""
    batch_results = []

    for img_path, image_source, image in image_batch:
        try:
            # Ensure image tensor is on the same device as model
            if hasattr(image, 'to'):
                image = image.to(device)

            # Run inference with device handling
            with torch.no_grad():  # Save memory during inference
                boxes, logits, phrases = predict(
                    model=model,
                    image=image,
                    caption=query,
                    box_threshold=CONFIG['BOX_THRESHOLD'],
                    text_threshold=CONFIG['TEXT_THRESHOLD'],
                    device=device
                )

            # Filter by confidence threshold
            if len(logits) > 0:
                high_conf_indices = logits > CONFIG['CONFIDENCE_THRESHOLD']

                if high_conf_indices.any():
                    filtered_boxes = boxes[high_conf_indices]
                    filtered_logits = logits[high_conf_indices]
                    filtered_phrases = [phrases[i] for i in range(len(phrases)) if high_conf_indices[i]]

                    batch_results.append({
                        'image_path': img_path,
                        'image_source': image_source,
                        'boxes': filtered_boxes,
                        'confidence_scores': filtered_logits,
                        'phrases': filtered_phrases,
                        'query': query
                    })

        except Exception as e:
            # Silently skip problematic images
            continue

    return batch_results

def search_images_with_query(query, model, gallery_path, batch_size=8):
    """
    Search for objects in images using natural language query with batch processing
    """
    results = []
    gallery_path = Path(gallery_path)

    if not gallery_path.exists():
        print(f"❌ Gallery path {gallery_path} does not exist")
        return results

    if not model:
        print("❌ Model not loaded. Please check model initialization.")
        return results

    # Supported image extensions
    image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
    image_files = [f for f in gallery_path.iterdir()
                  if f.suffix.lower() in image_extensions]

    if not image_files:
        print(f"⚠️ No images found in {gallery_path}")
        return results

    device = getattr(model, 'device', 'cpu')
    total_files = len(image_files)
    processed_count = 0
    error_count = 0

    print(f"🔍 Processing {total_files} images for: '{query}'")
    print(f"🖥️ Device: {device} | Batch size: {batch_size}")

    # Process images in batches
    for i in range(0, total_files, batch_size):
        batch_files = image_files[i:i + batch_size]
        batch_data = []

        # Load batch of images
        for img_path in batch_files:
            try:
                image_source, image = load_image(str(img_path))
                batch_data.append((img_path, image_source, image))
                processed_count += 1
            except Exception:
                error_count += 1
                continue

        # Process the batch
        if batch_data:
            batch_results = process_image_batch(batch_data, model, query, device)
            results.extend(batch_results)

        # Progress update
        progress = min(i + batch_size, total_files)
        matches_so_far = len(results)
        print(f"📊 Progress: {progress}/{total_files} | Matches: {matches_so_far}")

        # Clear cache periodically
        if torch.cuda.is_available() and i % (batch_size * 3) == 0:
            torch.cuda.empty_cache()

    # Final summary
    print(f"\n✅ Complete: {processed_count} processed, {error_count} errors, {len(results)} matches")
    return results

def copy_results_to_folder(results, output_folder):
    """Copy matched images to results folder"""
    output_path = Path(output_folder)
    output_path.mkdir(exist_ok=True)

    # Create subfolder with timestamp
    from datetime import datetime
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    search_folder = output_path / f"search_{timestamp}"
    search_folder.mkdir(exist_ok=True)

    copied_files = []

    for i, result in enumerate(results):
        try:
            source_path = result['image_path']
            # Create descriptive filename
            filename = f"{i+1:03d}_{source_path.stem}_conf{result['confidence_scores'].max():.2f}{source_path.suffix}"
            dest_path = search_folder / filename

            shutil.copy2(source_path, dest_path)
            copied_files.append(dest_path)

        except Exception as e:
            print(f"❌ Error copying {source_path}: {e}")

    print(f"📋 Copied {len(copied_files)} files to {search_folder}")
    return search_folder, copied_files

# Cell 5: Interactive Query Interface
def create_search_interface():
    """Create interactive search interface for forensic analysts"""

    batch_size_slider = widgets.IntSlider(
        value=8,
        min=1,
        max=16,
        step=1,
        description='Batch Size:',
        style={'description_width': 'initial'}
    )

    # Input widgets
    query_input = widgets.Text(
        value='person with red shirt',
        placeholder='Enter your search query (e.g., "person with weapon", "suspicious vehicle")',
        description='Query:',
        style={'description_width': 'initial'},
        layout=widgets.Layout(width='500px')
    )

    confidence_slider = widgets.FloatSlider(
        value=CONFIG['CONFIDENCE_THRESHOLD'],
        min=0.1,
        max=0.9,
        step=0.05,
        description='Min Confidence:',
        style={'description_width': 'initial'}
    )

    search_button = widgets.Button(
        description='🔍 Search Gallery',
        button_style='primary',
        layout=widgets.Layout(width='150px')
    )

    copy_button = widgets.Button(
        description='📋 Copy Results',
        button_style='success',
        layout=widgets.Layout(width='150px'),
        disabled=True
    )

    clear_button = widgets.Button(
        description='🗑️ Clear Results',
        button_style='warning',
        layout=widgets.Layout(width='150px')
    )

    # Output area
    output_area = widgets.Output()

    # Store results for copying
    search_results = []

    def on_search_clicked(b):
        nonlocal search_results
        with output_area:
            clear_output(wait=True)

            if not model:
                print("❌ Model not loaded. Please run the model initialization cell first.")
                return

            query = query_input.value.strip()
            if not query:
                print("⚠️ Please enter a search query")
                return

            # Update configuration
            CONFIG['CONFIDENCE_THRESHOLD'] = confidence_slider.value
            batch_size = batch_size_slider.value

            print(f"🚀 Starting search: '{query}' | Confidence: {CONFIG['CONFIDENCE_THRESHOLD']:.2f} | Batch: {batch_size}")
            print("-" * 50)

            # Perform search
            search_results = search_images_with_query(
                query, model, CONFIG['SUSPECTS_GALLERY_PATH'], batch_size
            )

            if search_results:
                copy_button.disabled = False
                display_results(search_results[:CONFIG['MAX_RESULTS_DISPLAY']])

                if len(search_results) > CONFIG['MAX_RESULTS_DISPLAY']:
                    print(f"\n📝 Showing first {CONFIG['MAX_RESULTS_DISPLAY']} results out of {len(search_results)} total matches")
            else:
                print("🔍 No matches found for your query")
                copy_button.disabled = True

    def on_copy_clicked(b):
        with output_area:
            if search_results:
                print("\n📋 Copying results to output folder...")
                folder, files = copy_results_to_folder(search_results, CONFIG['RESULTS_OUTPUT_PATH'])
                print(f"✅ Results saved to: {folder}")
            else:
                print("⚠️ No results to copy")

    def on_clear_clicked(b):
        nonlocal search_results
        search_results = []
        copy_button.disabled = True
        with output_area:
            clear_output()
            print("🗑️ Results cleared")

    # Connect button events
    search_button.on_click(on_search_clicked)
    copy_button.on_click(on_copy_clicked)
    clear_button.on_click(on_clear_clicked)

    # Layout
    controls = widgets.VBox([
        widgets.HTML("<h3>🔍 Forensic Image Search Interface</h3>"),
        query_input,
        widgets.HBox([confidence_slider, batch_size_slider]),
        widgets.HBox([search_button, copy_button, clear_button]),
        widgets.HTML("<hr>")
    ])

    return widgets.VBox([controls, output_area])

def display_results(results):
    """Display search results with bounding boxes"""
    if not results:
        return

    cols = 2
    rows = (len(results) + cols - 1) // cols

    fig, axes = plt.subplots(rows, cols, figsize=CONFIG['FIGURE_SIZE'])
    if rows == 1:
        axes = [axes] if cols == 1 else axes
    else:
        axes = axes.flatten()

    for i, result in enumerate(results):
        ax = axes[i] if len(results) > 1 else axes

        # Display image
        image = np.array(result['image_source'])
        ax.imshow(image)

        # Draw bounding boxes
        h, w = image.shape[:2]
        boxes = result['boxes']
        confidences = result['confidence_scores']

        for box, conf in zip(boxes, confidences):
            # Convert normalized coordinates to pixel coordinates
            x1, y1, x2, y2 = box
            x1, x2 = x1 * w, x2 * w
            y1, y2 = y1 * h, y2 * h

            # Create rectangle
            rect = patches.Rectangle(
                (x1, y1), x2 - x1, y2 - y1,
                linewidth=2, edgecolor='red', facecolor='none'
            )
            ax.add_patch(rect)

            # Add confidence text
            ax.text(x1, y1 - 5, f'{conf:.2f}',
                   color='red', fontweight='bold', fontsize=10)

        ax.set_title(f"{result['image_path'].name}\nMatches: {len(boxes)}", fontsize=10)
        ax.axis('off')

    # Hide empty subplots
    for j in range(len(results), len(axes)):
        axes[j].axis('off')

    plt.tight_layout()
    plt.show()

# Display the interface
interface = create_search_interface()
display(interface)

# Cell 6: Batch Processing Functions (Enhanced)
def batch_search_multiple_queries(queries_list, model, gallery_path, output_base_path, batch_size=8):
    """
    Process multiple queries in batch for comprehensive analysis
    """
    all_results = {}

    print(f"🚀 Starting batch analysis with {len(queries_list)} queries")
    print(f"📁 Gallery: {gallery_path}")
    print("=" * 60)

    for i, query in enumerate(queries_list, 1):
        print(f"\n[{i}/{len(queries_list)}] Query: '{query}'")

        results = search_images_with_query(query, model, gallery_path, batch_size)

        if results:
            # Create query-specific output folder
            query_folder = Path(output_base_path) / f"query_{query.replace(' ', '_').replace('/', '_')}"
            folder, files = copy_results_to_folder(results, query_folder)
            all_results[query] = {
                'results': results,
                'output_folder': folder,
                'file_count': len(files),
                'match_count': len(results)
            }
            print(f"📁 Saved {len(files)} files to: {folder.name}")
        else:
            all_results[query] = {
                'results': [],
                'output_folder': None,
                'file_count': 0,
                'match_count': 0
            }
            print("⚪ No matches found")

    # Summary report
    print("\n" + "="*60)
    print("📊 BATCH SEARCH SUMMARY")
    print("="*60)

    total_matches = 0
    total_files = 0

    for query, data in all_results.items():
        matches = data['match_count']
        files = data['file_count']
        total_matches += matches
        total_files += files

        status = "✅" if matches > 0 else "⚪"
        print(f"{status} '{query}': {matches} images, {files} files saved")

    print(f"\n🎯 TOTAL: {total_matches} matched images, {total_files} files copied")
    return all_results

# Enhanced batch processing with common forensic queries
def run_forensic_batch_analysis(custom_queries=None, batch_size=8):
    """Run comprehensive forensic analysis with predefined and custom queries"""

    # Default forensic queries
    default_queries = [
        "person with weapon",
        "person holding gun",
        "person with knife",
        "suspicious vehicle",
        "person wearing mask",
        "person running",
        "bag or backpack",
        "group of people",
        "person with phone",
        "person in dark clothing"
    ]

    # Combine with custom queries if provided
    if custom_queries:
        queries = default_queries + custom_queries
        print(f"📋 Using {len(default_queries)} default + {len(custom_queries)} custom queries")
    else:
        queries = default_queries
        print(f"📋 Using {len(default_queries)} default forensic queries")

    if model:
        print("🔍 Starting comprehensive forensic analysis...")
        batch_results = batch_search_multiple_queries(
            queries,
            model,
            CONFIG['SUSPECTS_GALLERY_PATH'],
            CONFIG['RESULTS_OUTPUT_PATH'],
            batch_size
        )
        return batch_results
    else:
        print("❌ Model not loaded. Cannot run batch analysis.")
        return None

# Quick test with reduced output
def quick_forensic_search(query="person with weapon", batch_size=8):
    """Quick single query search for testing"""
    if not model:
        print("❌ Model not loaded")
        return None

    print(f"🔍 Quick search: '{query}'")
    results = search_images_with_query(query, model, CONFIG['SUSPECTS_GALLERY_PATH'], batch_size)

    if results:
        print(f"📋 Found {len(results)} matches - ready for detailed analysis")
        return results
    else:
        print("⚪ No matches found")
        return []

# Uncomment to run batch analysis
# batch_results = run_forensic_batch_analysis(batch_size=8)

# Uncomment for quick test
# quick_results = quick_forensic_search("person holding gun", batch_size=8)

# Cell 7: Usage Instructions and Tips
print("""
🎯 FORENSIC IMAGE SEARCH SYSTEM - USAGE GUIDE
============================================

📋 SETUP CHECKLIST:
1. ✅ Place suspect images in the './suspects_gallery' folder
2. ✅ Ensure GroundingDINO model files are available
3. ✅ Run all cells in order (1-6)

🔧 CUDA/CPU COMPATIBILITY:
• System automatically detects GPU/CPU availability
• CPU mode: Slower but works on all systems
• GPU mode: Faster but requires CUDA-compatible PyTorch
• If getting CUDA errors, the system will fallback to CPU mode

🔍 SEARCH TIPS:
• Use specific, descriptive queries: "person with red jacket" instead of just "person"
• Try multiple variations: "weapon", "gun", "knife", "suspicious object"
• Adjust confidence threshold based on your needs (lower = more results, higher = more precise)
• Common forensic queries:
  - "person with weapon"
  - "suspicious vehicle"
  - "person wearing mask"
  - "bag or backpack"
  - "person running"
  - "group of people"

⚙️ CONFIGURATION:
• Modify CONFIG dictionary in Cell 1 to adjust paths and parameters
• Confidence threshold: 0.35 (recommended for forensic work)
• Results are automatically timestamped and organized

📁 OUTPUT STRUCTURE:
search_results/
├── search_20240611_143022/
│   ├── 001_suspect1_conf0.87.jpg
│   ├── 002_suspect5_conf0.72.jpg
│   └── ...

🚨 TROUBLESHOOTING:
• "CUDA not enabled" → System will auto-switch to CPU mode
• "Model not loaded" → Check model file paths in CONFIG or download model files
• "No images found" → Verify images are in suspects_gallery folder
• "Out of memory" → Reduce batch size or use CPU mode
• Low confidence scores → Try different query phrasing or lower threshold

📥 MODEL FILES:
If model files are missing, download from:
• GroundingDINO GitHub: https://github.com/IDEA-Research/GroundingDINO
• Pre-trained weights and config files needed

For batch processing of multiple queries, use the functions in Cell 6.
""")

In [1]:
# Cell 1: Configuration and Setup
import os
import shutil
import numpy as np
from pathlib import Path
import ipywidgets as widgets
from IPython.display import display, clear_output
import matplotlib.pyplot as plt
import matplotlib.patches as patches

# Configuration Settings
CONFIG = {
    # Model Selection (Hugging Face)
    'MODEL_NAME': 'IDEA-Research/grounding-dino-base',  # Options: grounding-dino-tiny, grounding-dino-base
    'DEVICE': 'auto',  # 'auto', 'cpu', 'cuda'

    # Paths
    'SUSPECTS_GALLERY_PATH': '../../datasets/images/objects/raw',  # Input folder with suspect images
    'RESULTS_OUTPUT_PATH': '../../datasets/images/objects/detections',      # Output folder for matched images

    # Detection parameters
    'CONFIDENCE_THRESHOLD': 0.35,  # Minimum confidence for detections
    'BOX_THRESHOLD': 0.35,         # Box threshold for detections
    'TEXT_THRESHOLD': 0.25,        # Text similarity threshold

    # Processing settings
    'BATCH_SIZE': 8,               # Default batch size for processing
    'MAX_RESULTS_DISPLAY': 10,     # Maximum results to display at once
    'FIGURE_SIZE': (12, 8),        # Size of result visualization
}

# Available model options
AVAILABLE_MODELS = {
    'grounding-dino-tiny': {
        'name': 'GroundingDINO-Tiny',
        'model_id': 'IDEA-Research/grounding-dino-tiny',
        'description': 'Fastest, smallest model - good for quick testing',
        'performance': 'Lower accuracy, fastest speed'
    },
    'grounding-dino-base': {
        'name': 'GroundingDINO-Base',
        'model_id': 'IDEA-Research/grounding-dino-base',
        'description': 'Best balance of speed and accuracy - recommended',
        'performance': '52.5 AP on COCO, good speed'
    }
}

print("✅ Configuration loaded successfully")
print(f"📁 Suspects gallery: {CONFIG['SUSPECTS_GALLERY_PATH']}")
print(f"📁 Results output: {CONFIG['RESULTS_OUTPUT_PATH']}")
print(f"🤖 Selected model: {CONFIG['MODEL_NAME']}")

# Cell 2: Install and Import Dependencies
# Run this cell first to install required packages
try:
    import transformers
    print("✅ Transformers already installed")
except ImportError:
    print("⚠️ Installing required packages...")
    !pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
    !pip install transformers
    !pip install ipywidgets
    !pip install Pillow
    !pip install matplotlib
    !pip install opencv-python
    print("📦 Installation complete")

# Import required libraries
try:
    import torch
    from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
    import requests
    from PIL import Image
    print("✅ All dependencies imported successfully")

    # Check PyTorch device compatibility
    print(f"🔧 PyTorch version: {torch.__version__}")
    if torch.cuda.is_available():
        print(f"🚀 CUDA available: {torch.cuda.get_device_name(0)}")
        default_device = "cuda"
    else:
        print("🖥️ Using CPU mode (CUDA not available)")
        default_device = "cpu"

    # Update config with detected device
    if CONFIG['DEVICE'] == 'auto':
        CONFIG['DEVICE'] = default_device
        print(f"📍 Auto-detected device: {CONFIG['DEVICE']}")

except ImportError as e:
    print(f"❌ Import error: {e}")
    print("🔧 Troubleshooting steps:")
    print("1. Restart kernel and run this cell again")
    print("2. Check if all packages installed correctly")

# Cell 3: Initialize Model and Directories
def setup_directories():
    """Create necessary directories if they don't exist"""
    os.makedirs(CONFIG['SUSPECTS_GALLERY_PATH'], exist_ok=True)
    os.makedirs(CONFIG['RESULTS_OUTPUT_PATH'], exist_ok=True)
    print(f"📁 Created directories: {CONFIG['SUSPECTS_GALLERY_PATH']}, {CONFIG['RESULTS_OUTPUT_PATH']}")

def load_grounding_dino_model(model_name=None):
    """Load GroundingDINO model from Hugging Face"""
    try:
        # Use provided model name or default from config
        if model_name is None:
            model_name = CONFIG['MODEL_NAME']

        device = CONFIG['DEVICE']
        print(f"📥 Loading model: {model_name}")
        print(f"🖥️ Target device: {device}")

        # Load processor and model
        print("⏳ Loading processor...")
        processor = AutoProcessor.from_pretrained(model_name)

        print("⏳ Loading model weights...")
        model = AutoModelForZeroShotObjectDetection.from_pretrained(model_name)

        # Move to device
        print(f"📍 Moving model to {device}...")
        model = model.to(device)
        model.eval()

        print(f"✅ Model loaded successfully!")
        print(f"   📊 Model: {model_name}")
        print(f"   🖥️ Device: {device}")

        return model, processor, device, model_name

    except Exception as e:
        print(f"❌ Error loading model: {e}")
        print("\n🔧 Troubleshooting steps:")
        print("1. Check internet connection (models download from Hugging Face)")
        print("2. Verify model name is correct")
        print("3. Try switching to 'grounding-dino-tiny' for faster loading")
        print("4. Restart kernel if memory issues occur")
        return None, None, None, None

def switch_model(model_key):
    """Switch to a different model variant"""
    if model_key in AVAILABLE_MODELS:
        CONFIG['MODEL_NAME'] = AVAILABLE_MODELS[model_key]['model_id']
        print(f"🔄 Switched to: {AVAILABLE_MODELS[model_key]['name']}")
        return load_grounding_dino_model()
    else:
        print(f"❌ Unknown model: {model_key}")
        print(f"Available models: {list(AVAILABLE_MODELS.keys())}")
        return None, None, None, None

# Initialize
setup_directories()
model, processor, device, model_name = load_grounding_dino_model()

# Cell 4: Core Search Functions
def process_image_batch(image_paths, model, processor, query, device, batch_size=4):
    """Process a batch of images efficiently using Hugging Face models"""
    batch_results = []

    # Process images one by one to avoid memory issues
    for i, img_path in enumerate(image_paths):
        try:
            # Progress indicator
            if i % 5 == 0:
                print(f"Processing {i+1}/{len(image_paths)}: {img_path.name[:30]}...", end='\r')

            # Load image
            image = Image.open(img_path).convert("RGB")

            # Prepare text query (ensure proper format)
            text_query = query.lower()
            if not text_query.endswith('.'):
                text_query += '.'

            # Process inputs
            inputs = processor(images=image, text=text_query, return_tensors="pt")
            inputs = inputs.to(device)

            # Run inference
            with torch.no_grad():
                outputs = model(**inputs)

            # Post-process results with updated parameter name
            results = processor.post_process_grounded_object_detection(
                outputs,
                inputs.input_ids,
                box_threshold=CONFIG['BOX_THRESHOLD'],
                text_threshold=CONFIG['TEXT_THRESHOLD'],
                target_sizes=[image.size[::-1]]  # (height, width)
            )

            # Filter by confidence threshold
            if results and len(results) > 0:
                result = results[0]  # First (and only) image in batch

                if 'scores' in result and len(result['scores']) > 0:
                    # Filter by confidence
                    high_conf_mask = result['scores'] >= CONFIG['CONFIDENCE_THRESHOLD']

                    if high_conf_mask.any():
                        filtered_boxes = result['boxes'][high_conf_mask]
                        filtered_scores = result['scores'][high_conf_mask]

                        # Use text_labels if available, otherwise fallback to labels
                        if 'text_labels' in result:
                            filtered_labels = [result['text_labels'][i] for i in range(len(result['text_labels'])) if high_conf_mask[i]]
                        else:
                            filtered_labels = [result['labels'][i] for i in range(len(result['labels'])) if high_conf_mask[i]]

                        batch_results.append({
                            'image_path': img_path,
                            'image': image,
                            'boxes': filtered_boxes,
                            'confidence_scores': filtered_scores,
                            'labels': filtered_labels,
                            'query': query
                        })

            # Clear memory after each image
            del inputs, outputs, results, image
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        except Exception as e:
            # Print error but continue processing
            print(f"\n⚠️ Error processing {img_path.name}: {str(e)[:50]}...")
            continue

        # Small break every 10 images to prevent system overload
        if i % 10 == 0 and i > 0:
            import time
            time.sleep(0.1)

    return batch_results

def search_images_with_query(query, model, processor, device, model_name, gallery_path, batch_size=4):
    """
    Search for objects in images using natural language query with Hugging Face models
    """
    results = []
    gallery_path = Path(gallery_path)

    if not gallery_path.exists():
        print(f"❌ Gallery path {gallery_path} does not exist")
        return results

    if not model or not processor:
        print("❌ Model or processor not loaded. Please check model initialization.")
        return results

    # Supported image extensions
    image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
    image_files = [f for f in gallery_path.iterdir()
                  if f.suffix.lower() in image_extensions]

    if not image_files:
        print(f"⚠️ No images found in {gallery_path}")
        return results

    total_files = len(image_files)

    print(f"🔍 Processing {total_files} images for: '{query}'")
    print(f"🖥️ Device: {device} | Processing one by one for stability")
    print(f"🤖 Model: {model_name}")

    # Process images one by one with progress tracking
    try:
        print("⏳ Starting image processing...")
        results = process_image_batch(image_files, model, processor, query, device, batch_size)

        # Show final progress
        matches_found = len(results)
        print(f"\n📊 Final: {total_files}/{total_files} processed | {matches_found} matches found")

    except Exception as e:
        print(f"❌ Processing error: {e}")
        print("💡 Try reducing batch size or switching to grounding-dino-tiny")

    # Final summary
    print(f"✅ Complete: {len(results)} images with matches found")
    return results

def copy_results_to_folder(results, output_folder):
    """Copy matched images to results folder"""
    output_path = Path(output_folder)
    output_path.mkdir(exist_ok=True)

    # Create subfolder with timestamp
    from datetime import datetime
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    search_folder = output_path / f"search_{timestamp}"
    search_folder.mkdir(exist_ok=True)

    copied_files = []

    for i, result in enumerate(results):
        try:
            source_path = result['image_path']
            # Create descriptive filename
            max_conf = float(result['confidence_scores'].max()) if len(result['confidence_scores']) > 0 else 0.0
            filename = f"{i+1:03d}_{source_path.stem}_conf{max_conf:.2f}{source_path.suffix}"
            dest_path = search_folder / filename

            shutil.copy2(source_path, dest_path)
            copied_files.append(dest_path)

        except Exception as e:
            print(f"❌ Error copying {source_path}: {e}")

    print(f"📋 Copied {len(copied_files)} files to {search_folder}")
    return search_folder, copied_files

# Cell 5: Interactive Query Interface
def create_search_interface():
    """Create interactive search interface for forensic analysts"""

    # Model selection dropdown
    model_options = [(f"{info['name']} - {info['description']}", key)
                    for key, info in AVAILABLE_MODELS.items()]

    model_selector = widgets.Dropdown(
        options=model_options,
        value='grounding-dino-base',
        description='Model:',
        style={'description_width': 'initial'},
        layout=widgets.Layout(width='500px')
    )

    # Batch size slider
    batch_size_slider = widgets.IntSlider(
        value=4,  # Reduced default for stability
        min=1,
        max=8,    # Reduced max to prevent memory issues
        step=1,
        description='Batch Size:',
        style={'description_width': 'initial'}
    )

    # Input widgets
    query_input = widgets.Text(
        value='person with weapon',
        placeholder='Enter your search query (e.g., "person with weapon", "suspicious vehicle")',
        description='Query:',
        style={'description_width': 'initial'},
        layout=widgets.Layout(width='500px')
    )

    confidence_slider = widgets.FloatSlider(
        value=CONFIG['CONFIDENCE_THRESHOLD'],
        min=0.1,
        max=0.9,
        step=0.05,
        description='Min Confidence:',
        style={'description_width': 'initial'}
    )

    search_button = widgets.Button(
        description='🔍 Search Gallery',
        button_style='primary',
        layout=widgets.Layout(width='150px')
    )

    copy_button = widgets.Button(
        description='📋 Copy Results',
        button_style='success',
        layout=widgets.Layout(width='150px'),
        disabled=True
    )

    clear_button = widgets.Button(
        description='🗑️ Clear Results',
        button_style='warning',
        layout=widgets.Layout(width='150px')
    )

    switch_model_button = widgets.Button(
        description='🔄 Switch Model',
        button_style='info',
        layout=widgets.Layout(width='150px')
    )

    # Output area
    output_area = widgets.Output()

    # Store results and current model
    search_results = []
    current_model = model
    current_processor = processor
    current_device = device
    current_model_name = model_name

    def on_model_switch_clicked(b):
        nonlocal current_model, current_processor, current_device, current_model_name
        with output_area:
            selected_model = model_selector.value
            print(f"🔄 Switching to: {AVAILABLE_MODELS[selected_model]['name']}")
            new_model, new_processor, new_device, new_model_name = switch_model(selected_model)
            if new_model and new_processor:
                current_model = new_model
                current_processor = new_processor
                current_device = new_device
                current_model_name = new_model_name
                print("✅ Model switched successfully!")
            else:
                print("❌ Failed to switch model")

    def on_search_clicked(b):
        nonlocal search_results
        with output_area:
            clear_output(wait=True)

            if not current_model or not current_processor:
                print("❌ Model not loaded. Please switch to a valid model first.")
                return

            query = query_input.value.strip()
            if not query:
                print("⚠️ Please enter a search query")
                return

            # Update configuration
            CONFIG['CONFIDENCE_THRESHOLD'] = confidence_slider.value
            batch_size = batch_size_slider.value

            print(f"🚀 Starting search: '{query}'")
            print(f"📊 Confidence: {CONFIG['CONFIDENCE_THRESHOLD']:.2f}")
            print(f"💡 Processing images individually for stability")
            print("-" * 50)

            # Perform search
            search_results = search_images_with_query(
                query, current_model, current_processor, current_device, current_model_name,
                CONFIG['SUSPECTS_GALLERY_PATH'], batch_size
            )

            if search_results:
                copy_button.disabled = False
                display_results(search_results[:CONFIG['MAX_RESULTS_DISPLAY']])

                if len(search_results) > CONFIG['MAX_RESULTS_DISPLAY']:
                    print(f"\n📝 Showing first {CONFIG['MAX_RESULTS_DISPLAY']} results out of {len(search_results)} total matches")
            else:
                print("🔍 No matches found for your query")
                copy_button.disabled = True

    def on_copy_clicked(b):
        with output_area:
            if search_results:
                print("\n📋 Copying results to output folder...")
                folder, files = copy_results_to_folder(search_results, CONFIG['RESULTS_OUTPUT_PATH'])
                print(f"✅ Results saved to: {folder}")
            else:
                print("⚠️ No results to copy")

    def on_clear_clicked(b):
        nonlocal search_results
        search_results = []
        copy_button.disabled = True
        with output_area:
            clear_output()
            print("🗑️ Results cleared")

    # Connect button events
    switch_model_button.on_click(on_model_switch_clicked)
    search_button.on_click(on_search_clicked)
    copy_button.on_click(on_copy_clicked)
    clear_button.on_click(on_clear_clicked)

    # Layout
    controls = widgets.VBox([
        widgets.HTML("<h3>🔍 Forensic Image Search Interface (Hugging Face)</h3>"),
        model_selector,
        query_input,
        widgets.HBox([confidence_slider, batch_size_slider]),
        widgets.HBox([search_button, copy_button, clear_button, switch_model_button]),
        widgets.HTML("<hr>")
    ])

    return widgets.VBox([controls, output_area])

def display_results(results):
    """Display search results with bounding boxes"""
    if not results:
        return

    cols = 2
    rows = (len(results) + cols - 1) // cols

    fig, axes = plt.subplots(rows, cols, figsize=CONFIG['FIGURE_SIZE'])
    if rows == 1:
        axes = [axes] if cols == 1 else axes
    else:
        axes = axes.flatten()

    for i, result in enumerate(results):
        ax = axes[i] if len(results) > 1 else axes

        # Display image
        image = result['image']
        ax.imshow(image)

        # Draw bounding boxes
        w, h = image.size
        boxes = result['boxes']
        confidences = result['confidence_scores']

        for box, conf in zip(boxes, confidences):
            # Convert from [x1, y1, x2, y2] to matplotlib rectangle
            x1, y1, x2, y2 = box.tolist()

            # Create rectangle
            rect = patches.Rectangle(
                (x1, y1), x2 - x1, y2 - y1,
                linewidth=2, edgecolor='red', facecolor='none'
            )
            ax.add_patch(rect)

            # Add confidence text
            ax.text(x1, y1 - 5, f'{conf:.2f}',
                   color='red', fontweight='bold', fontsize=10)

        ax.set_title(f"{result['image_path'].name}\nMatches: {len(boxes)}", fontsize=10)
        ax.axis('off')

    # Hide empty subplots
    for j in range(len(results), len(axes)):
        axes[j].axis('off')

    plt.tight_layout()
    plt.show()

# Display the interface
interface = create_search_interface()
display(interface)

# Cell 6: Batch Processing Functions (Enhanced)
def batch_search_multiple_queries(queries_list, model, processor, device, model_name, gallery_path, output_base_path, batch_size=8):
    """
    Process multiple queries in batch for comprehensive analysis
    """
    all_results = {}

    print(f"🚀 Starting batch analysis with {len(queries_list)} queries")
    print(f"📁 Gallery: {gallery_path}")
    print(f"🤖 Model: {model_name}")
    print("=" * 60)

    for i, query in enumerate(queries_list, 1):
        print(f"\n[{i}/{len(queries_list)}] Query: '{query}'")

        results = search_images_with_query(query, model, processor, device, model_name, gallery_path, batch_size)

        if results:
            # Create query-specific output folder
            query_folder = Path(output_base_path) / f"query_{query.replace(' ', '_').replace('/', '_')}"
            folder, files = copy_results_to_folder(results, query_folder)
            all_results[query] = {
                'results': results,
                'output_folder': folder,
                'file_count': len(files),
                'match_count': len(results)
            }
            print(f"📁 Saved {len(files)} files to: {folder.name}")
        else:
            all_results[query] = {
                'results': [],
                'output_folder': None,
                'file_count': 0,
                'match_count': 0
            }
            print("⚪ No matches found")

    # Summary report
    print("\n" + "="*60)
    print("📊 BATCH SEARCH SUMMARY")
    print("="*60)

    total_matches = 0
    total_files = 0

    for query, data in all_results.items():
        matches = data['match_count']
        files = data['file_count']
        total_matches += matches
        total_files += files

        status = "✅" if matches > 0 else "⚪"
        print(f"{status} '{query}': {matches} images, {files} files saved")

    print(f"\n🎯 TOTAL: {total_matches} matched images, {total_files} files copied")
    return all_results

# Enhanced batch processing with common forensic queries
def run_forensic_batch_analysis(custom_queries=None, batch_size=8, model_to_use=None):
    """Run comprehensive forensic analysis with predefined and custom queries"""

    # Use provided model or current global model
    if model_to_use:
        current_model, current_processor, current_device, current_model_name = model_to_use
    else:
        current_model, current_processor, current_device, current_model_name = model, processor, device, model_name

    # Default forensic queries
    default_queries = [
        "person with weapon",
        "person holding gun",
        "person with knife",
        "suspicious vehicle",
        "person wearing mask",
        "person running",
        "bag or backpack",
        "group of people",
        "person with phone",
        "person in dark clothing"
    ]

    # Combine with custom queries if provided
    if custom_queries:
        queries = default_queries + custom_queries
        print(f"📋 Using {len(default_queries)} default + {len(custom_queries)} custom queries")
    else:
        queries = default_queries
        print(f"📋 Using {len(default_queries)} default forensic queries")

    if current_model and current_processor:
        print("🔍 Starting comprehensive forensic analysis...")
        batch_results = batch_search_multiple_queries(
            queries,
            current_model,
            current_processor,
            current_device,
            current_model_name,
            CONFIG['SUSPECTS_GALLERY_PATH'],
            CONFIG['RESULTS_OUTPUT_PATH'],
            batch_size
        )
        return batch_results
    else:
        print("❌ Model or processor not loaded. Cannot run batch analysis.")
        return None

# Quick test with reduced output
def quick_forensic_search(query="person with weapon", batch_size=8, model_to_use=None):
    """Quick single query search for testing"""
    if model_to_use:
        current_model, current_processor, current_device, current_model_name = model_to_use
    else:
        current_model, current_processor, current_device, current_model_name = model, processor, device, model_name

    if not current_model or not current_processor:
        print("❌ Model or processor not loaded")
        return None

    print(f"🔍 Quick search: '{query}'")
    print(f"🤖 Using model: {current_model_name}")

    results = search_images_with_query(query, current_model, current_processor, current_device, current_model_name,
                                     CONFIG['SUSPECTS_GALLERY_PATH'], batch_size)

    if results:
        print(f"📋 Found {len(results)} matches - ready for detailed analysis")
        return results
    else:
        print("⚪ No matches found")
        return []

# Model switching utilities
def list_available_models():
    """Display available models with their descriptions"""
    print("🤖 Available GroundingDINO Models:")
    print("-" * 50)
    for key, info in AVAILABLE_MODELS.items():
        print(f"🔹 {info['name']} ({key})")
        print(f"   📊 {info['performance']}")
        print(f"   📝 {info['description']}")
        print()

# Display available models
list_available_models()

# Uncomment to run batch analysis
# batch_results = run_forensic_batch_analysis(batch_size=8)

# Uncomment for quick test
# quick_results = quick_forensic_search("person holding gun", batch_size=8)

# Cell 7: Usage Instructions and Tips
print("""
🎯 FORENSIC IMAGE SEARCH SYSTEM - HUGGING FACE VERSION
====================================================

🆕 NEW FEATURES:
• 🤖 Automatic model downloads from Hugging Face
• 🔄 Easy model switching between Tiny and Base variants
• 📦 No manual model file downloads required
• 🚀 Improved performance and reliability

📋 SETUP CHECKLIST:
1. ✅ Place suspect images in the './suspects_gallery' folder
2. ✅ Run all cells in order (1-6)
3. ✅ Models download automatically on first use

🤖 AVAILABLE MODELS:
• GroundingDINO-Tiny: Fastest, good for quick testing
• GroundingDINO-Base: Best balance (recommended)
• Both models run locally after download

🔧 DEVICE COMPATIBILITY:
• System automatically detects GPU/CPU availability
• Models work on both CPU and GPU
• CPU mode: Slower but works on all systems
• GPU mode: Faster with CUDA support

🔍 SEARCH TIPS:
• Use specific, descriptive queries: "person with red jacket"
• Try multiple variations: "weapon", "gun", "knife", "suspicious object"
• Adjust confidence threshold (lower = more results, higher = more precise)
• Common forensic queries:
  - "person with weapon"
  - "suspicious vehicle"
  - "person wearing mask"
  - "bag or backpack"
  - "person running"

⚙️ INTERFACE FEATURES:
• Model selector dropdown for easy switching
• Confidence and batch size sliders
• Real-time search with progress tracking
• Copy results to organized folders
• Clear visual results with bounding boxes

📁 OUTPUT STRUCTURE:
search_results/
├── search_20240611_143022/
│   ├── 001_suspect1_conf0.87.jpg
│   ├── 002_suspect5_conf0.72.jpg
│   └── ...

🚨 TROUBLESHOOTING:
• "Model loading failed" → Check internet connection (first download)
• "No images found" → Verify images are in suspects_gallery folder
• "Out of memory" → Reduce batch size or use Tiny model
• Slow performance → Try GPU if available, or reduce batch size

🔄 MODEL SWITCHING:
• Use the dropdown menu to select different models
• Click "Switch Model" to change models
• Tiny model: Faster, lower accuracy
• Base model: Better accuracy, slightly slower

For batch processing of multiple queries, use the functions in Cell 6.
""")

✅ Configuration loaded successfully
📁 Suspects gallery: ../../datasets/images/objects/raw
📁 Results output: ../../datasets/images/objects/detections
🤖 Selected model: IDEA-Research/grounding-dino-base
✅ Transformers already installed
✅ All dependencies imported successfully
🔧 PyTorch version: 2.7.1+cpu
🖥️ Using CPU mode (CUDA not available)
📍 Auto-detected device: cpu
📁 Created directories: ../../datasets/images/objects/raw, ../../datasets/images/objects/detections
📥 Loading model: IDEA-Research/grounding-dino-base
🖥️ Target device: cpu
⏳ Loading processor...
⏳ Loading model weights...
📍 Moving model to cpu...
✅ Model loaded successfully!
   📊 Model: IDEA-Research/grounding-dino-base
   🖥️ Device: cpu


VBox(children=(VBox(children=(HTML(value='<h3>🔍 Forensic Image Search Interface (Hugging Face)</h3>'), Dropdow…

🤖 Available GroundingDINO Models:
--------------------------------------------------
🔹 GroundingDINO-Tiny (grounding-dino-tiny)
   📊 Lower accuracy, fastest speed
   📝 Fastest, smallest model - good for quick testing

🔹 GroundingDINO-Base (grounding-dino-base)
   📊 52.5 AP on COCO, good speed
   📝 Best balance of speed and accuracy - recommended


🎯 FORENSIC IMAGE SEARCH SYSTEM - HUGGING FACE VERSION

🆕 NEW FEATURES:
• 🤖 Automatic model downloads from Hugging Face
• 🔄 Easy model switching between Tiny and Base variants
• 📦 No manual model file downloads required
• 🚀 Improved performance and reliability

📋 SETUP CHECKLIST:
1. ✅ Place suspect images in the './suspects_gallery' folder
2. ✅ Run all cells in order (1-6)
3. ✅ Models download automatically on first use

🤖 AVAILABLE MODELS:
• GroundingDINO-Tiny: Fastest, good for quick testing
• GroundingDINO-Base: Best balance (recommended)
• Both models run locally after download

🔧 DEVICE COMPATIBILITY:
• System automatically detects GPU/CP