# Photo to CLIP Embedding Converter

This notebook allows you to select images from a folder and convert them to CLIP embeddings.
The embeddings are saved as JSON files for use in your Flux pipeline.

In [None]:
# Install required packages (uncomment if needed)
# !pip install torch torchvision transformers pillow ipywidgets --break-system-packages

In [None]:
import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import json
from pathlib import Path
import ipywidgets as widgets
from IPython.display import display, clear_output
import os

In [None]:
# Configuration
current_dir = Path(os.getcwd())
output_dir = current_dir.parent / "data" / "embeddings" / "CLIP"
output_dir.mkdir(parents=True, exist_ok=True)

print(f"Output directory: {output_dir}")
print(f"Current directory: {current_dir}")

In [None]:
# Load CLIP model
print("Loading CLIP model...")
model_name = "openai/clip-vit-large-patch14"
model = CLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)
print("Model loaded successfully!")

In [None]:
def get_image_files(directory):
    """Get all image files from a directory"""
    image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp'}
    directory = Path(directory)
    
    if not directory.exists():
        return []
    
    image_files = []
    for file in directory.iterdir():
        if file.is_file() and file.suffix.lower() in image_extensions:
            image_files.append(file)
    
    return sorted(image_files)


def extract_clip_image_embedding(image_path, return_tokens=True):
    """
    Extract CLIP image embeddings from an image file.
    
    Args:
        image_path: Path to the image file
        return_tokens: If True, return token embeddings [257, 768].
                       If False, return pooled embedding [768]
    
    Returns:
        Dictionary with embedding data
    """
    # Load image
    image = Image.open(image_path).convert('RGB')
    
    # Process the image
    inputs = processor(images=image, return_tensors="pt")
    
    # Get image embeddings
    with torch.no_grad():
        if return_tokens:
            # Get full token embeddings
            vision_outputs = model.vision_model(**inputs)
            image_embeds = vision_outputs.last_hidden_state  # [1, 257, 768]
            image_embeds = image_embeds.squeeze(0)  # [257, 768]
            shape = list(image_embeds.shape)
        else:
            # Get pooled embedding
            image_features = model.get_image_features(**inputs)  # [1, 768]
            # Normalize
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            image_embeds = image_features.squeeze(0)  # [768]
            shape = list(image_embeds.shape)
    
    # Convert to list for JSON serialization
    embedding_list = image_embeds.cpu().numpy().tolist()
    
    return {
        "source_image": str(image_path.name),
        "embedding": embedding_list,
        "shape": shape,
        "type": "clip_image_tokens" if return_tokens else "clip_image_pooled",
        "model": model_name
    }


def save_embedding_json(embedding_data, output_path):
    """Save embedding data to JSON file"""
    with open(output_path, 'w') as f:
        json.dump(embedding_data, f)
    print(f"Saved: {output_path}")

In [None]:
# Widget for selecting image directory
image_dir_input = widgets.Text(
    value=str(current_dir / "example_photos"),
    placeholder='Enter path to image folder',
    description='Image Folder:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='600px')
)

load_button = widgets.Button(
    description='Load Images',
    button_style='info',
    tooltip='Load images from the specified folder'
)

image_selector = widgets.Dropdown(
    options=[],
    description='Select Image:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='600px'),
    disabled=True
)

embedding_type = widgets.RadioButtons(
    options=[
        ('Token embeddings [257, 768] - Full detail', True),
        ('Pooled embedding [768] - Single vector', False)
    ],
    description='Embedding Type:',
    style={'description_width': 'initial'}
)

preview_image = widgets.Image(
    format='png',
    width=400,
    height=400
)

convert_button = widgets.Button(
    description='Convert to CLIP Embedding',
    button_style='success',
    tooltip='Extract CLIP embedding and save as JSON',
    disabled=True
)

output_text = widgets.Output()

# Current state
state = {
    'image_files': [],
    'current_image': None
}


def on_load_images(b):
    """Load images from the specified directory"""
    with output_text:
        clear_output()
        image_dir = Path(image_dir_input.value)
        
        if not image_dir.exists():
            print(f"‚ùå Directory not found: {image_dir}")
            image_selector.options = []
            image_selector.disabled = True
            convert_button.disabled = True
            return
        
        image_files = get_image_files(image_dir)
        
        if not image_files:
            print(f"‚ùå No images found in: {image_dir}")
            image_selector.options = []
            image_selector.disabled = True
            convert_button.disabled = True
            return
        
        state['image_files'] = image_files
        image_selector.options = [(f.name, f) for f in image_files]
        image_selector.disabled = False
        convert_button.disabled = False
        
        print(f"‚úÖ Loaded {len(image_files)} images from {image_dir}")


def on_image_selected(change):
    """Update preview when image is selected"""
    if change['new'] is None:
        return
    
    image_path = change['new']
    state['current_image'] = image_path
    
    # Load and display preview
    with open(image_path, 'rb') as f:
        preview_image.value = f.read()
    
    with output_text:
        print(f"Selected: {image_path.name}")


def on_convert(b):
    """Convert selected image to CLIP embedding"""
    with output_text:
        clear_output()
        
        if state['current_image'] is None:
            print("‚ùå No image selected")
            return
        
        image_path = state['current_image']
        return_tokens = embedding_type.value
        
        print(f"üîÑ Processing: {image_path.name}")
        print(f"   Type: {'Token embeddings [257, 768]' if return_tokens else 'Pooled embedding [768]'}")
        
        try:
            # Extract embedding
            embedding_data = extract_clip_image_embedding(image_path, return_tokens)
            
            # Generate output filename
            input_stem = image_path.stem  # filename without extension
            output_filename = f"{input_stem}_from_image.json"
            output_path = output_dir / output_filename
            
            # Save to JSON
            save_embedding_json(embedding_data, output_path)
            
            print(f"‚úÖ Success!")
            print(f"   Shape: {embedding_data['shape']}")
            print(f"   Saved to: {output_path}")
            
        except Exception as e:
            print(f"‚ùå Error: {e}")


# Connect callbacks
load_button.on_click(on_load_images)
image_selector.observe(on_image_selected, names='value')
convert_button.on_click(on_convert)

In [None]:
# Display the interface
print("=" * 70)
print("Photo to CLIP Embedding Converter")
print("=" * 70)
print(f"Output directory: {output_dir}")
print("=" * 70)

display(
    widgets.VBox([
        widgets.HBox([image_dir_input, load_button]),
        image_selector,
        preview_image,
        embedding_type,
        convert_button,
        output_text
    ])
)

## Usage Instructions

1. **Specify Image Folder**: Enter the path to your folder containing example photos
2. **Load Images**: Click "Load Images" to scan the folder
3. **Select Image**: Choose an image from the dropdown menu
4. **Choose Embedding Type**:
   - **Token embeddings [257, 768]**: Full token-level embeddings (like your text CLIP with 77 tokens)
   - **Pooled embedding [768]**: Single vector representation
5. **Convert**: Click "Convert to CLIP Embedding" to process and save

The output JSON files will be saved in: `../data/embeddings/CLIP/`

Filename format: `{original_image_name}_from_image.json`

## Example: Load and Use the Embedding

After converting images, you can load and use the embeddings in your Flux pipeline:

In [None]:
# Example: Load a saved embedding
def load_embedding(json_path):
    """Load embedding from JSON file"""
    with open(json_path, 'r') as f:
        data = json.load(f)
    
    embedding = torch.tensor(data['embedding'])
    print(f"Loaded: {data['source_image']}")
    print(f"Type: {data['type']}")
    print(f"Shape: {data['shape']}")
    print(f"Model: {data['model']}")
    
    return embedding

# Uncomment to test loading:
# embedding = load_embedding(output_dir / "your_image_from_image.json")

## Notes

- **Token Embeddings vs Pooled**: 
  - Token embeddings preserve spatial/patch information (good for detailed control)
  - Pooled embeddings are more compact (good for overall style/content)
  
- **Compatibility**: The embeddings use the same CLIP model (`clip-vit-large-patch14`) for consistency

- **File Size**: Token embeddings (~1.5 MB) are larger than pooled embeddings (~6 KB)

- **Integration**: You can blend these image embeddings with text embeddings or use them directly in your Flux pipeline