# Photo to CLIP Embedding Converter

This notebook allows you to select images from a folder and convert them to CLIP embeddings.
The embeddings are saved as JSON files for use in your Flux pipeline.

In [None]:
# Install required packages (uncomment if needed)
# !pip install torch torchvision transformers pillow ipywidgets --break-system-packages

In [None]:
import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import json
from pathlib import Path
import ipywidgets as widgets
from IPython.display import display, clear_output
import os

In [None]:
# Configuration
current_dir = Path(os.getcwd())
output_dir = current_dir.parent / "data" / "embeddings" / "CLIP"
output_dir.mkdir(parents=True, exist_ok=True)

print(f"Output directory: {output_dir}")
print(f"Current directory: {current_dir}")

In [None]:
# Load CLIP model
print("Loading CLIP model...")
model_name = "openai/clip-vit-large-patch14"
model = CLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)
print("Model loaded successfully!")

In [None]:
def get_image_files(directory):
    """Get all image files from a directory"""
    image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp'}
    directory = Path(directory)
    
    if not directory.exists():
        return []
    
    image_files = []
    for file in directory.iterdir():
        if file.is_file() and file.suffix.lower() in image_extensions:
            image_files.append(file)
    
    return sorted(image_files)


def extract_clip_image_embedding(image_path):
    """
    Extract CLIP image embeddings from an image file.
    Returns embeddings with shape [77, 768] to match text embedding format.
    
    The ViT-Large vision encoder produces 1024-dimensional hidden states,
    which are then projected to 768 dimensions to match the text encoder.
    
    Args:
        image_path: Path to the image file
    
    Returns:
        Dictionary with embedding data
    """
    # Load image
    image = Image.open(image_path).convert('RGB')
    
    # Process the image
    inputs = processor(images=image, return_tensors="pt")
    
    # Get image embeddings
    with torch.no_grad():
        # Get vision encoder hidden states [1, 257, 1024]
        vision_outputs = model.vision_model(**inputs)
        image_embeds = vision_outputs.last_hidden_state  # [1, 257, 1024]
        image_embeds = image_embeds.squeeze(0)  # [257, 1024]
        
        # Apply visual projection to get 768 dimensions [257, 1024] -> [257, 768]
        image_embeds = model.visual_projection(image_embeds)  # [257, 768]
        
        # Truncate to 77 tokens to match text embedding format
        # Take the first 77 tokens (including CLS token and first 76 patch tokens)
        image_embeds = image_embeds[:77, :]  # [77, 768]
    
    # Convert to list for JSON serialization
    embedding_list = image_embeds.cpu().numpy().tolist()
    
    return {
        "prompt": str(image_path.stem),  # Use image filename without extension
        "embedding": embedding_list,
        "shape": [77, 768]
    }


def save_embedding_json(embedding_data, output_path):
    """Save embedding data to JSON file"""
    with open(output_path, 'w') as f:
        json.dump(embedding_data, f)
    print(f"Saved: {output_path}")

In [None]:
# Widget for selecting image directory
image_dir_input = widgets.Text(
    value=str(current_dir.parent / "data" / "input_img"),
    placeholder='Enter path to image folder',
    description='Image Folder:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='600px')
)

load_button = widgets.Button(
    description='Load Images',
    button_style='info',
    tooltip='Load images from the specified folder'
)

image_selector = widgets.Dropdown(
    options=[],
    description='Select Image:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='600px'),
    disabled=True
)

preview_image = widgets.Image(
    format='png',
    width=400,
    height=400
)

convert_button = widgets.Button(
    description='Convert to CLIP Embedding',
    button_style='success',
    tooltip='Extract CLIP embedding and save as JSON',
    disabled=True
)

output_text = widgets.Output()

# Current state
state = {
    'image_files': [],
    'current_image': None
}


def on_load_images(b):
    """Load images from the specified directory"""
    with output_text:
        clear_output()
        image_dir = Path(image_dir_input.value)
        
        if not image_dir.exists():
            print(f"‚ùå Directory not found: {image_dir}")
            image_selector.options = []
            image_selector.disabled = True
            convert_button.disabled = True
            return
        
        image_files = get_image_files(image_dir)
        
        if not image_files:
            print(f"‚ùå No images found in: {image_dir}")
            image_selector.options = []
            image_selector.disabled = True
            convert_button.disabled = True
            return
        
        state['image_files'] = image_files
        image_selector.options = [(f.name, f) for f in image_files]
        image_selector.disabled = False
        convert_button.disabled = False
        
        print(f"‚úÖ Loaded {len(image_files)} images from {image_dir}")


def on_image_selected(change):
    """Update preview when image is selected"""
    if change['new'] is None:
        return
    
    image_path = change['new']
    state['current_image'] = image_path
    
    # Load and display preview
    with open(image_path, 'rb') as f:
        preview_image.value = f.read()
    
    with output_text:
        print(f"Selected: {image_path.name}")


def on_convert(b):
    """Convert selected image to CLIP embedding"""
    with output_text:
        clear_output()
        
        if state['current_image'] is None:
            print("‚ùå No image selected")
            return
        
        image_path = state['current_image']
        
        print(f"üîÑ Processing: {image_path.name}")
        print(f"   Output shape: [77, 768]")
        
        try:
            # Extract embedding
            embedding_data = extract_clip_image_embedding(image_path)
            
            # Generate output filename
            input_stem = image_path.stem  # filename without extension
            output_filename = f"{input_stem}_from_image.json"
            output_path = output_dir / output_filename
            
            # Save to JSON
            save_embedding_json(embedding_data, output_path)
            
            print(f"‚úÖ Success!")
            print(f"   Shape: {embedding_data['shape']}")
            print(f"   Saved to: {output_path}")
            
        except Exception as e:
            print(f"‚ùå Error: {e}")


# Connect callbacks
load_button.on_click(on_load_images)
image_selector.observe(on_image_selected, names='value')
convert_button.on_click(on_convert)

In [None]:
# Display the interface
print("=" * 70)
print("Photo to CLIP Embedding Converter")
print("=" * 70)
print(f"Output directory: {output_dir}")
print(f"Output shape: [77, 768]")
print("=" * 70)

display(
    widgets.VBox([
        widgets.HBox([image_dir_input, load_button]),
        image_selector,
        preview_image,
        convert_button,
        output_text
    ])
)

## Usage Instructions

1. **Specify Image Folder**: Enter the path to your folder containing example photos
2. **Load Images**: Click "Load Images" to scan the folder
3. **Select Image**: Choose an image from the dropdown menu
4. **Convert**: Click "Convert to CLIP Embedding" to process and save

The embeddings will be saved with shape **[77, 768]** to match the format of text CLIP embeddings.

Output JSON files will be saved in: `../data/embeddings/CLIP/`

Filename format: `{original_image_name}_from_image.json`

## Example: Load and Use the Embedding

After converting images, you can load and use the embeddings in your Flux pipeline:

In [None]:
# Example: Load a saved embedding
def load_embedding(json_path):
    """Load embedding from JSON file"""
    with open(json_path, 'r') as f:
        data = json.load(f)
    
    embedding = torch.tensor(data['embedding'])
    print(f"Loaded: {data['prompt']}")
    print(f"Shape: {data['shape']}")
    
    return embedding

# Uncomment to test loading:
# embedding = load_embedding(output_dir / "your_image_from_image.json")

## Notes

- **Output Format**: All embeddings are saved with shape **[77, 768]** to match text CLIP embedding format
  
- **Dimension Projection**: The ViT-Large vision encoder produces 1024-dimensional hidden states, which are projected to 768 dimensions using the visual projection layer to match the text encoder's output dimension.

- **Token Truncation**: The notebook extracts the first 77 tokens from the CLIP vision encoder's projected output (which normally produces 257 tokens). This includes the CLS token and the first 76 patch tokens.

- **Compatibility**: The embeddings use the same CLIP model (`clip-vit-large-patch14`) and projection to the shared 768-dimensional space for consistency with text embeddings

- **File Size**: Each embedding file is approximately 750KB

- **JSON Format**: Output matches the format used by text embeddings:
  - `prompt`: Image filename (without extension)
  - `embedding`: List of 77 token vectors, each with 768 dimensions
  - `shape`: [77, 768]