# Semantic Image Search with CLIP

In this notebook, you will learn how to use **CLIP** (Contrastive Language-Image Pre-training) to search through art collections using natural language.

## What is CLIP?

CLIP is a neural network trained by OpenAI that learns to connect images and text. It can:

- **Understand images** by converting them into numerical representations (embeddings)
- **Understand text** by converting descriptions into the same embedding space
- **Match** images and text by measuring how similar their embeddings are

This allows us to search for images using natural language descriptions like:
- "a painting of a stormy sea"
- "portrait of a woman in red"
- "winter landscape with snow"
- "flowers in a vase"

## Two Modes

This notebook supports two modes:

1. **Pre-calculated Mode** (Recommended for most laptops)
   - Uses pre-computed image embeddings
   - Only needs to compute text embeddings (fast on any device)
   - Perfect for workshop settings

2. **Full Mode** (For powerful GPUs)
   - Computes image embeddings on-the-fly
   - Allows processing your own images
   - Requires CUDA GPU for reasonable speed

---

## Part 1: Setup

First, let's install and import the required libraries.

In [None]:
# Install CLIP if not already installed
# Uncomment the line below if you need to install

# !pip install git+https://github.com/openai/CLIP.git pillow torch torchvision numpy tqdm

In [None]:
import os
import json
from pathlib import Path

import numpy as np
import torch
from PIL import Image
from IPython.display import display, HTML

# Import CLIP
try:
    import clip
    print(f"CLIP loaded successfully!")
except ImportError:
    print("CLIP not installed. Please run: pip install git+https://github.com/openai/CLIP.git")

# Check for GPU
if torch.cuda.is_available():
    print(f"GPU available: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("No GPU detected. Using CPU (this will be slower for image encoding).")

In [None]:
# Set up paths
PROJECT_ROOT = Path("../").resolve()
DATA_DIR = PROJECT_ROOT / "data"
IMAGES_DIR = PROJECT_ROOT / "images"
EMBEDDINGS_FILE = DATA_DIR / "clip_embeddings.npz"

print(f"Project root: {PROJECT_ROOT}")
print(f"Embeddings file: {EMBEDDINGS_FILE}")
print(f"Embeddings exist: {EMBEDDINGS_FILE.exists()}")

---

## Part 2: Choose Your Mode

Select the mode based on your hardware:

| Mode | Requirements | Speed | Use when... |
|------|--------------|-------|-------------|
| `precalculated` | Pre-computed embeddings file | Fast | Workshop setting, any laptop |
| `full` | CUDA GPU with 4GB+ VRAM | Slower | You have a powerful GPU |

In [None]:
# ============================================================
# CONFIGURATION: Choose your mode
# ============================================================

# Options: 'precalculated' or 'full'
MODE = 'precalculated'  # <-- CHANGE THIS if you have a powerful GPU

# CLIP model to use (must match pre-calculated embeddings if using that mode)
# Options: 'ViT-B/32' (fastest), 'ViT-B/16' (better), 'ViT-L/14' (best but slowest)
MODEL_NAME = 'ViT-B/32'

# Device: 'cuda' for GPU, 'cpu' for CPU
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# ============================================================

print(f"Mode: {MODE}")
print(f"Model: {MODEL_NAME}")
print(f"Device: {DEVICE}")

---

## Part 3: Load CLIP Model

Load the CLIP model. This will download the model weights on first run (~350MB for ViT-B/32).

In [None]:
print(f"Loading CLIP model '{MODEL_NAME}' on {DEVICE}...")
model, preprocess = clip.load(MODEL_NAME, device=DEVICE)
model.eval()
print("Model loaded!")

# Get embedding dimension
with torch.no_grad():
    dummy_text = clip.tokenize(["test"]).to(DEVICE)
    dummy_embedding = model.encode_text(dummy_text)
    EMBEDDING_DIM = dummy_embedding.shape[1]
    print(f"Embedding dimension: {EMBEDDING_DIM}")

---

## Part 4: Load or Compute Image Embeddings

Depending on your mode, either load pre-calculated embeddings or compute them from images.

In [None]:
# Variables to store embeddings and filenames
image_embeddings = None
image_filenames = None
images_base_dir = None

if MODE == 'precalculated':
    # Load pre-calculated embeddings
    if not EMBEDDINGS_FILE.exists():
        print(f"ERROR: Pre-calculated embeddings not found at {EMBEDDINGS_FILE}")
        print("Please ask the workshop instructor for the embeddings file,")
        print("or switch to MODE = 'full' if you have a GPU.")
    else:
        print(f"Loading pre-calculated embeddings from {EMBEDDINGS_FILE}...")
        data = np.load(EMBEDDINGS_FILE, allow_pickle=True)
        
        image_embeddings = data['embeddings']
        image_filenames = data['filenames']
        saved_model = str(data.get('model_name', 'unknown'))
        
        print(f"Loaded {len(image_filenames)} image embeddings")
        print(f"Embeddings shape: {image_embeddings.shape}")
        print(f"Embeddings were computed with model: {saved_model}")
        
        if saved_model != MODEL_NAME:
            print(f"WARNING: Current model ({MODEL_NAME}) differs from saved ({saved_model})!")
            print("Results may be less accurate. Consider using the same model.")
        
        # Convert to torch tensor for faster computation
        image_embeddings = torch.tensor(image_embeddings, dtype=torch.float32).to(DEVICE)
        
        # Load the index to find the images directory
        index_file = EMBEDDINGS_FILE.with_suffix('.json')
        if index_file.exists():
            with open(index_file) as f:
                index_data = json.load(f)
            print(f"Index loaded: {index_data.get('num_images', '?')} images")

In [None]:
if MODE == 'full':
    from tqdm.notebook import tqdm
    
    # Set the images directory
    # Change this to your downloaded images folder
    images_base_dir = IMAGES_DIR / "all_images"  # <-- CHANGE THIS to your images folder
    
    if not images_base_dir.exists():
        print(f"Images directory not found: {images_base_dir}")
        print("Please download images first using the 01_api_and_data notebook,")
        print("or update the path above.")
    else:
        # Find all images
        image_files = []
        for ext in ['.jpg', '.jpeg', '.png', '.webp']:
            image_files.extend(images_base_dir.rglob(f'*{ext}'))
            image_files.extend(images_base_dir.rglob(f'*{ext.upper()}'))
        image_files = sorted(set(image_files))
        
        print(f"Found {len(image_files)} images")
        
        if len(image_files) > 0:
            # Compute embeddings
            print("Computing image embeddings (this may take a while)...")
            
            all_embeddings = []
            all_filenames = []
            batch_size = 32 if DEVICE == 'cuda' else 8
            
            for i in tqdm(range(0, len(image_files), batch_size)):
                batch_files = image_files[i:i + batch_size]
                batch_images = []
                batch_names = []
                
                for img_path in batch_files:
                    try:
                        image = Image.open(img_path).convert('RGB')
                        image_tensor = preprocess(image)
                        batch_images.append(image_tensor)
                        batch_names.append(str(img_path.relative_to(images_base_dir)))
                    except Exception as e:
                        continue
                
                if batch_images:
                    batch_tensor = torch.stack(batch_images).to(DEVICE)
                    with torch.no_grad():
                        embeddings = model.encode_image(batch_tensor)
                        embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
                    all_embeddings.append(embeddings)
                    all_filenames.extend(batch_names)
            
            image_embeddings = torch.cat(all_embeddings, dim=0)
            image_filenames = np.array(all_filenames)
            
            print(f"Computed embeddings for {len(image_filenames)} images")
            print(f"Embeddings shape: {image_embeddings.shape}")

---

## Part 5: Text Embedding and Semantic Search

Now the fun part! We'll create functions to:
1. Convert your search query into an embedding
2. Find the most similar images using cosine similarity

### Understanding Cosine Similarity

Cosine similarity measures how similar two vectors are by calculating the cosine of the angle between them:

- **1.0** = Identical direction (very similar)
- **0.0** = Perpendicular (unrelated)
- **-1.0** = Opposite direction (very different)

CLIP embeddings are normalized, so we can compute cosine similarity as a simple dot product.

In [None]:
def encode_text(text_query: str) -> torch.Tensor:
    """
    Convert a text query into a CLIP embedding.
    
    Args:
        text_query: Natural language description
    
    Returns:
        Normalized embedding tensor
    """
    with torch.no_grad():
        text_tokens = clip.tokenize([text_query]).to(DEVICE)
        text_embedding = model.encode_text(text_tokens)
        # Normalize
        text_embedding = text_embedding / text_embedding.norm(dim=-1, keepdim=True)
    return text_embedding


def encode_text_ensemble(text_queries: list) -> torch.Tensor:
    """
    Encode multiple text queries and average their embeddings.
    
    This technique (prompt ensembling) often gives better results
    than using a single query.
    
    Args:
        text_queries: List of related descriptions
    
    Returns:
        Averaged, normalized embedding tensor
    """
    with torch.no_grad():
        text_tokens = clip.tokenize(text_queries).to(DEVICE)
        text_embeddings = model.encode_text(text_tokens)
        # Average the embeddings
        mean_embedding = text_embeddings.mean(dim=0, keepdim=True)
        # Normalize
        mean_embedding = mean_embedding / mean_embedding.norm(dim=-1, keepdim=True)
    return mean_embedding


def search_images(text_embedding: torch.Tensor, top_k: int = 10) -> list:
    """
    Find the most similar images to a text embedding.
    
    Args:
        text_embedding: The query embedding
        top_k: Number of results to return
    
    Returns:
        List of (filename, similarity_score) tuples
    """
    # Compute cosine similarity (dot product of normalized vectors)
    similarities = (image_embeddings @ text_embedding.T).squeeze()
    
    # Get top-k indices
    top_indices = similarities.argsort(descending=True)[:top_k]
    
    results = []
    for idx in top_indices:
        filename = image_filenames[idx.item()] if isinstance(image_filenames[idx.item()], str) else str(image_filenames[idx.item()])
        score = similarities[idx].item()
        results.append((filename, score))
    
    return results


def semantic_search(query: str, top_k: int = 10, use_ensemble: bool = False) -> list:
    """
    Perform semantic search using a natural language query.
    
    Args:
        query: Natural language search query
        top_k: Number of results to return
        use_ensemble: If True, create variations of the query for better results
    
    Returns:
        List of (filename, similarity_score) tuples
    """
    if use_ensemble:
        # Create query variations for better matching
        queries = [
            query,
            f"a painting of {query}",
            f"an artwork depicting {query}",
            f"a photo of {query}",
            f"{query}, fine art"
        ]
        text_embedding = encode_text_ensemble(queries)
    else:
        text_embedding = encode_text(query)
    
    return search_images(text_embedding, top_k)

---

## Part 6: Display Search Results

Helper functions to display images from search results.

In [None]:
def find_image_path(filename: str) -> Path:
    """
    Find the full path to an image file.
    
    Searches in common locations based on the mode.
    """
    # Try different base directories
    possible_bases = [
        images_base_dir,
        IMAGES_DIR,
        IMAGES_DIR / "all_images",
        PROJECT_ROOT / "downloaded_data"
    ]
    
    for base in possible_bases:
        if base is None:
            continue
        full_path = base / filename
        if full_path.exists():
            return full_path
    
    # Try searching recursively
    for base in possible_bases:
        if base is None or not base.exists():
            continue
        matches = list(base.rglob(Path(filename).name))
        if matches:
            return matches[0]
    
    return None


def display_results(results: list, max_display: int = 5, show_scores: bool = True):
    """
    Display search results as images.
    
    Args:
        results: List of (filename, score) tuples from search
        max_display: Maximum number of images to show
        show_scores: Whether to display similarity scores
    """
    displayed = 0
    
    for filename, score in results[:max_display]:
        img_path = find_image_path(filename)
        
        print(f"\n{'='*60}")
        if show_scores:
            print(f"Similarity: {score:.4f}")
        print(f"File: {filename}")
        
        if img_path and img_path.exists():
            try:
                img = Image.open(img_path)
                # Resize for display
                max_size = 500
                ratio = min(max_size / img.width, max_size / img.height)
                if ratio < 1:
                    new_size = (int(img.width * ratio), int(img.height * ratio))
                    img = img.resize(new_size, Image.Resampling.LANCZOS)
                display(img)
                displayed += 1
            except Exception as e:
                print(f"Could not display image: {e}")
        else:
            print(f"Image file not found locally: {filename}")
            print("(Image embeddings exist but images may not be downloaded)")
    
    if displayed == 0:
        print("\nNo images could be displayed.")
        print("Make sure you have downloaded the images using the 01_api_and_data notebook.")

---

## Exercise: Semantic Image Search

**Your task:** Change the `SEARCH_QUERY` variable to search for different subjects!

### Tips for Better Queries

| Instead of... | Try... | Why |
|---------------|--------|-----|
| `water` | `shoreline with waves` | More specific, less ambiguous |
| `woman` | `portrait of a young woman` | Provides context |
| `forest` | `dense pine forest in summer` | Adds detail |
| `sad` | `melancholic expression, somber mood` | Describes the feeling |

### Suggested Queries to Try

**Nature:**
- `"a lake surrounded by mountains"`
- `"stormy sea with waves"`
- `"birch trees in autumn"`
- `"snowy winter landscape"`

**People:**
- `"portrait of an elderly man"`
- `"children playing outdoors"`
- `"peasants working in a field"`
- `"woman reading a book"`

**Subjects:**
- `"flowers in a vase, still life"`
- `"sailing ships at sea"`
- `"mythological scene"`
- `"interior of a church"`

In [None]:
# ============================================================
# EXERCISE: Change the search query below!
# ============================================================

SEARCH_QUERY = "shoreline with waves and beach"  # <-- CHANGE THIS!

# Number of results to return
TOP_K = 10

# Use prompt ensembling for potentially better results
USE_ENSEMBLE = True

# ============================================================

print(f"Searching for: '{SEARCH_QUERY}'")
print(f"Using ensemble: {USE_ENSEMBLE}")
print()

In [None]:
# Perform the search
if image_embeddings is not None:
    results = semantic_search(SEARCH_QUERY, top_k=TOP_K, use_ensemble=USE_ENSEMBLE)
    
    print(f"Top {len(results)} results:")
    print()
    for i, (filename, score) in enumerate(results, 1):
        print(f"{i:2}. {score:.4f} - {filename}")
else:
    print("No image embeddings loaded. Please check the setup above.")

In [None]:
# Display the results
if image_embeddings is not None and results:
    display_results(results, max_display=5)

---

## Part 7: Advanced - Custom Prompt Engineering

For more control, you can create your own list of prompts. This is useful when:
- A single query doesn't capture what you're looking for
- You want to combine multiple concepts
- You want to exclude certain things (though CLIP doesn't support negation well)

In [None]:
# ============================================================
# ADVANCED: Create custom prompts
# ============================================================

# Define multiple prompts that describe what you're looking for
CUSTOM_PROMPTS = [
    "a painting of water near the shore",
    "beach scene with ocean waves",
    "coastal landscape with sea",
    "shoreline at sunset",
    "rocky beach with water"
]

# ============================================================

if image_embeddings is not None:
    print("Using custom prompts:")
    for p in CUSTOM_PROMPTS:
        print(f"  - {p}")
    print()
    
    # Encode and average the prompts
    custom_embedding = encode_text_ensemble(CUSTOM_PROMPTS)
    
    # Search
    custom_results = search_images(custom_embedding, top_k=10)
    
    print(f"Top 10 results:")
    for i, (filename, score) in enumerate(custom_results, 1):
        print(f"{i:2}. {score:.4f} - {filename}")

In [None]:
# Display custom search results
if image_embeddings is not None:
    display_results(custom_results, max_display=5)

---

## Part 8: Compare Different Queries

See how different phrasings affect search results.

In [None]:
# Compare different ways to search for the same concept
QUERIES_TO_COMPARE = [
    "water",
    "sea",
    "ocean waves",
    "shoreline with beach"
]

if image_embeddings is not None:
    print("Comparing queries - showing top result for each:\n")
    
    for query in QUERIES_TO_COMPARE:
        results = semantic_search(query, top_k=1, use_ensemble=False)
        if results:
            filename, score = results[0]
            print(f"'{query}'")
            print(f"  -> {score:.4f} - {filename}")
            print()

---

## Part 9: Save Search Results

Export your search results for further analysis.

In [None]:
def save_search_results(results: list, query: str, output_dir: Path = None):
    """
    Save search results to a JSON file.
    """
    if output_dir is None:
        output_dir = DATA_DIR
    
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Create safe filename from query
    safe_query = "".join(c for c in query if c.isalnum() or c in ' _-')[:50]
    filename = f"search_results_{safe_query.replace(' ', '_')}.json"
    
    output_path = output_dir / filename
    
    data = {
        'query': query,
        'model': MODEL_NAME,
        'results': [
            {'filename': f, 'similarity': s}
            for f, s in results
        ]
    }
    
    with open(output_path, 'w') as f:
        json.dump(data, f, indent=2)
    
    print(f"Saved {len(results)} results to {output_path}")
    return output_path

In [None]:
# Save your search results
if image_embeddings is not None and 'results' in dir():
    save_search_results(results, SEARCH_QUERY)

---

## Part 10: Batch Search and Filtering

Find images matching a query and copy them to a folder (like the original experiment).

In [None]:
import shutil

def filter_images_by_query(
    query: str,
    output_dir: Path,
    threshold: float = 0.25,
    max_images: int = 100,
    use_ensemble: bool = True,
    copy_files: bool = True
):
    """
    Find and optionally copy all images matching a query above a threshold.
    
    Args:
        query: Search query
        output_dir: Directory to copy matching images to
        threshold: Minimum similarity score (0-1)
        max_images: Maximum number of images to return
        use_ensemble: Use prompt ensembling
        copy_files: Whether to copy files or just list them
    
    Returns:
        List of matching (filename, score) tuples
    """
    # Get all results above threshold
    results = semantic_search(query, top_k=len(image_filenames), use_ensemble=use_ensemble)
    
    # Filter by threshold
    matches = [(f, s) for f, s in results if s >= threshold][:max_images]
    
    print(f"Found {len(matches)} images with similarity >= {threshold}")
    
    if copy_files and matches:
        output_dir = Path(output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)
        
        copied = 0
        for filename, score in matches:
            src_path = find_image_path(filename)
            if src_path and src_path.exists():
                dst_path = output_dir / Path(filename).name
                shutil.copy2(src_path, dst_path)
                copied += 1
        
        print(f"Copied {copied} images to {output_dir}")
    
    return matches

In [None]:
# ============================================================
# EXERCISE: Filter and collect images
# ============================================================

FILTER_QUERY = "shoreline with waves and beach"  # <-- CHANGE THIS!
SIMILARITY_THRESHOLD = 0.25  # Adjust: higher = stricter matching
MAX_IMAGES = 50
OUTPUT_FOLDER = IMAGES_DIR / "filtered_shorelines"  # <-- CHANGE THIS!

# Set to True to actually copy files, False to just preview
COPY_FILES = False  # <-- Change to True when ready

# ============================================================

if image_embeddings is not None:
    matches = filter_images_by_query(
        query=FILTER_QUERY,
        output_dir=OUTPUT_FOLDER,
        threshold=SIMILARITY_THRESHOLD,
        max_images=MAX_IMAGES,
        copy_files=COPY_FILES
    )
    
    print(f"\nTop matches:")
    for f, s in matches[:10]:
        print(f"  {s:.4f} - {f}")

---

## Summary

In this notebook, you learned:

1. **What CLIP is** and how it connects images and text
2. **How to use pre-calculated embeddings** for fast search on any device
3. **How cosine similarity works** for measuring image-text similarity
4. **Prompt engineering techniques** for better search results
5. **How to filter large image collections** using semantic search

### Key Takeaways

- **More specific queries** generally give better results than single words
- **Prompt ensembling** (averaging multiple descriptions) improves robustness
- **Threshold tuning** is important - start low and increase until results look good
- CLIP understands **visual concepts** beyond what metadata keywords capture

### Limitations

- CLIP doesn't handle **negation** well ("not a portrait" won't work)
- Results depend on **training data** (CLIP was trained on internet images)
- **Fine-grained distinctions** (e.g., specific art styles) may be challenging

### Next Steps

- Try different CLIP models (ViT-L/14 is more accurate but slower)
- Combine CLIP search with metadata filtering
- Use CLIP for image clustering or visualization