# Visual Attribute Blending: Image → Text Embeddings

## The Problem
Flux was trained on **text embeddings** from CLIP's text encoder, not **image embeddings** from CLIP's vision encoder. Even though both produce 768-dimensional vectors, they come from different encoders with different statistical distributions.

## The Solution
Instead of using image embeddings directly, we:
1. Define a library of visual attributes as text (colors, composition, lighting)
2. Use CLIP to measure how well each attribute describes the image
3. Create a weighted blend of **text embeddings** based on these similarities
4. Output a proper [77, 768] text embedding that Flux can understand

## Why This Works
- CLIP was trained to align vision and text in the same space
- We extract visual features automatically using CLIP's similarity
- We stay in text embedding space (where Flux was trained)
- No manual text description needed!

```mermaid
flowchart LR
    IMG[Input Image]
    
    subgraph CLIP Analysis
        VIS[CLIP Vision<br/>Encoder]
        ATTR[Attribute Library<br/>colors, mood, lighting...]
        SIM[Similarity<br/>Calculation]
    end

    subgraph Text Embedding Generation
        TOP[Top K Attributes]
        TXT[CLIP Text<br/>Encoder]
        BLEND[Weighted<br/>Blend]
    end

    EMB[Text Embedding<br/>77 × 768]

    IMG --> VIS --> SIM
    ATTR --> SIM
    SIM --> TOP --> TXT --> BLEND --> EMB
    
    EMB -->|compatible with| FLUX[FLUX]
```

## Installation

Uncomment if needed:

In [None]:
# !pip install torch torchvision transformers pillow ipywidgets matplotlib --break-system-packages

## Imports and Setup

In [None]:
import torch
from transformers import CLIPProcessor, CLIPModel, CLIPTextModel, CLIPTokenizer
from PIL import Image
import json
import numpy as np
from pathlib import Path
import ipywidgets as widgets
from IPython.display import display, clear_output
import os
import matplotlib.pyplot as plt

# Configuration
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

current_dir = Path(os.getcwd())
output_dir = current_dir.parent / "data" / "embeddings" / "CLIP"
output_dir.mkdir(parents=True, exist_ok=True)

print(f"Output directory: {output_dir}")

## Load CLIP Models

We need:
- **Full CLIP model** (vision + text) for similarity calculations
- **Text encoder** to generate the final text embeddings for Flux

In [None]:
print("Loading CLIP models...")

# Full CLIP model (for similarity calculations)
model_name = "openai/clip-vit-large-patch14"
clip_model = CLIPModel.from_pretrained(model_name).to(device)
clip_processor = CLIPProcessor.from_pretrained(model_name)

# Text encoder and tokenizer (for generating final embeddings)
text_model = CLIPTextModel.from_pretrained(model_name).to(device)
tokenizer = CLIPTokenizer.from_pretrained(model_name)

clip_model.eval()
text_model.eval()

print("✓ Models loaded successfully!")
print(f"  Device: {device}")
print(f"  Text embedding dimension: {text_model.config.hidden_size}")
print(f"  Max tokens: {tokenizer.model_max_length}")

## Define Visual Attribute Library

These are the "building blocks" we'll use to describe images.
Each attribute will be converted to a text embedding, then blended based on how well it matches the image.

You can customize this library to focus on attributes relevant to your use case!

In [None]:
# Define attribute library
ATTRIBUTE_LIBRARY = {
    "colors": [
        "red colors",
        "orange colors",
        "yellow colors",
        "green colors",
        "blue colors",
        "purple colors",
        "pink colors",
        "warm colors",
        "cool colors",
        "vibrant colors",
        "muted colors",
        "pastel colors",
        "dark colors",
        "light colors",
    ],
    "composition": [
        "centered composition",
        "horizontal composition",
        "vertical composition",
        "diagonal composition",
        "symmetrical composition",
        "asymmetrical composition",
        "simple composition",
        "complex composition",
        "minimal composition",
        "busy composition",
    ],
    "lighting": [
        "bright lighting",
        "dark lighting",
        "soft lighting",
        "harsh lighting",
        "high contrast",
        "low contrast",
        "dramatic lighting",
        "natural lighting",
        "backlit",
        "evenly lit",
    ],
    "texture": [
        "smooth texture",
        "rough texture",
        "soft texture",
        "detailed texture",
        "blurred texture",
        "sharp details",
    ],
    "mood": [
        "peaceful mood",
        "dramatic mood",
        "energetic mood",
        "calm mood",
        "mysterious mood",
        "cheerful mood",
    ]
}

# Flatten all attributes into a single list
all_attributes = []
for category, attrs in ATTRIBUTE_LIBRARY.items():
    all_attributes.extend(attrs)

print(f"Total attributes defined: {len(all_attributes)}")
print("\nCategories:")
for category, attrs in ATTRIBUTE_LIBRARY.items():
    print(f"  {category}: {len(attrs)} attributes")

## Core Functions

In [None]:
def get_image_features(image_path):
    """
    Extract CLIP vision features from an image.
    These are used for similarity calculations only.
    
    Returns: normalized image features [1, 768]
    """
    image = Image.open(image_path).convert('RGB')
    inputs = clip_processor(images=image, return_tensors="pt").to(device)
    
    with torch.no_grad():
        image_features = clip_model.get_image_features(**inputs)
        # Normalize for cosine similarity
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
    
    return image_features


def get_text_features(texts):
    """
    Extract CLIP text features for similarity calculations.
    
    Args:
        texts: List of text strings
    
    Returns: normalized text features [len(texts), 768]
    """
    inputs = clip_processor(text=texts, return_tensors="pt", padding=True).to(device)
    
    with torch.no_grad():
        text_features = clip_model.get_text_features(**inputs)
        # Normalize for cosine similarity
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
    
    return text_features


def calculate_similarities(image_features, text_features):
    """
    Calculate cosine similarity between image and each text attribute.
    
    Returns: similarity scores [len(texts)]
    """
    # Cosine similarity (features are already normalized)
    similarities = (image_features @ text_features.T).squeeze(0)
    return similarities


def get_text_embeddings(texts):
    """
    Generate full [77, 768] text embeddings from CLIP text encoder.
    These are the embeddings Flux expects.
    
    Args:
        texts: List of text strings
    
    Returns: text embeddings [len(texts), 77, 768]
    """
    tokens = tokenizer(
        texts,
        padding="max_length",
        max_length=77,
        truncation=True,
        return_tensors="pt"
    ).to(device)
    
    with torch.no_grad():
        outputs = text_model(**tokens)
        embeddings = outputs.last_hidden_state  # [batch, 77, 768]
    
    return embeddings


def blend_embeddings(embeddings, weights):
    """
    Create weighted blend of text embeddings.
    
    Args:
        embeddings: Tensor of shape [n_attributes, 77, 768]
        weights: Tensor of shape [n_attributes] - similarity scores
    
    Returns: Blended embedding [77, 768]
    """
    # Normalize weights to sum to 1
    weights = weights / weights.sum()
    
    # Reshape weights for broadcasting: [n_attributes, 1, 1]
    weights = weights.view(-1, 1, 1)
    
    # Weighted sum
    blended = (embeddings * weights).sum(dim=0)  # [77, 768]
    
    return blended


print("✓ Functions defined")

## Main Processing Function

In [None]:
def create_visual_text_embedding(image_path, top_k=15, min_similarity=0.15):
    """
    Create a text embedding informed by visual content.
    
    Args:
        image_path: Path to input image
        top_k: Number of top attributes to use for blending (default: 15)
        min_similarity: Minimum similarity threshold (default: 0.15)
    
    Returns:
        Dictionary with:
        - embedding: [77, 768] numpy array
        - shape: [77, 768]
        - prompt: Description of attributes used
        - top_attributes: List of (attribute, score) tuples
    """
    print(f"Processing: {image_path.name}")
    print("="*70)
    
    # Step 1: Get image features
    print("\n1. Extracting image features...")
    image_features = get_image_features(image_path)
    print(f"   ✓ Image features: {image_features.shape}")
    
    # Step 2: Get text features for all attributes
    print("\n2. Processing attribute library...")
    text_features = get_text_features(all_attributes)
    print(f"   ✓ Text features: {text_features.shape}")
    
    # Step 3: Calculate similarities
    print("\n3. Calculating similarities...")
    similarities = calculate_similarities(image_features, text_features)
    print(f"   ✓ Similarities calculated: {similarities.shape}")
    print(f"   Similarity range: [{similarities.min():.3f}, {similarities.max():.3f}]")
    
    # Step 4: Select top K attributes
    print(f"\n4. Selecting top {top_k} attributes (min similarity: {min_similarity})...")
    
    # Get top K indices
    top_indices = torch.argsort(similarities, descending=True)[:top_k]
    top_similarities = similarities[top_indices]
    
    # Filter by minimum similarity
    mask = top_similarities >= min_similarity
    top_indices = top_indices[mask]
    top_similarities = top_similarities[mask]
    
    # Get attribute names
    top_attributes = [(all_attributes[idx], similarities[idx].item()) 
                      for idx in top_indices]
    
    print(f"   ✓ Selected {len(top_attributes)} attributes:")
    for attr, score in top_attributes:
        print(f"      {score:.3f} - {attr}")
    
    # Step 5: Get full text embeddings for selected attributes
    print("\n5. Generating text embeddings...")
    selected_texts = [all_attributes[idx] for idx in top_indices]
    text_embeddings = get_text_embeddings(selected_texts)  # [n, 77, 768]
    print(f"   ✓ Text embeddings: {text_embeddings.shape}")
    
    # Step 6: Blend embeddings
    print("\n6. Blending embeddings...")
    blended = blend_embeddings(text_embeddings, top_similarities)
    print(f"   ✓ Blended embedding: {blended.shape}")
    
    # Convert to numpy
    embedding_array = blended.cpu().numpy()
    
    # Create prompt description
    prompt = ", ".join([attr for attr, _ in top_attributes[:5]])
    
    print("\n" + "="*70)
    print("✓ Complete!")
    print(f"Prompt (top 5 attributes): {prompt}")
    
    return {
        "prompt": prompt,
        "embedding": embedding_array.tolist(),
        "shape": [77, 768],
        "top_attributes": top_attributes,
        "method": "visual_attribute_blending",
        "source_image": str(image_path.name)
    }

print("✓ Main function defined")

## Visualization Function

In [None]:
def visualize_attributes(top_attributes, image_path=None):
    """
    Visualize the top attributes and their similarity scores.
    """
    if not top_attributes:
        print("No attributes to visualize")
        return
    
    # Prepare data
    attrs = [attr for attr, _ in top_attributes]
    scores = [score for _, score in top_attributes]
    
    # Create figure
    if image_path:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
        
        # Show image
        img = Image.open(image_path)
        ax1.imshow(img)
        ax1.axis('off')
        ax1.set_title('Input Image', fontsize=14, fontweight='bold')
    else:
        fig, ax2 = plt.subplots(1, 1, figsize=(10, 6))
    
    # Plot bar chart
    y_pos = np.arange(len(attrs))
    ax2.barh(y_pos, scores, color='steelblue')
    ax2.set_yticks(y_pos)
    ax2.set_yticklabels(attrs)
    ax2.invert_yaxis()  # Top attribute at the top
    ax2.set_xlabel('CLIP Similarity Score', fontsize=12)
    ax2.set_title('Top Visual Attributes', fontsize=14, fontweight='bold')
    ax2.grid(axis='x', alpha=0.3)
    
    # Add value labels
    for i, v in enumerate(scores):
        ax2.text(v + 0.01, i, f'{v:.3f}', va='center', fontsize=9)
    
    plt.tight_layout()
    plt.show()

print("✓ Visualization function defined")

## Interactive Interface

In [None]:
# File browser
def get_image_files(directory):
    """Get all image files from a directory"""
    image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp'}
    directory = Path(directory)
    
    if not directory.exists():
        return []
    
    image_files = []
    for file in directory.iterdir():
        if file.is_file() and file.suffix.lower() in image_extensions:
            image_files.append(file)
    
    return sorted(image_files)

# Widgets
image_dir_input = widgets.Text(
    value=str(current_dir.parent / "data" / "input_img"),
    placeholder='Enter path to image folder',
    description='Image Folder:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='600px')
)

load_button = widgets.Button(
    description='Load Images',
    button_style='info'
)

image_selector = widgets.Dropdown(
    options=[],
    description='Select Image:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='600px'),
    disabled=True
)

preview_image = widgets.Image(
    format='png',
    width=400,
    height=400
)

top_k_slider = widgets.IntSlider(
    value=15,
    min=5,
    max=30,
    step=1,
    description='Top K attributes:',
    style={'description_width': 'initial'}
)

min_sim_slider = widgets.FloatSlider(
    value=0.15,
    min=0.0,
    max=0.5,
    step=0.05,
    description='Min similarity:',
    style={'description_width': 'initial'},
    readout_format='.2f'
)

process_button = widgets.Button(
    description='Process Image',
    button_style='success',
    disabled=True
)

output_area = widgets.Output()

# State
state = {
    'image_files': [],
    'current_image': None,
    'last_result': None
}

# Callbacks
def on_load_images(b):
    with output_area:
        clear_output()
        image_dir = Path(image_dir_input.value)
        
        if not image_dir.exists():
            print(f"❌ Directory not found: {image_dir}")
            return
        
        image_files = get_image_files(image_dir)
        
        if not image_files:
            print(f"❌ No images found in: {image_dir}")
            return
        
        state['image_files'] = image_files
        image_selector.options = [(f.name, f) for f in image_files]
        image_selector.disabled = False
        process_button.disabled = False
        
        print(f"✓ Loaded {len(image_files)} images from {image_dir}")

def on_image_selected(change):
    if change['new'] is None:
        return
    
    image_path = change['new']
    state['current_image'] = image_path
    
    # Load and display preview
    with open(image_path, 'rb') as f:
        preview_image.value = f.read()

def on_process(b):
    with output_area:
        clear_output(wait=True)
        
        if state['current_image'] is None:
            print("❌ No image selected")
            return
        
        image_path = state['current_image']
        
        try:
            # Process image
            result = create_visual_text_embedding(
                image_path,
                top_k=top_k_slider.value,
                min_similarity=min_sim_slider.value
            )
            
            state['last_result'] = result
            
            # Save to JSON
            output_filename = f"{image_path.stem}_visual_blend.json"
            output_path = output_dir / output_filename
            
            # Remove top_attributes from saved JSON (too verbose)
            save_data = {k: v for k, v in result.items() if k != 'top_attributes'}
            
            with open(output_path, 'w') as f:
                json.dump(save_data, f)
            
            print(f"\n✓ Saved to: {output_path}")
            print(f"  File size: {os.path.getsize(output_path) / 1024:.2f} KB")
            
            # Visualize
            print("\n" + "="*70)
            visualize_attributes(result['top_attributes'], image_path)
            
        except Exception as e:
            print(f"❌ Error: {e}")
            import traceback
            traceback.print_exc()

# Connect callbacks
load_button.on_click(on_load_images)
image_selector.observe(on_image_selected, names='value')
process_button.on_click(on_process)

# Display interface
print("="*70)
print("Visual Attribute Blending Interface")
print("="*70)

display(
    widgets.VBox([
        widgets.HTML("<h3>1. Load Images</h3>"),
        widgets.HBox([image_dir_input, load_button]),
        widgets.HTML("<h3>2. Select Image</h3>"),
        image_selector,
        preview_image,
        widgets.HTML("<h3>3. Configure Processing</h3>"),
        top_k_slider,
        min_sim_slider,
        widgets.HTML("<h3>4. Process</h3>"),
        process_button,
        output_area
    ])
)

## Usage Instructions

1. **Load Images**: Enter the path to your image folder and click "Load Images"
2. **Select Image**: Choose an image from the dropdown
3. **Configure Processing**:
   - **Top K attributes**: How many top-scoring attributes to blend (more = more comprehensive)
   - **Min similarity**: Minimum threshold for including an attribute (higher = more selective)
4. **Process**: Click "Process Image" to create the blended text embedding

The output will show:
- Processing steps and selected attributes
- Visualization showing similarity scores
- Saved JSON file path

## Output Format

The generated JSON files will have:
- `prompt`: Comma-separated list of top 5 attributes
- `embedding`: [77, 768] array - compatible with Flux!
- `shape`: [77, 768]
- `method`: "visual_attribute_blending"
- `source_image`: Original image filename

## How It Works

### The Process:

1. **Extract Image Features**: Get vision embedding from CLIP (used only for similarity)
2. **Calculate Similarities**: Compare image to each attribute text
3. **Select Top Attributes**: Keep only the highest-scoring attributes
4. **Generate Text Embeddings**: Get full [77, 768] embeddings for each selected attribute
5. **Weighted Blend**: Combine embeddings weighted by similarity scores
6. **Output**: Final [77, 768] text embedding that Flux can use

### Key Insight:

Instead of trying to convert image embeddings to text embeddings (impossible!), we:
- Use CLIP's vision-text alignment to measure visual properties
- Stay entirely in text embedding space for the final output
- Automatically extract visual features without manual description

### Customization:

- **Attribute Library**: Modify `ATTRIBUTE_LIBRARY` to focus on specific visual aspects
- **Top K**: Higher values = more attributes blended = more comprehensive
- **Min Similarity**: Higher threshold = more selective = stronger attributes only

### Comparison with Traditional Approaches:

- **CLIP Interrogation**: Generates a single text description → single embedding
- **This Method**: Blends multiple visual attributes → weighted combination
- **Direct Image Embedding**: Uses vision encoder → incompatible with Flux ❌

### Limitations:

- Won't capture specific objects or scenes ("elephant", "sunset")
- Focuses on visual attributes (colors, composition, mood)
- Results depend on the attribute library
- Best for style transfer and visual property extraction

## Batch Processing (Optional)

Process multiple images at once:

In [None]:
def batch_process_images(image_dir, top_k=15, min_similarity=0.15):
    """
    Process all images in a directory.
    """
    image_dir = Path(image_dir)
    image_files = get_image_files(image_dir)
    
    print(f"Found {len(image_files)} images in {image_dir}")
    print("="*70)
    
    results = []
    
    for i, image_path in enumerate(image_files, 1):
        print(f"\n[{i}/{len(image_files)}] Processing {image_path.name}...")
        
        try:
            result = create_visual_text_embedding(
                image_path,
                top_k=top_k,
                min_similarity=min_similarity
            )
            
            # Save
            output_filename = f"{image_path.stem}_visual_blend.json"
            output_path = output_dir / output_filename
            
            save_data = {k: v for k, v in result.items() if k != 'top_attributes'}
            
            with open(output_path, 'w') as f:
                json.dump(save_data, f)
            
            print(f"✓ Saved to: {output_path}")
            results.append(result)
            
        except Exception as e:
            print(f"❌ Error processing {image_path.name}: {e}")
    
    print("\n" + "="*70)
    print(f"✓ Batch processing complete! Processed {len(results)}/{len(image_files)} images")
    
    return results

# Uncomment to run batch processing:
# results = batch_process_images(
#     image_dir=current_dir.parent / "data" / "input_img",
#     top_k=15,
#     min_similarity=0.15
# )

## Test with Flux

To verify the embeddings work with Flux, you can load them in your `Image_generation.ipynb` notebook.

The embeddings are saved in the same format and location as your other CLIP embeddings, so they should load directly into your existing pipeline!