# T5 Embedding Manipulation for FLUX

This notebook lets you:
1. Generate T5 embeddings from text prompts
2. Save/load embeddings as JSON
3. Manually edit embedding values
4. Apply custom attention masks
5. Generate images with FLUX using modified embeddings

# üé® FLUX Dual Embedding Manipulation

This notebook now supports manipulating **BOTH** types of embeddings:

## 1. **T5 Embeddings** (512 x 4096)
- Detailed token-by-token representation
- Controls WHAT objects appear and WHERE
- Used in cross-attention

## 2. **CLIP Pooled Embeddings** (768)
- Global semantic summary
- Controls HOW things look (style, mood, aesthetic)
- Used in timestep conditioning

## Workflow:
1. Generate BOTH embeddings from a prompt
2. Save BOTH to JSON (they're now stored together)
3. Manipulate T5 embeddings (zero values, scale, etc.)
4. Manipulate CLIP embeddings (scale, add noise, etc.)
5. Generate images with different combinations:
   - Original T5 + Original CLIP
   - Modified T5 + Original CLIP
   - Original T5 + Modified CLIP  
   - Modified T5 + Modified CLIP

This lets you see the different effects of each embedding type!

## Installation

Run this cell first to install required packages:

In [1]:
import torch
import json
import numpy as np
from transformers import T5EncoderModel, T5Tokenizer
from diffusers import FluxPipeline
import ipywidgets as widgets
from IPython.display import display, Image as IPImage
from PIL import Image
import os
from pathlib import Path

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Create models directory
current_dir = Path.cwd()
MODELS_DIR = current_dir.parent / "data/models"
T5_MODEL_PATH = os.path.join(MODELS_DIR, "t5-v1_1-xxl")
FLUX_MODEL_PATH = os.path.join(MODELS_DIR, "FLUX.1-schnell")

os.makedirs(MODELS_DIR, exist_ok=True)
print(f"Models directory: {os.path.abspath(MODELS_DIR)}")
print(f"T5 path: {os.path.abspath(T5_MODEL_PATH)}")
print(f"FLUX path: {os.path.abspath(FLUX_MODEL_PATH)}")

Using device: cuda
Models directory: /shares/weddigen.ki.uzh/laura_wagner/latent_vandalism_workshopt/data/models
T5 path: /shares/weddigen.ki.uzh/laura_wagner/latent_vandalism_workshopt/data/models/t5-v1_1-xxl
FLUX path: /shares/weddigen.ki.uzh/laura_wagner/latent_vandalism_workshopt/data/models/FLUX.1-schnell


In [2]:
# Load Hugging Face token from file
from pathlib import Path

# Get the token file path
current_dir = Path.cwd()
token_file = current_dir.parent / "misc/credentials/hf.txt"

print(f"Looking for HF token at: {token_file}")

if token_file.exists():
    with open(token_file, 'r') as f:
        hf_token = f.read().strip()
    
    # Set the token as an environment variable
    os.environ['HF_TOKEN'] = hf_token
    
    # Also login using huggingface_hub
    from huggingface_hub import login
    login(token=hf_token)
    
    print("‚úì Hugging Face token loaded and authenticated successfully!")
else:
    print(f"‚ùå Token file not found at: {token_file}")
    print("Please create the file or update the path.")

Looking for HF token at: /shares/weddigen.ki.uzh/laura_wagner/latent_vandalism_workshopt/misc/credentials/hf.txt


Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


‚úì Hugging Face token loaded and authenticated successfully!


In [3]:
MODELS_DIR = current_dir.parent / "data/models"

In [4]:
# Load FLUX
try:
    if not os.path.exists(FLUX_MODEL_PATH):
        flux_pipe = FluxPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-schnell",
            torch_dtype=torch.bfloat16
        )
        flux_pipe.save_pretrained(FLUX_MODEL_PATH)
    else:
        flux_pipe = FluxPipeline.from_pretrained(
            FLUX_MODEL_PATH,
            torch_dtype=torch.bfloat16,
            local_files_only=True
        )
    
    flux_pipe = flux_pipe.to(device)
    
except Exception as e:
    print(f"Error loading FLUX: {e}")

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers


## 1. Load T5 Model (T5-XXL for FLUX)

We'll use `google/t5-v1_1-xxl` which is what FLUX uses. This produces 4096-dimensional embeddings.

**Note:** This is a large model (~11GB download). Make sure you have enough disk space and RAM.

In [5]:
# Load T5-XXL from local folder
print(f"Loading T5 model from: {T5_MODEL_PATH}...")

if not os.path.exists(T5_MODEL_PATH):
    print("\n‚ö†Ô∏è  Model not found locally. Downloading from Hugging Face...")
    print("This is a large model (~11GB) and will take several minutes.")
    print("Please be patient...\n")
    
    # Download and save to local folder
    tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-xxl")
    t5_model = T5EncoderModel.from_pretrained(
        "google/t5-v1_1-xxl",
        torch_dtype=torch.bfloat16  # Use bfloat16 to match FLUX
    )
    
    # Save to local folder
    print(f"Saving model to {T5_MODEL_PATH}...")
    tokenizer.save_pretrained(T5_MODEL_PATH)
    t5_model.save_pretrained(T5_MODEL_PATH)
    print("‚úì Model downloaded and saved locally!\n")
else:
    print("‚úì Loading from local folder...\n")

# Load from local folder
tokenizer = T5Tokenizer.from_pretrained(T5_MODEL_PATH, local_files_only=True)
t5_model = T5EncoderModel.from_pretrained(
    T5_MODEL_PATH,
    torch_dtype=torch.bfloat16,  # Use bfloat16 to match FLUX
    local_files_only=True
).to(device)

t5_model.eval()  # Set to evaluation mode

print(f"‚úì T5-XXL loaded successfully!")
print(f"  Embedding dimension: {t5_model.config.d_model}")
print(f"  Max sequence length: {tokenizer.model_max_length}")
print(f"  Loaded from: {T5_MODEL_PATH}")
print(f"  Model dtype: {next(t5_model.parameters()).dtype}")

Loading T5 model from: /shares/weddigen.ki.uzh/laura_wagner/latent_vandalism_workshopt/data/models/t5-v1_1-xxl...
‚úì Loading from local folder...



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

‚úì T5-XXL loaded successfully!
  Embedding dimension: 4096
  Max sequence length: 512
  Loaded from: /shares/weddigen.ki.uzh/laura_wagner/latent_vandalism_workshopt/data/models/t5-v1_1-xxl
  Model dtype: torch.bfloat16


## 2. Text Input Widget and Embedding Generation

In [7]:
# Create text input widget
prompt_input = widgets.Textarea(
    value='a red cat sitting on a blue table',
    placeholder='Enter your prompt here',
    description='Prompt:',
    layout=widgets.Layout(width='80%', height='80px')
)

generate_button = widgets.Button(
    description='Generate Embedding',
    button_style='success'
)

output_area = widgets.Output()

# Global variable to store current embedding
current_embedding = None
current_tokens = None

def generate_embedding(b):
    global current_embedding, current_tokens
    
    with output_area:
        output_area.clear_output()
        
        prompt = prompt_input.value
        print(f"Generating embedding for: '{prompt}'\n")
        
        # Tokenize
        tokens = tokenizer(
            prompt,
            padding="max_length",
            max_length=512,
            truncation=True,
            return_tensors="pt"
        )
        
        # Get token strings for display
        token_ids = tokens['input_ids'][0].tolist()
        token_strings = [tokenizer.decode([tid]) for tid in token_ids]
        
        # Find how many real tokens (non-padding)
        num_real_tokens = (tokens['input_ids'][0] != tokenizer.pad_token_id).sum().item()
        
        print(f"Tokenized into {num_real_tokens} real tokens (+ {512 - num_real_tokens} padding):")
        print("First 10 tokens:", token_strings[:10])
        print()
        
        # Generate embedding
        with torch.no_grad():
            tokens = {k: v.to(device) for k, v in tokens.items()}
            outputs = t5_model(**tokens)
            embedding = outputs.last_hidden_state  # Shape: [1, 512, embedding_dim]
        
        current_embedding = embedding.float().cpu().numpy()[0]  # Shape: [512, embedding_dim]
        current_tokens = token_strings
        
        embedding_dim = current_embedding.shape[1]
        total_numbers = current_embedding.shape[0] * current_embedding.shape[1]
        
        print(f"‚úì Embedding generated!")
        print(f"  Shape: {current_embedding.shape}")
        print(f"  Total numbers: {total_numbers:,}")
        print(f"  Size: {current_embedding.nbytes / 1024:.2f} KB")
        print()
        print(f"First token '{token_strings[0]}' embedding (first 10 values):")
        print(current_embedding[0, :10])

generate_button.on_click(generate_embedding)

display(prompt_input, generate_button, output_area)

Textarea(value='a red cat sitting on a blue table', description='Prompt:', layout=Layout(height='80px', width=‚Ä¶

Button(button_style='success', description='Generate Embedding', style=ButtonStyle())

Output()

## 3. Save Embedding to JSON

In [8]:
def save_embedding(filename="embedding.json"):
    if current_embedding is None:
        print("‚ùå No embedding to save! Generate one first.")
        return
    
    data = {
        "embedding": current_embedding.tolist(),
        "tokens": current_tokens,
        "shape": list(current_embedding.shape),
        "prompt": prompt_input.value
    }
    
    with open(filename, 'w') as f:
        json.dump(data, f)
    
    file_size = os.path.getsize(filename) / (1024 * 1024)
    print(f"‚úì Embedding saved to '{filename}' ({file_size:.2f} MB)")

# Save button
save_button = widgets.Button(description='Save Embedding', button_style='info')
save_output = widgets.Output()

def on_save_click(b):
    with save_output:
        save_output.clear_output()
        save_embedding()

save_button.on_click(on_save_click)
display(save_button, save_output)

Button(button_style='info', description='Save Embedding', style=ButtonStyle())

Output()

## 4. Load Embedding from JSON

In [9]:
def load_embedding(filename="embedding.json"):
    global current_embedding, current_tokens
    
    with open(filename, 'r') as f:
        data = json.load(f)
    
    current_embedding = np.array(data['embedding'])
    current_tokens = data['tokens']
    
    print(f"‚úì Embedding loaded from '{filename}'")
    print(f"  Original prompt: {data['prompt']}")
    print(f"  Shape: {current_embedding.shape}")
    print(f"  First token: '{current_tokens[0]}'")

# Load button
load_button = widgets.Button(description='Load Embedding', button_style='warning')
load_output = widgets.Output()

def on_load_click(b):
    with load_output:
        load_output.clear_output()
        try:
            load_embedding()
        except FileNotFoundError:
            print("‚ùå File 'embedding.json' not found. Save an embedding first.")

load_button.on_click(on_load_click)
display(load_button, load_output)



Output()

## 5. Manual Embedding Manipulation

Example: Zero out a percentage of values

In [10]:
def manipulate_embedding(zero_percentage=0.3):
    global current_embedding
    
    if current_embedding is None:
        print("‚ùå No embedding loaded!")
        return
    
    # Create a copy
    modified = current_embedding.copy()
    
    # Randomly zero out a percentage of values
    total_values = modified.size
    num_zeros = int(total_values * zero_percentage)
    
    # Random indices
    flat_modified = modified.flatten()
    zero_indices = np.random.choice(total_values, num_zeros, replace=False)
    flat_modified[zero_indices] = 0.0
    
    modified = flat_modified.reshape(current_embedding.shape)
    
    print(f"‚úì Zeroed out {num_zeros:,} values ({zero_percentage*100}%)")
    print(f"  Original non-zero: {np.count_nonzero(current_embedding):,}")
    print(f"  Modified non-zero: {np.count_nonzero(modified):,}")
    
    return modified

# Manipulation widget
zero_slider = widgets.FloatSlider(
    value=0.3,
    min=0.0,
    max=1.0,
    step=0.05,
    description='Zero %:',
    readout_format='.0%'
)

manipulate_button = widgets.Button(description='Apply Manipulation', button_style='danger')
manipulate_output = widgets.Output()

modified_embedding = None

def on_manipulate_click(b):
    global modified_embedding
    with manipulate_output:
        manipulate_output.clear_output()
        modified_embedding = manipulate_embedding(zero_slider.value)

manipulate_button.on_click(on_manipulate_click)
display(zero_slider, manipulate_button, manipulate_output)

FloatSlider(value=0.3, description='Zero %:', max=1.0, readout_format='.0%', step=0.05)

Button(button_style='danger', description='Apply Manipulation', style=ButtonStyle())

Output()

In [11]:
# Save modified embedding to JSON
def save_modified_embedding(filename="modified_embedding.json"):
    if modified_embedding is None:
        print("‚ùå No modified embedding to save! Apply a manipulation first.")
        return
    
    data = {
        "embedding": modified_embedding.tolist(),
        "tokens": current_tokens,
        "shape": list(modified_embedding.shape),
        "prompt": prompt_input.value
    }
    
    with open(filename, 'w') as f:
        json.dump(data, f)
    
    file_size = os.path.getsize(filename) / (1024 * 1024)
    print(f"‚úì Modified embedding saved to '{filename}' ({file_size:.2f} MB)")

# Save modified embedding button
save_modified_button = widgets.Button(description='Save Modified Embedding', button_style='info')
save_modified_output = widgets.Output()

def on_save_modified_click(b):
    with save_modified_output:
        save_modified_output.clear_output()
        save_modified_embedding()

save_modified_button.on_click(on_save_modified_click)
display(save_modified_button, save_modified_output)

Button(button_style='info', description='Save Modified Embedding', style=ButtonStyle())

Output()

In [12]:
# Load modified embedding from JSON
def load_modified_embedding(filename="modified_embedding.json"):
    global modified_embedding, current_tokens
    
    with open(filename, 'r') as f:
        data = json.load(f)
    
    modified_embedding = np.array(data['embedding'])
    current_tokens = data['tokens']
    
    print(f"‚úì Modified embedding loaded from '{filename}'")
    print(f"  Original prompt: {data['prompt']}")
    print(f"  Shape: {modified_embedding.shape}")
    print(f"  Non-zero values: {np.count_nonzero(modified_embedding):,}")

# Load modified embedding button
load_modified_button = widgets.Button(description='Load Modified Embedding', button_style='warning')
load_modified_output = widgets.Output()

def on_load_modified_click(b):
    with load_modified_output:
        load_modified_output.clear_output()
        try:
            load_modified_embedding()
        except FileNotFoundError:
            print("‚ùå File 'modified_embedding.json' not found.")
        except Exception as e:
            print(f"‚ùå Error loading: {e}")

load_modified_button.on_click(on_load_modified_click)
display(load_modified_button, load_modified_output)



Output()

## 6. Manual Attention Masking

Control which tokens get how much attention

In [13]:
# Display tokens and create sliders
def create_attention_sliders(num_tokens=10):
    if current_tokens is None:
        print("‚ùå Generate an embedding first!")
        return None
    
    # Find non-padding tokens
    real_tokens = []
    for i, token in enumerate(current_tokens):
        if token.strip() and token != '<pad>' and i < num_tokens:
            real_tokens.append((i, token))
    
    print(f"Creating sliders for first {len(real_tokens)} tokens:")
    print(real_tokens)
    print()
    
    sliders = []
    for idx, token in real_tokens:
        slider = widgets.FloatSlider(
            value=1.0 / len(real_tokens),  # Equal distribution
            min=0.0,
            max=1.0,
            step=0.01,
            description=f'{idx}: {token[:15]}',
            readout_format='.0%',
            layout=widgets.Layout(width='400px')
        )
        sliders.append(slider)
    
    # Normalize button
    normalize_button = widgets.Button(
        description='Normalize to 100%',
        button_style='info'
    )
    
    total_label = widgets.Label(value='Total: 100%')
    
    def update_total(*args):
        total = sum(s.value for s in sliders)
        total_label.value = f'Total: {total*100:.1f}%'
        if abs(total - 1.0) > 0.01:
            total_label.value += ' ‚ö†Ô∏è Should sum to 100%'
    
    def normalize(*args):
        total = sum(s.value for s in sliders)
        if total > 0:
            for s in sliders:
                s.value = s.value / total
        update_total()
    
    for slider in sliders:
        slider.observe(update_total, 'value')
    
    normalize_button.on_click(normalize)
    
    update_total()
    
    return sliders, normalize_button, total_label, real_tokens

attention_controls = create_attention_sliders(num_tokens=10)

if attention_controls:
    sliders, normalize_btn, total_lbl, real_tokens = attention_controls
    display(widgets.VBox(sliders + [normalize_btn, total_lbl]))

Creating sliders for first 8 tokens:
[(1, 'a'), (2, 'red'), (3, 'cat'), (4, 'sitting'), (5, 'on'), (7, 'a'), (8, 'blue'), (9, 'table')]



VBox(children=(FloatSlider(value=0.125, description='1: a', layout=Layout(width='400px'), max=1.0, readout_for‚Ä¶

## 7. Create Attention Mask Array

In [14]:
def create_attention_mask():
    if attention_controls is None:
        print("‚ùå Create attention sliders first!")
        return None
    
    sliders, _, _, real_tokens = attention_controls
    
    # Create mask array (512 tokens)
    mask = np.zeros(512)
    
    # Set weights from sliders
    for slider, (idx, token) in zip(sliders, real_tokens):
        mask[idx] = slider.value
    
    print("Attention Mask created:")
    print(f"Non-zero weights: {np.count_nonzero(mask)}")
    for slider, (idx, token) in zip(sliders, real_tokens):
        print(f"  Token {idx} '{token}': {slider.value*100:.1f}%")
    
    return mask

mask_button = widgets.Button(description='Create Mask', button_style='success')
mask_output = widgets.Output()

attention_mask = None

def on_mask_click(b):
    global attention_mask
    with mask_output:
        mask_output.clear_output()
        attention_mask = create_attention_mask()

mask_button.on_click(on_mask_click)
display(mask_button, mask_output)

Button(button_style='success', description='Create Mask', style=ButtonStyle())

Output()

## Generate Image from embedding

In [15]:
def generate_image_from_embedding(embedding_array, output_filename="generated_from_embedding.png", seed=42):    """    Generate image using custom T5 embedding by injecting it into FLUX pipeline.    """    if 'flux_pipe' not in globals():        print("‚ùå FLUX not loaded!")        return None        if embedding_array is None:        print("‚ùå No embedding provided! Load an embedding first.")        return None        print(f"Generating image from custom embedding...")    print(f"  Embedding shape: {embedding_array.shape}")    print(f"  Expected shape: [512, 4096]")        # Convert numpy to torch tensor with bfloat16 to match FLUX    # First ensure numpy array is float32, then convert to bfloat16    embedding_tensor = torch.from_numpy(embedding_array.astype(np.float32)).to(        device=device,        dtype=torch.bfloat16    )        # Add batch dimension: [512, 4096] -> [1, 512, 4096]    embedding_tensor = embedding_tensor.unsqueeze(0)        print(f"  Tensor shape: {embedding_tensor.shape}")    print(f"  Device: {embedding_tensor.device}")    print(f"  Dtype: {embedding_tensor.dtype}")        # Generate pooled embeddings using CLIP text encoder    print("\nGenerating pooled embeddings from CLIP...")    try:        # Use the original prompt to get pooled embeddings from CLIP        prompt = prompt_input.value                # Use FLUX's built-in encode_prompt method to get the correct pooled embeddings        # This ensures we get exactly what FLUX expects        with torch.no_grad():            (                _,  # We already have T5 embeddings                pooled_embeds,  # This is what we need from CLIP                _,  # text_ids            ) = flux_pipe.encode_prompt(                prompt=prompt,                prompt_2=None,                device=device,                num_images_per_prompt=1,                prompt_embeds=None,                pooled_prompt_embeds=None,                max_sequence_length=512,            )                print(f"  Pooled embeddings shape: {pooled_embeds.shape}")        print(f"  Pooled embeddings dtype: {pooled_embeds.dtype}")        except Exception as e:        print(f"‚ùå Error generating pooled embeddings: {e}")        import traceback        traceback.print_exc()        return None        # Generate image using both embeddings    try:        print("\nRunning diffusion process (4 steps)...")        image = flux_pipe(            prompt_embeds=embedding_tensor,            pooled_prompt_embeds=pooled_embeds,            num_inference_steps=4,            guidance_scale=0.0,            height=1024,            width=1024,            generator=torch.manual_seed(seed)        ).images[0]                image.save(output_filename)        print(f"\n‚úì Image generated and saved to '{output_filename}'")                return image            except Exception as e:        print(f"‚ùå Error generating image: {e}")        print("\nThis might happen if:")        print("  - Embedding dimensions don't match")        print("  - Insufficient VRAM/RAM")        print("  - dtype mismatch (check that tensors are bfloat16)")        import traceback        traceback.print_exc()        return None# Generate from current embedding buttongenerate_from_embedding_button = widgets.Button(    description='Generate from Current Embedding',    button_style='primary',    layout=widgets.Layout(width='300px'))# Generate from modified embedding buttongenerate_from_modified_button = widgets.Button(    description='Generate from Modified Embedding',    button_style='warning',    layout=widgets.Layout(width='300px'))generation_output = widgets.Output()def on_generate_from_embedding(b):    with generation_output:        generation_output.clear_output(wait=True)        if current_embedding is not None:            image = generate_image_from_embedding(                current_embedding,                output_filename="image_original_embedding.png"            )            if image:                display(image)        else:            print("‚ùå No embedding loaded! Generate or load an embedding first.")def on_generate_from_modified(b):    with generation_output:        generation_output.clear_output(wait=True)        if modified_embedding is not None:            image = generate_image_from_embedding(                modified_embedding,                output_filename="image_modified_embedding.png"            )            if image:                display(image)        else:            print("‚ùå No modified embedding! Apply a manipulation first or load one.")generate_from_embedding_button.on_click(on_generate_from_embedding)generate_from_modified_button.on_click(on_generate_from_modified)display(widgets.VBox([    widgets.HTML("<h3>Generate Images from Embeddings</h3>"),    generate_from_embedding_button,    generate_from_modified_button]), generation_output)

VBox(children=(HTML(value='<h3>Generate Images from Embeddings</h3>'), Button(button_style='primary', descript‚Ä¶

Output()

## 10. Compare Original vs Modified Embeddings

Generate both images and view them side-by-side to see the effect of your modifications.

In [None]:
def compare_embeddings(seed=42):
    """
    Generate images from both original and modified embeddings and display side-by-side.
    """
    if 'flux_pipe' not in globals():
        print("‚ùå FLUX not loaded!")
        return
    
    if current_embedding is None:
        print("‚ùå No original embedding! Generate or load an embedding first.")
        return
    
    if modified_embedding is None:
        print("‚ùå No modified embedding! Apply a manipulation first.")
        return
    
    print("="*60)
    print("EMBEDDING COMPARISON")
    print("="*60)
    
    # Show embedding statistics
    print("\nORIGINAL EMBEDDING:")
    print(f"  Non-zero values: {np.count_nonzero(current_embedding):,}")
    print(f"  Mean: {current_embedding.mean():.4f}")
    print(f"  Std: {current_embedding.std():.4f}")
    print(f"  Min: {current_embedding.min():.4f}")
    print(f"  Max: {current_embedding.max():.4f}")
    
    print("\nMODIFIED EMBEDDING:")
    print(f"  Non-zero values: {np.count_nonzero(modified_embedding):,}")
    print(f"  Mean: {modified_embedding.mean():.4f}")
    print(f"  Std: {modified_embedding.std():.4f}")
    print(f"  Min: {modified_embedding.min():.4f}")
    print(f"  Max: {modified_embedding.max():.4f}")
    
    diff = np.abs(current_embedding - modified_embedding).sum()
    print(f"\nTOTAL ABSOLUTE DIFFERENCE: {diff:,.2f}")
    print(f"Percentage changed: {(np.count_nonzero(current_embedding != modified_embedding) / current_embedding.size * 100):.2f}%")
    
    print("\n" + "="*60)
    print("GENERATING IMAGES...")
    print("="*60 + "\n")
    
    # Generate from original
    print("[1/2] Generating from ORIGINAL embedding...")
    img1 = generate_image_from_embedding(
        current_embedding,
        output_filename="comparison_original.png",
        seed=seed
    )
    
    if img1 is None:
        return
    
    print("\n[2/2] Generating from MODIFIED embedding...")
    img2 = generate_image_from_embedding(
        modified_embedding,
        output_filename="comparison_modified.png",
        seed=seed
    )
    
    if img2 is None:
        return
    
    # Display side by side
    print("\n" + "="*60)
    print("COMPARISON RESULTS")
    print("="*60 + "\n")
    
    import matplotlib.pyplot as plt
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 8))
    
    axes[0].imshow(img1)
    axes[0].set_title('Original Embedding', fontsize=14, fontweight='bold')
    axes[0].axis('off')
    
    axes[1].imshow(img2)
    axes[1].set_title('Modified Embedding', fontsize=14, fontweight='bold')
    axes[1].axis('off')
    
    plt.suptitle(f'Prompt: "{prompt_input.value}"', fontsize=12, y=0.98)
    plt.tight_layout()
    plt.savefig("comparison_sidebyside.png", dpi=150, bbox_inches='tight')
    plt.show()
    
    print("‚úì Comparison complete!")
    print("\nFiles saved:")
    print("  - comparison_original.png")
    print("  - comparison_modified.png")
    print("  - comparison_sidebyside.png")

# Comparison button
compare_button = widgets.Button(
    description='Compare Original vs Modified',
    button_style='success',
    layout=widgets.Layout(width='300px', height='50px')
)

compare_output = widgets.Output()

def on_compare_click(b):
    with compare_output:
        compare_output.clear_output(wait=True)
        compare_embeddings()

compare_button.on_click(on_compare_click)
display(widgets.VBox([
    widgets.HTML("<h3>‚ö° Generate Comparison</h3><p>This will generate images from both embeddings with the same random seed.</p>"),
    compare_button
]), compare_output)

## 11. Summary

What you've learned:

1. ‚úì Generate T5-XXL embeddings from text (512 tokens √ó 4096 dimensions = 2,097,152 numbers)
2. ‚úì Save/load embeddings as JSON files
3. ‚úì Manually manipulate embedding values
4. ‚úì Create custom attention masks with sliders
5. ‚úì Generate images directly from custom embeddings using FLUX
6. ‚úì Compare original vs modified embeddings side-by-side
7. ‚úì Understand the complete pipeline: text ‚Üí T5-XXL ‚Üí embeddings ‚Üí FLUX ‚Üí image

## Workflow:

1. **Generate embedding** from your text prompt
2. **Save to JSON** for backup
3. **Apply manipulations** (zero out values, etc.)
4. **Create attention masks** (optional - for future use)
5. **Generate images** from both original and modified embeddings
6. **Compare results** to see what your changes did!

## Next Steps:

- Experiment with different zero percentages (10%, 30%, 50%, 90%)
- Try zeroing specific token embeddings instead of random values
- Compare results with different prompts
- Implement attention mask injection (requires modifying cross-attention layers)
- Explore interpolating between two different embeddings
- Try adding noise to embeddings and see the effect

In [None]:
# Clean up cache to free home directory space
import shutil
import os
from pathlib import Path

def cleanup_cache():
    """Delete cache directories to free up home quota"""
    cache_dirs = [
        Path.home() / "data" / ".cache",
        Path.home() / ".cache" / "huggingface",
        Path.home() / ".cache" / "torch",
        Path.home() / ".cache" / "uv",
    ]
    
    total_freed = 0
    for cache_dir in cache_dirs:
        if cache_dir.exists():
            try:
                # Get size before deletion (approximate)
                size = sum(f.stat().st_size for f in cache_dir.rglob('*') if f.is_file())
                shutil.rmtree(cache_dir)
                total_freed += size
                print(f"\u2713 Deleted {cache_dir}: {size / 1024**3:.2f} GB")
            except Exception as e:
                print(f"\u2717 Could not delete {cache_dir}: {e}")
    
    print(f"\nTotal freed: {total_freed / 1024**3:.2f} GB")

# Run cleanup
cleanup_cache()