# Multimodal Search Engine: Data Processing Pipeline

This notebook is designed to run on **Google Colab** to leverage GPU acceleration for processing images and text.

**Objective:**
1.  Download the Unsplash Lite dataset.
2.  Generate embeddings for images using **two models**:
    *   **SIGLIP**: For text-to-image semantic search.
    *   **DINOv2**: For image-to-image visual similarity search.
3.  Store the embeddings and metadata in a **Qdrant** vector database with two collections:
    *   `text_visual_index` (SIGLIP)
    *   `pure_visual_index` (DINOv2)
4.  Save the Qdrant storage to Google Drive for later use in the backend application.

**Important:**
After running this notebook, you will have a `qdrant_storage` folder (zipped as `qdrant_checkpoint.zip`) in your Google Drive. You should download this file and extract it into your project's `backend/` directory to serve the search engine.

In [None]:
import google.colab.drive as drive

# Mount Google Drive to save checkpoints and the final database
drive.mount("./drive")

## 1. Setup Environment
Install necessary libraries and download the dataset.

In [None]:
# Install dependencies
!pip install qdrant-client sentence-transformers pandas tqdm -q

# Download Unsplash Lite Dataset
!curl -O https://unsplash-datasets.s3.amazonaws.com/lite/latest/unsplash-research-dataset-lite-latest.zip

# Unzip the dataset
!unzip -d ./unsplash-research-dataset-lite-latest ./unsplash-research-dataset-lite-latest.zip

## 2. Load Data & Model
Load the dataset metadata into Pandas and initialize the SIGLIP model for embedding generation.

In [None]:
import numpy as np
import pandas as pd
import glob
from transformers import AutoProcessor, AutoModel, AutoImageProcessor
import torch

# --- Load Dataset Metadata ---
path = './unsplash-research-dataset-lite-latest/'
documents = ['photos', 'keywords', 'collections', 'conversions', 'colors']
datasets = {}

# Read the photos CSV
files = glob.glob(path + documents[0] + ".csv*")
subsets = []
for filename in files:
    df = pd.read_csv(filename, sep='\t', header=0)
    subsets.append(df)

datasets[documents[0]] = pd.concat(subsets, axis=0, ignore_index=True)
print(f"Loaded {len(datasets['photos'])} photo records.")

# --- Load Models ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 1. SIGLIP (Text-to-Image)
print("Loading SIGLIP model...")
siglip_model = AutoModel.from_pretrained("google/siglip-base-patch16-256").to(device)
siglip_processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-256")
siglip_dim = siglip_model.config.vision_config.hidden_size

# 2. DINOv2 (Image-to-Image)
print("Loading DINOv2 model...")
dino_processor = AutoImageProcessor.from_pretrained('facebook/dinov2-large')
dino_model = AutoModel.from_pretrained('facebook/dinov2-large').to(device)
dino_dim = dino_model.config.hidden_size

print(f"Models loaded.")
print(f"SIGLIP Dimension: {siglip_dim}")
print(f"DINOv2 Dimension: {dino_dim}")

## 3. Initialize Vector Database
Set up Qdrant and define helper functions for checkpointing (saving progress to Drive) and downloading images.

In [None]:
import aiohttp
import asyncio
from qdrant_client import QdrantClient, models
import shutil
import os
from tqdm.notebook import tqdm
from PIL import Image
from io import BytesIO

# --- Configuration ---
QDRANT_PATH = "./qdrant_storage"
DRIVE_BACKUP_PATH = "/content/drive/My Drive/qdrant_backups"
CHECKPOINT_ZIP_FILE = "qdrant_checkpoint.zip"
DRIVE_BACKUP_FILE = os.path.join(DRIVE_BACKUP_PATH, CHECKPOINT_ZIP_FILE)
TEMP_ZIP_PATH = "./qdrant_checkpoint_temp"

COLLECTION_SIGLIP = "text_visual_index"
COLLECTION_DINO = "pure_visual_index"

# Ensure backup directory exists
os.makedirs(DRIVE_BACKUP_PATH, exist_ok=True)

# --- Helper: Save Checkpoint ---
def save_checkpoint(client, drive_path):
    tqdm.write(f"\n--- Checkpointing: Saving database to {drive_path} ---")
    try:
        client.close()
        shutil.make_archive(base_name=TEMP_ZIP_PATH, format='zip', root_dir=".", base_dir=QDRANT_PATH)
        shutil.move(f"{TEMP_ZIP_PATH}.zip", drive_path)
        tqdm.write(f"--- Checkpoint successful: {CHECKPOINT_ZIP_FILE} updated in Drive. ---")
    except Exception as e:
        tqdm.write(f"[Warning] Checkpoint failed: {e}")
    finally:
        return QdrantClient(path=QDRANT_PATH)

# --- Helper: Download Image ---
HEADERS = {"User-Agent": "Mozilla/5.0"}
TOTAL_DOWNLOAD_TIMEOUT = 15.0

async def download_image(session, row, semaphore):
    async with semaphore:
        try:
            async with session.get(row['photo_image_url'], timeout=5.0, headers=HEADERS) as response:
                response.raise_for_status()
                content = await response.read()
                img = Image.open(BytesIO(content)).convert("RGB")
                return (img, row)
        except Exception as e:
            return (None, row)

# --- Restore or Initialize Database ---
print("--- Initializing Database ---")
start_index = 0

if os.path.exists(DRIVE_BACKUP_FILE):
    print(f"Found checkpoint at {DRIVE_BACKUP_FILE}. Restoring...")
    shutil.rmtree(QDRANT_PATH, ignore_errors=True)
    shutil.copy(DRIVE_BACKUP_FILE, "local_checkpoint.zip")
    shutil.unpack_archive("local_checkpoint.zip", ".")
    os.remove("local_checkpoint.zip")
    print("Restore complete.")
else:
    print("No checkpoint found. Starting fresh.")

client = QdrantClient(path=QDRANT_PATH)

# 1. Setup SIGLIP Collection
if client.collection_exists(COLLECTION_SIGLIP):
    info = client.get_collection(COLLECTION_SIGLIP)
    start_index = info.points_count
    print(f"Resuming SIGLIP collection from index: {start_index}")
else:
    print(f"Creating collection '{COLLECTION_SIGLIP}'...")
    client.create_collection(
        collection_name=COLLECTION_SIGLIP,
        vectors_config=models.VectorParams(size=siglip_dim, distance=models.Distance.COSINE)
    )

# 2. Setup DINOv2 Collection
if client.collection_exists(COLLECTION_DINO):
    info = client.get_collection(COLLECTION_DINO)
    print(f"DINOv2 collection found with {info.points_count} points.")
else:
    print(f"Creating collection '{COLLECTION_DINO}'...")
    client.create_collection(
        collection_name=COLLECTION_DINO,
        vectors_config=models.VectorParams(size=dino_dim, distance=models.Distance.COSINE)
    )

print("Database initialized.")

## 4. Ingestion Pipeline
This loop processes images in batches:
1.  Downloads images concurrently.
2.  Encodes them using SIGLIP (on GPU).
3.  Upserts vectors and metadata to Qdrant.
4.  Periodically saves checkpoints to Drive.

In [None]:
CONCURRENCY_LIMIT = 10
BATCH_SIZE = 64
CHECKPOINT_EVERY_N_BATCHES = 5
batch_counter = 0

total_remaining = len(datasets["photos"]) - start_index
num_batches = (total_remaining // BATCH_SIZE) + 1

print(f"Starting ingestion from index {start_index}. ({total_remaining} images remaining)")

main_progress_bar = tqdm(range(start_index, len(datasets["photos"]), BATCH_SIZE), total=num_batches, desc="Total Batches")

async def ingestion_loop():
    global client, batch_counter
    semaphore = asyncio.Semaphore(CONCURRENCY_LIMIT)

    async with aiohttp.ClientSession() as session:
        for i in main_progress_bar:
            batch_df = datasets["photos"].iloc[i:i + BATCH_SIZE]
            if batch_df.empty: continue

            # 1. Download
            tasks = [download_image(session, row, semaphore) for _, row in batch_df.iterrows()]
            
            images_to_process = []
            rows_to_process = []
            
            for future in asyncio.as_completed(tasks):
                img, row = await future
                if img:
                    images_to_process.append(img)
                    rows_to_process.append(row)

            if not images_to_process: continue

            # 2. Encode with SIGLIP
            with torch.no_grad():
                inputs_siglip = siglip_processor(images=images_to_process, padding="max_length", return_tensors="pt").to(device)
                vectors_siglip = siglip_model.get_image_features(**inputs_siglip).cpu().numpy()

            # 3. Encode with DINOv2
            with torch.no_grad():
                inputs_dino = dino_processor(images=images_to_process, return_tensors="pt").to(device)
                outputs_dino = dino_model(**inputs_dino)
                # Average pool the last hidden state
                vectors_dino = outputs_dino.last_hidden_state.mean(dim=1).cpu().numpy()

            # 4. Prepare Points
            points_siglip = []
            points_dino = []
            
            for idx, row in enumerate(rows_to_process):
                payload = {
                    "url": row['photo_image_url'],
                    "description": str(row['photo_description']),
                    "unsplash_id": row['photo_id']
                }
                
                # SIGLIP Point
                points_siglip.append(models.PointStruct(
                    id=row.name, 
                    vector=vectors_siglip[idx].tolist(), 
                    payload=payload
                ))
                
                # DINOv2 Point (Same ID, Same Payload)
                points_dino.append(models.PointStruct(
                    id=row.name, 
                    vector=vectors_dino[idx].tolist(), 
                    payload=payload
                ))

            # 5. Upsert to Both Collections
            if points_siglip:
                client.upsert(collection_name=COLLECTION_SIGLIP, points=points_siglip, wait=False)
            
            if points_dino:
                client.upsert(collection_name=COLLECTION_DINO, points=points_dino, wait=False)

            batch_counter += 1

            # 6. Checkpoint
            if batch_counter % CHECKPOINT_EVERY_N_BATCHES == 0:
                client = save_checkpoint(client, DRIVE_BACKUP_FILE)

# Run the loop
await ingestion_loop()

# Final Save
print("Ingestion finished. Saving final checkpoint...")
client = save_checkpoint(client, DRIVE_BACKUP_FILE)
client.close()
print(f"Database saved to: {DRIVE_BACKUP_FILE}")

## 5. Test Search
Verify the database works by running a text query.

In [None]:
# --- Test Script ---
search_query = "a lion facing the camera"

print(f"Loading database from {QDRANT_PATH}...")
client = QdrantClient(path=QDRANT_PATH)

# Verify Counts
info_siglip = client.get_collection(COLLECTION_SIGLIP)
info_dino = client.get_collection(COLLECTION_DINO)
print(f"SIGLIP Collection Count: {info_siglip.points_count}")
print(f"DINOv2 Collection Count: {info_dino.points_count}")

# Test Text Search (SIGLIP)
print(f"\nEncoding query: '{search_query}'")
with torch.no_grad():
    inputs = siglip_processor(text=[search_query], padding="max_length", return_tensors="pt").to(device)
    query_vector = siglip_model.get_text_features(**inputs).cpu().numpy()[0]

print("Searching SIGLIP collection...")
search_results = client.search(collection_name=COLLECTION_SIGLIP, query_vector=query_vector, limit=5)

print("\n--- TOP 5 RESULTS (Text-to-Image) ---")
for i, result in enumerate(search_results):
    print(f"\nResult {i+1} (Score: {result.score:.4f}):")
    if result.payload:
        print(f"  URL: {result.payload.get('url')}")
        print(f"  Desc: {result.payload.get('description')}")

client.close()

## 6. Test Image-to-Image Search (Visual Similarity)
Now we test the DINOv2 index. We will take an image and find other images that look visually similar (shapes, colors, composition).

In [None]:
import requests

# 1. Pick a query image
# We'll just take the first image from our dataframe as the "query"
query_row = datasets['photos'].iloc[0]
query_image_url = query_row['photo_image_url']

print(f"Query Image URL: {query_image_url}")
print(f"Description: {query_row['photo_description']}")

# 2. Download the image
response = requests.get(query_image_url, headers=HEADERS)
query_img = Image.open(BytesIO(response.content)).convert("RGB")

# 3. Encode with DINOv2
print("Encoding query image with DINOv2...")
with torch.no_grad():
    inputs = dino_processor(images=query_img, return_tensors="pt").to(device)
    outputs = dino_model(**inputs)
    # Average pool to get the embedding
    query_vector_dino = outputs.last_hidden_state.mean(dim=1).cpu().numpy()[0]

# 4. Search the 'pure_visual_index'
print("Searching DINOv2 collection...")
client = QdrantClient(path=QDRANT_PATH) # Re-open client just in case
search_results_visual = client.search(
    collection_name=COLLECTION_DINO,
    query_vector=query_vector_dino,
    limit=5
)

# 5. Display Results
print("\n--- TOP 5 VISUALLY SIMILAR IMAGES ---")
for i, result in enumerate(search_results_visual):
    print(f"\nResult {i+1} (Score: {result.score:.4f}):")
    if result.payload:
        print(f"  URL: {result.payload.get('url')}")
        # print(f"  Desc: {result.payload.get('description')}")

client.close()