# Embedding Manipulation

Load and manipulate embeddings from CLIP or T5 models.

This notebook allows you to:
- Load saved embeddings (CLIP or T5)
- Apply transformations: **Scale**, **Invert**, **Zero Range**
- **Visualize before/after comparison** with heatmaps and statistics
- Save modified embeddings for use in image generation

**Zero Range** lets you select which token positions to keep using a range slider. Everything outside the range is zeroed out. This is useful because:
- The **first token** (position 0) is typically a start token with less semantic meaning
- The **last token** often has special emphasis in the embedding
- **Padding tokens** (positions after real tokens) can significantly affect generation

```mermaid
flowchart LR
    subgraph Input
        T5[T5 Embedding<br/>JSON]
        CLIP[CLIP Embedding<br/>JSON]
    end

    subgraph Manipulations
        S[Scale]
        I[Invert]
        Z[Zero Range<br/>keep start:end]
    end

    subgraph Visualize
        VIZ[Before/After<br/>Comparison]
    end

    subgraph Output
        MOD[Modified<br/>Embedding JSON]
    end

    T5 --> S & I & Z
    CLIP --> S & I & Z
    S & I & Z --> VIZ
    VIZ --> MOD

    MOD -->|use in| GEN[Image Generation]
```

In [None]:
import numpy as np
import json
import os
from pathlib import Path
import ipywidgets as widgets
from IPython.display import display, clear_output
import matplotlib.pyplot as plt

# Setup directories
current_dir = Path.cwd()
CLIP_EMBEDDINGS_DIR = current_dir.parent / "data/embeddings/CLIP"
T5_EMBEDDINGS_DIR = current_dir.parent / "data/embeddings/T5"

# Global variables
current_embedding = None
current_prompt = None
current_model_type = None
modified_embedding = None
zero_range = None  # Store the zero range for visualization

print("✓ Setup complete!")
print(f"CLIP embeddings: {CLIP_EMBEDDINGS_DIR}")
print(f"T5 embeddings: {T5_EMBEDDINGS_DIR}")

In [None]:
# Function definitions

def update_file_list(change):
    """Update file dropdown when model type changes"""
    model_type = change['new']
    embedding_dir = CLIP_EMBEDDINGS_DIR if model_type == 'CLIP' else T5_EMBEDDINGS_DIR
    
    if embedding_dir.exists():
        # Use **/*.json to find files in subdirectories too
        json_files = sorted([
            str(f.relative_to(embedding_dir)) 
            for f in embedding_dir.glob('**/*.json')
        ])
        embedding_file_dropdown.options = json_files
    else:
        embedding_file_dropdown.options = []

def load_embedding(b):
    global current_embedding, current_prompt, current_model_type, modified_embedding, zero_range
    
    with load_output:
        load_output.clear_output()
        
        model_type = model_type_dropdown.value
        filename = embedding_file_dropdown.value
        
        if not filename:
            print("❌ No file selected!")
            return
        
        embedding_dir = CLIP_EMBEDDINGS_DIR if model_type == 'CLIP' else T5_EMBEDDINGS_DIR
        filepath = embedding_dir / filename
        
        try:
            with open(filepath, 'r') as f:
                data = json.load(f)
            
            current_embedding = np.array(data['embedding'])
            current_prompt = data.get('prompt', 'Unknown')
            current_model_type = model_type
            modified_embedding = None  # Reset modified embedding
            zero_range = None  # Reset zero range
            
            # Update the range slider max value based on embedding shape
            num_tokens = current_embedding.shape[0]
            zero_range_slider.max = num_tokens
            zero_range_slider.value = (0, num_tokens)  # Default: keep all
            
            print(f"✓ Loaded {model_type} embedding!")
            print(f"  File: {filename}")
            print(f"  Prompt: '{current_prompt}'")
            print(f"  Shape: {current_embedding.shape}")
            print(f"  Size: {current_embedding.nbytes / 1024:.2f} KB")
            print(f"  Value range: [{current_embedding.min():.4f}, {current_embedding.max():.4f}]")
            
        except Exception as e:
            print(f"❌ Error loading embedding: {e}")

def update_params(change):
    """Update parameter widgets based on manipulation type"""
    manipulation = change['new']
    
    if manipulation == 'Scale':
        params_box.children = [scale_slider]
    elif manipulation == 'Invert':
        params_box.children = []
    elif manipulation == 'Zero Range':
        params_box.children = [zero_range_slider, zero_range_label]

def update_range_label(change):
    """Update the label showing which positions will be zeroed"""
    start, end = change['new']
    max_val = zero_range_slider.max
    
    zeroed_parts = []
    if start > 0:
        zeroed_parts.append(f"positions 0-{start-1}")
    if end < max_val:
        zeroed_parts.append(f"positions {end}-{max_val-1}")
    
    if zeroed_parts:
        zero_range_label.value = f"<b>Zeroing:</b> {', '.join(zeroed_parts)} | <b>Keeping:</b> positions {start}-{end-1}"
    else:
        zero_range_label.value = "<b>Keeping all positions</b> (no zeroing)"

def apply_manipulation(b):
    global modified_embedding, zero_range
    
    with manipulation_output:
        manipulation_output.clear_output()
        
        if current_embedding is None:
            print("❌ No embedding loaded! Load an embedding first.")
            return
        
        manipulation = manipulation_dropdown.value
        
        # Create a copy
        modified = current_embedding.copy()
        
        if manipulation == 'Scale':
            factor = scale_slider.value
            modified = modified * factor
            zero_range = None  # Clear zero range for other manipulations
            print(f"✓ Scaled embedding by {factor}x")
            print(f"  Original range: [{current_embedding.min():.4f}, {current_embedding.max():.4f}]")
            print(f"  Modified range: [{modified.min():.4f}, {modified.max():.4f}]")
            
        elif manipulation == 'Invert':
            modified = -modified
            zero_range = None  # Clear zero range for other manipulations
            print(f"✓ Inverted embedding values")
            print(f"  Original range: [{current_embedding.min():.4f}, {current_embedding.max():.4f}]")
            print(f"  Modified range: [{modified.min():.4f}, {modified.max():.4f}]")
            
        elif manipulation == 'Zero Range':
            start, end = zero_range_slider.value
            zero_range = (start, end)  # Store for visualization
            num_tokens = modified.shape[0]
            
            # Zero out positions OUTSIDE the selected range
            if start > 0:
                modified[:start] = 0.0
            if end < num_tokens:
                modified[end:] = 0.0
            
            num_kept = end - start
            num_zeroed = num_tokens - num_kept
            
            print(f"✓ Zeroed token positions outside range [{start}, {end})")
            print(f"  Kept positions: {start} to {end-1} ({num_kept} tokens)")
            print(f"  Zeroed positions: {num_zeroed} tokens")
            if start > 0:
                print(f"    - Start: positions 0 to {start-1}")
            if end < num_tokens:
                print(f"    - End: positions {end} to {num_tokens-1}")
            print(f"  Original non-zero: {np.count_nonzero(current_embedding):,}")
            print(f"  Modified non-zero: {np.count_nonzero(modified):,}")
        
        modified_embedding = modified
        print(f"\n  Shape: {modified_embedding.shape}")
        print(f"  Size: {modified_embedding.nbytes / 1024:.2f} KB")

def add_zero_range_lines(ax, zero_range, num_tokens):
    """Helper to add green boundary lines for zero range"""
    if zero_range is not None:
        start, end = zero_range
        if start > 0:
            ax.axvline(x=start-0.5, color='lime', linewidth=2, linestyle='--', label='Keep range')
        if end < num_tokens:
            ax.axvline(x=end-0.5, color='lime', linewidth=2, linestyle='--')
        ax.legend(loc='upper right')

def visualize_comparison(b):
    """Visualize before/after comparison of embeddings"""
    with visualization_output:
        visualization_output.clear_output()
        
        if current_embedding is None:
            print("❌ No embedding loaded! Load an embedding first.")
            return
        
        if modified_embedding is None:
            print("❌ No modified embedding! Apply a manipulation first.")
            return
        
        original = current_embedding
        modified = modified_embedding
        num_dims = original.shape[1]
        num_tokens = original.shape[0]
        
        # Check if this is a large embedding (like T5)
        is_large_embedding = num_dims > 1000
        
        if is_large_embedding:
            # T5: Use tighter bounds to make colors darker/more visible
            vmin, vmax = -0.3, 0.3
            truncate_dims = 512
            
            # Create figure with 4 plots: full + truncated for both original and modified
            fig_height = max(12, min(25, num_dims / 200))
            fig, axes = plt.subplots(2, 2, figsize=(16, fig_height), constrained_layout=True)
            
            prompt_display = f'"{current_prompt[:40]}..."' if len(current_prompt) > 40 else f'"{current_prompt}"'
            manipulation_name = manipulation_dropdown.value
            
            # Row 1: Original (truncated + full)
            # Truncated original
            ax1 = axes[0, 0]
            im1 = ax1.imshow(original[:, :truncate_dims].T, aspect='auto', cmap='RdBu_r', vmin=vmin, vmax=vmax)
            ax1.set_title(f'Original: {prompt_display}\n(dims 0-{truncate_dims})', fontweight='bold', fontsize=10)
            ax1.set_xlabel('Token Position')
            ax1.set_ylabel('Dimension')
            plt.colorbar(im1, ax=ax1, label='Value', shrink=0.8)
            add_zero_range_lines(ax1, zero_range, num_tokens)
            
            # Full original
            ax2 = axes[0, 1]
            im2 = ax2.imshow(original.T, aspect='auto', cmap='RdBu_r', vmin=vmin, vmax=vmax)
            ax2.set_title(f'Original: {prompt_display}\n(all {num_dims} dims)', fontweight='bold', fontsize=10)
            ax2.set_xlabel('Token Position')
            ax2.set_ylabel('Dimension')
            plt.colorbar(im2, ax=ax2, label='Value', shrink=0.8)
            add_zero_range_lines(ax2, zero_range, num_tokens)
            
            # Row 2: Modified (truncated + full)
            # Truncated modified
            ax3 = axes[1, 0]
            im3 = ax3.imshow(modified[:, :truncate_dims].T, aspect='auto', cmap='RdBu_r', vmin=vmin, vmax=vmax)
            ax3.set_title(f'Modified ({manipulation_name})\n(dims 0-{truncate_dims})', fontweight='bold', fontsize=10)
            ax3.set_xlabel('Token Position')
            ax3.set_ylabel('Dimension')
            plt.colorbar(im3, ax=ax3, label='Value', shrink=0.8)
            add_zero_range_lines(ax3, zero_range, num_tokens)
            
            # Full modified
            ax4 = axes[1, 1]
            im4 = ax4.imshow(modified.T, aspect='auto', cmap='RdBu_r', vmin=vmin, vmax=vmax)
            ax4.set_title(f'Modified ({manipulation_name})\n(all {num_dims} dims)', fontweight='bold', fontsize=10)
            ax4.set_xlabel('Token Position')
            ax4.set_ylabel('Dimension')
            plt.colorbar(im4, ax=ax4, label='Value', shrink=0.8)
            add_zero_range_lines(ax4, zero_range, num_tokens)
            
        else:
            # CLIP: Standard visualization with 2 plots
            vmin, vmax = -1, 1
            fig, axes = plt.subplots(2, 1, figsize=(14, 10), constrained_layout=True)
            
            prompt_display = f'"{current_prompt[:50]}..."' if len(current_prompt) > 50 else f'"{current_prompt}"'
            manipulation_name = manipulation_dropdown.value
            
            # Original
            ax1 = axes[0]
            im1 = ax1.imshow(original.T, aspect='auto', cmap='RdBu_r', vmin=vmin, vmax=vmax)
            ax1.set_title(f'Original Embedding: {prompt_display}', fontweight='bold', fontsize=11)
            ax1.set_xlabel('Token Position')
            ax1.set_ylabel('Dimension')
            plt.colorbar(im1, ax=ax1, label='Value', shrink=0.8)
            add_zero_range_lines(ax1, zero_range, num_tokens)
            
            # Modified
            ax2 = axes[1]
            im2 = ax2.imshow(modified.T, aspect='auto', cmap='RdBu_r', vmin=vmin, vmax=vmax)
            ax2.set_title(f'Modified Embedding ({manipulation_name})', fontweight='bold', fontsize=11)
            ax2.set_xlabel('Token Position')
            ax2.set_ylabel('Dimension')
            plt.colorbar(im2, ax=ax2, label='Value', shrink=0.8)
            add_zero_range_lines(ax2, zero_range, num_tokens)
        
        plt.show()
        
        # Print statistics
        print("\n" + "="*60)
        print("COMPARISON STATISTICS")
        print("="*60)
        print(f"\n{'Metric':<25} {'Original':<15} {'Modified':<15} {'Change':<15}")
        print("-"*70)
        print(f"{'Min value':<25} {original.min():<15.4f} {modified.min():<15.4f} {modified.min()-original.min():<+15.4f}")
        print(f"{'Max value':<25} {original.max():<15.4f} {modified.max():<15.4f} {modified.max()-original.max():<+15.4f}")
        print(f"{'Mean':<25} {original.mean():<15.4f} {modified.mean():<15.4f} {modified.mean()-original.mean():<+15.4f}")
        print(f"{'Std deviation':<25} {original.std():<15.4f} {modified.std():<15.4f} {modified.std()-original.std():<+15.4f}")
        print(f"{'Total L2 norm':<25} {np.linalg.norm(original):<15.4f} {np.linalg.norm(modified):<15.4f} {np.linalg.norm(modified)-np.linalg.norm(original):<+15.4f}")
        print(f"{'Non-zero values':<25} {np.count_nonzero(original):<15,} {np.count_nonzero(modified):<15,} {np.count_nonzero(modified)-np.count_nonzero(original):<+15,}")
        
        # Cosine similarity between original and modified
        cos_sim = np.dot(original.flatten(), modified.flatten()) / (np.linalg.norm(original) * np.linalg.norm(modified))
        print(f"\n{'Cosine similarity':<25} {cos_sim:.6f}")
        print("="*60)

def save_modified(b):
    with save_output:
        save_output.clear_output()
        
        if modified_embedding is None:
            print("❌ No modified embedding to save! Apply a manipulation first.")
            return
        
        if current_model_type is None:
            print("❌ No model type detected!")
            return
        
        # Determine save directory
        save_dir = CLIP_EMBEDDINGS_DIR if current_model_type == 'CLIP' else T5_EMBEDDINGS_DIR
        os.makedirs(save_dir, exist_ok=True)
        
        # Get original filename and add manipulation suffix
        original_filename = embedding_file_dropdown.value
        # Handle subdirectory paths - get just the filename part
        base_name = Path(original_filename).stem
        
        # Create suffix based on manipulation type
        manipulation = manipulation_dropdown.value
        if manipulation == 'Scale':
            suffix = f"_scaled_{scale_slider.value}x"
        elif manipulation == 'Invert':
            suffix = "_inverted"
        elif manipulation == 'Zero Range':
            start, end = zero_range_slider.value
            suffix = f"_keep_{start}_to_{end}"
        
        new_filename = f"{base_name}{suffix}.json"
        filepath = save_dir / new_filename
        
        # Save embedding
        embedding_data = {
            "prompt": current_prompt,
            "embedding": modified_embedding.tolist(),
            "shape": list(modified_embedding.shape),
            "manipulation": manipulation,
            "original_file": original_filename
        }
        
        # Add zero range info if applicable
        if manipulation == 'Zero Range' and zero_range is not None:
            embedding_data["zero_range"] = {"keep_start": zero_range[0], "keep_end": zero_range[1]}
        
        with open(filepath, 'w') as f:
            json.dump(embedding_data, f)
        
        print(f"✓ Modified embedding saved!")
        print(f"  Directory: {save_dir}")
        print(f"  Filename: {new_filename}")
        print(f"  Size: {os.path.getsize(filepath) / 1024:.2f} KB")

print("✓ Functions loaded!")

In [None]:
# Create all widgets

# Load section
model_type_dropdown = widgets.Dropdown(
    options=['CLIP', 'T5'],
    value='CLIP',
    description='Model:',
    style={'description_width': 'initial'}
)

embedding_file_dropdown = widgets.Dropdown(
    options=[],
    description='File:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='500px')
)

load_button = widgets.Button(
    description='Load Embedding',
    button_style='success'
)

load_output = widgets.Output()

# Manipulation section
manipulation_dropdown = widgets.Dropdown(
    options=['Scale', 'Invert', 'Zero Range'],
    value='Scale',
    description='Operation:',
    style={'description_width': 'initial'}
)

scale_slider = widgets.FloatSlider(
    value=2.0,
    min=0.1,
    max=5.0,
    step=0.1,
    description='Scale Factor:',
    style={'description_width': 'initial'}
)

# Range slider for selecting which token positions to KEEP
zero_range_slider = widgets.IntRangeSlider(
    value=(0, 77),  # Default for CLIP (77 tokens)
    min=0,
    max=77,
    step=1,
    description='Keep Range:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='500px'),
    continuous_update=True
)

# Label to show what will be zeroed
zero_range_label = widgets.HTML(
    value="<b>Keeping all positions</b> (no zeroing)",
    layout=widgets.Layout(margin='5px 0')
)

params_box = widgets.VBox([scale_slider])

apply_button = widgets.Button(
    description='Apply Manipulation',
    button_style='warning'
)

manipulation_output = widgets.Output()

# Visualization section
visualize_button = widgets.Button(
    description='Visualize Before/After',
    button_style='info',
    icon='bar-chart'
)

visualization_output = widgets.Output()

# Save section
save_button = widgets.Button(
    description='Save Modified Embedding',
    button_style='primary'
)

save_output = widgets.Output()

# Connect callbacks
model_type_dropdown.observe(update_file_list, names='value')
manipulation_dropdown.observe(update_params, names='value')
zero_range_slider.observe(update_range_label, names='value')
load_button.on_click(load_embedding)
apply_button.on_click(apply_manipulation)
visualize_button.on_click(visualize_comparison)
save_button.on_click(save_modified)

# Initialize file list
update_file_list({'new': model_type_dropdown.value})

print("✓ Widgets created!")

In [None]:
# Display interface

# Create sections with headers
load_section = widgets.VBox([
    widgets.HTML("<h3>1. Load Embedding</h3>"),
    model_type_dropdown,
    embedding_file_dropdown,
    load_button,
    load_output
])

manipulation_section = widgets.VBox([
    widgets.HTML("<h3>2. Apply Manipulation</h3>"),
    manipulation_dropdown,
    params_box,
    apply_button,
    manipulation_output
])

visualization_section = widgets.VBox([
    widgets.HTML("<h3>3. Visualize Changes</h3>"),
    widgets.HTML("<p style='color: #666; margin-top: -10px;'>Compare original vs modified embedding weights</p>"),
    visualize_button,
    visualization_output
])

save_section = widgets.VBox([
    widgets.HTML("<h3>4. Save Modified Embedding</h3>"),
    save_button,
    save_output
])

# Display all sections
interface = widgets.VBox([
    load_section,
    widgets.HTML("<hr>"),
    manipulation_section,
    widgets.HTML("<hr>"),
    visualization_section,
    widgets.HTML("<hr>"),
    save_section
])

display(interface)