# T5 Embedding Manipulation for FLUX

This notebook lets you:
1. Generate T5 embeddings from text prompts
2. Save/load embeddings as JSON
3. Manually edit embedding values
4. Apply custom attention masks
5. Generate images with FLUX using modified embeddings

## Installation

Run this cell first to install required packages:

In [1]:
import torch
import json
import numpy as np
from transformers import T5EncoderModel, T5Tokenizer
from diffusers import FluxPipeline
import ipywidgets as widgets
from IPython.display import display, Image as IPImage
from PIL import Image
import os
from pathlib import Path

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Create models directory
current_dir = Path.cwd()
MODELS_DIR = current_dir.parent / "data/models"
T5_MODEL_PATH = os.path.join(MODELS_DIR, "t5-v1_1-xxl")
CLIP_MODEL_PATH = os.path.join(MODELS_DIR, "clip")
FLUX_MODEL_PATH = os.path.join(MODELS_DIR, "FLUX.1-schnell")

os.makedirs(MODELS_DIR, exist_ok=True)
print(f"Models directory: {os.path.abspath(MODELS_DIR)}")
print(f"T5 path: {os.path.abspath(T5_MODEL_PATH)}")
print(f"FLUX path: {os.path.abspath(FLUX_MODEL_PATH)}")

Using device: cuda
Models directory: /shares/weddigen.ki.uzh/laura_wagner/latent_vandalism_workshopt/data/models
T5 path: /shares/weddigen.ki.uzh/laura_wagner/latent_vandalism_workshopt/data/models/t5-v1_1-xxl
FLUX path: /shares/weddigen.ki.uzh/laura_wagner/latent_vandalism_workshopt/data/models/FLUX.1-schnell


In [2]:
# Load Hugging Face token from file
from pathlib import Path

# Get the token file path
current_dir = Path.cwd()
token_file = current_dir.parent / "misc/credentials/hf.txt"

print(f"Looking for HF token at: {token_file}")

if token_file.exists():
    with open(token_file, 'r') as f:
        hf_token = f.read().strip()
    
    # Set the token as an environment variable
    os.environ['HF_TOKEN'] = hf_token
    
    # Also login using huggingface_hub
    from huggingface_hub import login
    login(token=hf_token)
    
    print("✓ Hugging Face token loaded and authenticated successfully!")
else:
    print(f"❌ Token file not found at: {token_file}")
    print("Please create the file or update the path.")

Looking for HF token at: /shares/weddigen.ki.uzh/laura_wagner/latent_vandalism_workshopt/misc/credentials/hf.txt


Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


✓ Hugging Face token loaded and authenticated successfully!


In [3]:
MODELS_DIR = current_dir.parent / "data/models"

In [4]:
# Load FLUX
try:
    if not os.path.exists(FLUX_MODEL_PATH):
        flux_pipe = FluxPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-schnell",
            torch_dtype=torch.bfloat16
        )
        flux_pipe.save_pretrained(FLUX_MODEL_PATH)
    else:
        flux_pipe = FluxPipeline.from_pretrained(
            FLUX_MODEL_PATH,
            torch_dtype=torch.bfloat16,
            local_files_only=True
        )
    
    flux_pipe = flux_pipe.to(device)
    
except Exception as e:
    print(f"Error loading FLUX: {e}")

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [34]:
# Save modified embedding to JSON
def save_modified_embedding(filename="modified_embedding.json"):
    if modified_embedding is None:
        print("❌ No modified embedding to save! Apply a manipulation first.")
        return
    
    data = {
        "embedding": modified_embedding.tolist(),
        "tokens": current_tokens,
        "shape": list(modified_embedding.shape),
        "prompt": prompt_input.value
    }
    
    with open(filename, 'w') as f:
        json.dump(data, f)
    
    file_size = os.path.getsize(filename) / (1024 * 1024)
    print(f"✓ Modified embedding saved to '{filename}' ({file_size:.2f} MB)")

# Save modified embedding button
save_modified_button = widgets.Button(description='Save Modified Embedding', button_style='info')
save_modified_output = widgets.Output()

def on_save_modified_click(b):
    with save_modified_output:
        save_modified_output.clear_output()
        save_modified_embedding()

save_modified_button.on_click(on_save_modified_click)
display(save_modified_button, save_modified_output)

Button(button_style='info', description='Save Modified Embedding', style=ButtonStyle())

Output()

In [30]:
# Load modified embedding from JSON
def load_modified_embedding(filename="modified_embedding.json"):
    global modified_embedding, current_tokens
    
    with open(filename, 'r') as f:
        data = json.load(f)
    
    modified_embedding = np.array(data['embedding'])
    current_tokens = data['tokens']
    
    print(f"✓ Modified embedding loaded from '{filename}'")
    print(f"  Original prompt: {data['prompt']}")
    print(f"  Shape: {modified_embedding.shape}")
    print(f"  Non-zero values: {np.count_nonzero(modified_embedding):,}")

# Load modified embedding button
load_modified_button = widgets.Button(description='Load Modified Embedding', button_style='warning')
load_modified_output = widgets.Output()

def on_load_modified_click(b):
    with load_modified_output:
        load_modified_output.clear_output()
        try:
            load_modified_embedding()
        except FileNotFoundError:
            print("❌ File 'modified_embedding.json' not found.")
        except Exception as e:
            print(f"❌ Error loading: {e}")

load_modified_button.on_click(on_load_modified_click)
display(load_modified_button, load_modified_output)



Output()

## 6. Manual Attention Masking

Control which tokens get how much attention

In [None]:
# Display tokens and create sliders
def create_attention_sliders(num_tokens=10):
    if current_tokens is None:
        print("❌ Generate an embedding first!")
        return None
    
    # Find non-padding tokens
    real_tokens = []
    for i, token in enumerate(current_tokens):
        if token.strip() and token != '<pad>' and i < num_tokens:
            real_tokens.append((i, token))
    
    print(f"Creating sliders for first {len(real_tokens)} tokens:")
    print(real_tokens)
    print()
    
    sliders = []
    for idx, token in real_tokens:
        slider = widgets.FloatSlider(
            value=1.0 / len(real_tokens),  # Equal distribution
            min=0.0,
            max=1.0,
            step=0.01,
            description=f'{idx}: {token[:15]}',
            readout_format='.0%',
            layout=widgets.Layout(width='400px')
        )
        sliders.append(slider)
    
    # Normalize button
    normalize_button = widgets.Button(
        description='Normalize to 100%',
        button_style='info'
    )
    
    total_label = widgets.Label(value='Total: 100%')
    
    def update_total(*args):
        total = sum(s.value for s in sliders)
        total_label.value = f'Total: {total*100:.1f}%'
        if abs(total - 1.0) > 0.01:
            total_label.value += ' ⚠️ Should sum to 100%'
    
    def normalize(*args):
        total = sum(s.value for s in sliders)
        if total > 0:
            for s in sliders:
                s.value = s.value / total
        update_total()
    
    for slider in sliders:
        slider.observe(update_total, 'value')
    
    normalize_button.on_click(normalize)
    
    update_total()
    
    return sliders, normalize_button, total_label, real_tokens

attention_controls = create_attention_sliders(num_tokens=10)

if attention_controls:
    sliders, normalize_btn, total_lbl, real_tokens = attention_controls
    display(widgets.VBox(sliders + [normalize_btn, total_lbl]))

## 7. Create Attention Mask Array

In [None]:
def create_attention_mask():
    if attention_controls is None:
        print("❌ Create attention sliders first!")
        return None
    
    sliders, _, _, real_tokens = attention_controls
    
    # Create mask array (512 tokens)
    mask = np.zeros(512)
    
    # Set weights from sliders
    for slider, (idx, token) in zip(sliders, real_tokens):
        mask[idx] = slider.value
    
    print("Attention Mask created:")
    print(f"Non-zero weights: {np.count_nonzero(mask)}")
    for slider, (idx, token) in zip(sliders, real_tokens):
        print(f"  Token {idx} '{token}': {slider.value*100:.1f}%")
    
    return mask

mask_button = widgets.Button(description='Create Mask', button_style='success')
mask_output = widgets.Output()

attention_mask = None

def on_mask_click(b):
    global attention_mask
    with mask_output:
        mask_output.clear_output()
        attention_mask = create_attention_mask()

mask_button.on_click(on_mask_click)
display(mask_button, mask_output)

### Load CLIP Model

In [19]:
from transformers import CLIPTextModel, CLIPTokenizer
import torch
import os

# Define CLIP model path
CLIP_MODEL_PATH = "./models/clip-vit-large-patch14"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load CLIP from local folder
print(f"Loading CLIP model from: {CLIP_MODEL_PATH}...")
if not os.path.exists(CLIP_MODEL_PATH):
    print("\n⚠️  Model not found locally. Downloading from Hugging Face...")
    print("This model is ~1.7GB and will take a few minutes.")
    print("Please be patient...\n")
    
    # Download and save to local folder
    clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
    clip_model = CLIPTextModel.from_pretrained(
        "openai/clip-vit-large-patch14",
        torch_dtype=torch.bfloat16  # Use bfloat16 to match FLUX
    )
    
    # Save to local folder
    print(f"Saving model to {CLIP_MODEL_PATH}...")
    clip_tokenizer.save_pretrained(CLIP_MODEL_PATH)
    clip_model.save_pretrained(CLIP_MODEL_PATH)
    print("✓ Model downloaded and saved locally!\n")
else:
    print("✓ Loading from local folder...\n")

# Load from local folder
clip_tokenizer = CLIPTokenizer.from_pretrained(CLIP_MODEL_PATH, local_files_only=True)
clip_model = CLIPTextModel.from_pretrained(
    CLIP_MODEL_PATH,
    torch_dtype=torch.bfloat16,  # Use bfloat16 to match FLUX
    local_files_only=True
).to(device)
clip_model.eval()  # Set to evaluation mode

print(f"✓ CLIP loaded successfully!")
print(f"  Embedding dimension: {clip_model.config.hidden_size}")
print(f"  Max sequence length: {clip_tokenizer.model_max_length}")
print(f"  Loaded from: {CLIP_MODEL_PATH}")
print(f"  Model dtype: {next(clip_model.parameters()).dtype}")

Loading CLIP model from: ./models/clip-vit-large-patch14...
✓ Loading from local folder...

✓ CLIP loaded successfully!
  Embedding dimension: 768
  Max sequence length: 77
  Loaded from: ./models/clip-vit-large-patch14
  Model dtype: torch.bfloat16


In [23]:
import json

def save_clip_embedding(filename="clip_embedding.json"):
    if current_clip_embedding is None:
        print("❌ No CLIP embedding to save! Generate one first.")
        return
    
    data = {
        "embedding": current_clip_embedding.tolist(),
        "tokens": current_clip_tokens,
        "shape": list(current_clip_embedding.shape),
        "prompt": clip_prompt_input.value
    }
    
    with open(filename, 'w') as f:
        json.dump(data, f)
    
    file_size = os.path.getsize(filename) / (1024 * 1024)
    print(f"✓ CLIP embedding saved to '{filename}' ({file_size:.2f} MB)")

# Save button
clip_save_button = widgets.Button(description='Save CLIP Embedding', button_style='info')
clip_save_output = widgets.Output()

def on_clip_save_click(b):
    with clip_save_output:
        clip_save_output.clear_output()
        save_clip_embedding()

clip_save_button.on_click(on_clip_save_click)
display(clip_save_button, clip_save_output)

Button(button_style='info', description='Save CLIP Embedding', style=ButtonStyle())

Output()

## Generate Image from embedding

In [None]:
def generate_image_from_embedding(embedding_array, output_filename="generated_from_embedding.png", seed=42):
    """
    Generate image using custom T5 embedding by injecting it into FLUX pipeline.
    """
    if 'flux_pipe' not in globals():
        print("❌ FLUX not loaded!")
        return None
    
    if embedding_array is None:
        print("❌ No embedding provided! Load an embedding first.")
        return None
    
    print(f"Generating image from custom embedding...")
    print(f"  Embedding shape: {embedding_array.shape}")
    print(f"  Expected shape: [512, 4096]")
    
    # Convert numpy to torch tensor with bfloat16 to match FLUX
    # First ensure numpy array is float32, then convert to bfloat16
    embedding_tensor = torch.from_numpy(embedding_array.astype(np.float32)).to(
        device=device,
        dtype=torch.bfloat16
    )
    
    # Add batch dimension: [512, 4096] -> [1, 512, 4096]
    embedding_tensor = embedding_tensor.unsqueeze(0)
    
    print(f"  Tensor shape: {embedding_tensor.shape}")
    print(f"  Device: {embedding_tensor.device}")
    print(f"  Dtype: {embedding_tensor.dtype}")
    
    # Generate pooled embeddings using CLIP text encoder
    print("\nGenerating pooled embeddings from CLIP...")
    try:
        # Use the original prompt to get pooled embeddings from CLIP
        prompt = prompt_input.value
        
        # Use FLUX's built-in encode_prompt method to get the correct pooled embeddings
        # This ensures we get exactly what FLUX expects
        with torch.no_grad():
            (
                _,  # We already have T5 embeddings
                pooled_embeds,  # This is what we need from CLIP
                _,  # text_ids
            ) = flux_pipe.encode_prompt(
                prompt=prompt,
                prompt_2=None,
                device=device,
                num_images_per_prompt=1,
                prompt_embeds=None,
                pooled_prompt_embeds=None,
                max_sequence_length=512,
            )
        
        print(f"  Pooled embeddings shape: {pooled_embeds.shape}")
        print(f"  Pooled embeddings dtype: {pooled_embeds.dtype}")
    
    except Exception as e:
        print(f"❌ Error generating pooled embeddings: {e}")
        import traceback
        traceback.print_exc()
        return None
    
    # Generate image using both embeddings
    try:
        print("\nRunning diffusion process (4 steps)...")
        image = flux_pipe(
            prompt_embeds=embedding_tensor,
            pooled_prompt_embeds=pooled_embeds,
            num_inference_steps=4,
            guidance_scale=0.0,
            height=1024,
            width=1024,
            generator=torch.manual_seed(seed)
        ).images[0]
        
        image.save(output_filename)
        print(f"\n✓ Image generated and saved to '{output_filename}'")
        
        return image
        
    except Exception as e:
        print(f"❌ Error generating image: {e}")
        print("\nThis might happen if:")
        print("  - Embedding dimensions don't match")
        print("  - Insufficient VRAM/RAM")
        print("  - dtype mismatch (check that tensors are bfloat16)")
        import traceback
        traceback.print_exc()
        return None

# Generate from current embedding button
generate_from_embedding_button = widgets.Button(
    description='Generate from Current Embedding',
    button_style='primary',
    layout=widgets.Layout(width='300px')
)

# Generate from modified embedding button
generate_from_modified_button = widgets.Button(
    description='Generate from Modified Embedding',
    button_style='warning',
    layout=widgets.Layout(width='300px')
)

generation_output = widgets.Output()

def on_generate_from_embedding(b):
    with generation_output:
        generation_output.clear_output(wait=True)
        if current_embedding is not None:
            image = generate_image_from_embedding(
                current_embedding,
                output_filename="image_original_embedding.png"
            )
            if image:
                display(image)
        else:
            print("❌ No embedding loaded! Generate or load an embedding first.")

def on_generate_from_modified(b):
    with generation_output:
        generation_output.clear_output(wait=True)
        if modified_embedding is not None:
            image = generate_image_from_embedding(
                modified_embedding,
                output_filename="image_modified_embedding.png"
            )
            if image:
                display(image)
        else:
            print("❌ No modified embedding! Apply a manipulation first or load one.")

generate_from_embedding_button.on_click(on_generate_from_embedding)
generate_from_modified_button.on_click(on_generate_from_modified)

display(widgets.VBox([
    widgets.HTML("<h3>Generate Images from Embeddings</h3>"),
    generate_from_embedding_button,
    generate_from_modified_button
]), generation_output)


In [24]:
def load_clip_embedding(filename="clip_embedding.json"):
    global current_clip_embedding, current_clip_tokens
    
    with open(filename, 'r') as f:
        data = json.load(f)
    
    current_clip_embedding = np.array(data['embedding'])
    current_clip_tokens = data['tokens']
    
    print(f"✓ CLIP embedding loaded from '{filename}'")
    print(f"  Original prompt: {data['prompt']}")
    print(f"  Shape: {current_clip_embedding.shape}")
    print(f"  First token: '{current_clip_tokens[0]}'")

# Load button
load_clip_button = widgets.Button(description='Load CLIP Embedding', button_style='warning')
load_clip_output = widgets.Output()

def on_load_clip_click(b):
    with load_clip_output:
        load_clip_output.clear_output()
        try:
            load_clip_embedding()
        except FileNotFoundError:
            print("❌ File 'clip_embedding.json' not found. Save a CLIP embedding first.")
        except Exception as e:
            print(f"❌ Error loading CLIP embedding: {e}")

load_clip_button.on_click(on_load_clip_click)
display(load_clip_button, load_clip_output)



Output()

# CLIP + T5 EMBEDDING

In [None]:
# Load embeddings interface

# Setup directories
T5_EMBEDDINGS_DIR = current_dir.parent / "data/embeddings/T5"
CLIP_EMBEDDINGS_DIR = current_dir.parent / "data/embeddings/CLIP"

# Global variables for loaded embeddings
loaded_t5_embedding = None
loaded_clip_embedding = None
loaded_t5_prompt = None
loaded_clip_prompt = None

# T5 embedding selection
t5_file_dropdown = widgets.Dropdown(
    options=[],
    description='T5 Embedding:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='600px')
)

load_t5_button = widgets.Button(
    description='Load T5',
    button_style='success'
)

t5_load_output = widgets.Output()

# CLIP embedding selection
clip_file_dropdown = widgets.Dropdown(
    options=[],
    description='CLIP Embedding:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='600px')
)

load_clip_button = widgets.Button(
    description='Load CLIP',
    button_style='success'
)

clip_load_output = widgets.Output()

# Populate dropdown options
if T5_EMBEDDINGS_DIR.exists():
    t5_files = sorted([f.name for f in T5_EMBEDDINGS_DIR.glob('*.json')])
    t5_file_dropdown.options = t5_files

if CLIP_EMBEDDINGS_DIR.exists():
    clip_files = sorted([f.name for f in CLIP_EMBEDDINGS_DIR.glob('*.json')])
    clip_file_dropdown.options = clip_files

def load_t5_embedding_file(b):
    global loaded_t5_embedding, loaded_t5_prompt
    
    with t5_load_output:
        t5_load_output.clear_output()
        
        filename = t5_file_dropdown.value
        if not filename:
            print("❌ No file selected!")
            return
        
        filepath = T5_EMBEDDINGS_DIR / filename
        
        try:
            with open(filepath, 'r') as f:
                data = json.load(f)
            
            loaded_t5_embedding = np.array(data['embedding'])
            loaded_t5_prompt = data.get('prompt', 'Unknown')
            
            print(f"✓ Loaded T5 embedding!")
            print(f"  File: {filename}")
            print(f"  Prompt: '{loaded_t5_prompt}'")
            print(f"  Shape: {loaded_t5_embedding.shape}")
            print(f"  Expected: [512, 4096]")
            
        except Exception as e:
            print(f"❌ Error loading T5 embedding: {e}")

def load_clip_embedding_file(b):
    global loaded_clip_embedding, loaded_clip_prompt
    
    with clip_load_output:
        clip_load_output.clear_output()
        
        filename = clip_file_dropdown.value
        if not filename:
            print("❌ No file selected!")
            return
        
        filepath = CLIP_EMBEDDINGS_DIR / filename
        
        try:
            with open(filepath, 'r') as f:
                data = json.load(f)
            
            loaded_clip_embedding = np.array(data['embedding'])
            loaded_clip_prompt = data.get('prompt', 'Unknown')
            
            print(f"✓ Loaded CLIP embedding!")
            print(f"  File: {filename}")
            print(f"  Prompt: '{loaded_clip_prompt}'")
            print(f"  Shape: {loaded_clip_embedding.shape}")
            print(f"  Expected: [77, 768]")
            
        except Exception as e:
            print(f"❌ Error loading CLIP embedding: {e}")

load_t5_button.on_click(load_t5_embedding_file)
load_clip_button.on_click(load_clip_embedding_file)

# Display interface
display(widgets.VBox([
    widgets.HTML("<h3>Select Embeddings for Image Generation</h3>"),
    widgets.HTML("<hr>"),
    widgets.HTML("<h4>1. T5 Embedding (Text Encoding)</h4>"),
    t5_file_dropdown,
    load_t5_button,
    t5_load_output,
    widgets.HTML("<hr>"),
    widgets.HTML("<h4>2. CLIP Embedding (Pooled Encoding)</h4>"),
    clip_file_dropdown,
    load_clip_button,
    clip_load_output
]))

In [None]:
# Generate image from loaded embeddings

def generate_from_loaded_embeddings(output_filename="generated_image.png", seed=42):
    """
    Generate image using loaded T5 and CLIP embeddings.
    """
    if 'flux_pipe' not in globals():
        print("❌ FLUX not loaded!")
        return None
    
    if loaded_t5_embedding is None:
        print("❌ No T5 embedding loaded! Load a T5 embedding first.")
        return None
    
    print(f"Generating image from loaded embeddings...")
    print(f"  T5 embedding shape: {loaded_t5_embedding.shape}")
    print(f"  T5 prompt: '{loaded_t5_prompt}'")
    
    # Convert T5 numpy to torch tensor with bfloat16
    t5_tensor = torch.from_numpy(loaded_t5_embedding.astype(np.float32)).to(
        device=device,
        dtype=torch.bfloat16
    )
    
    # Add batch dimension: [512, 4096] -> [1, 512, 4096]
    t5_tensor = t5_tensor.unsqueeze(0)
    
    print(f"  T5 tensor shape: {t5_tensor.shape}")
    
    # Process CLIP embeddings
    print("\nProcessing CLIP pooled embeddings...")
    
    if loaded_clip_embedding is not None:
        # Use loaded CLIP embedding
        print(f"  Using loaded CLIP embedding")
        print(f"  CLIP embedding shape: {loaded_clip_embedding.shape}")
        print(f"  CLIP prompt: '{loaded_clip_prompt}'")
        
        # Use last token embedding as pooled embedding (EOS token)
        pooled_embeds = torch.from_numpy(
            loaded_clip_embedding[-1:].astype(np.float32)
        ).to(device=device, dtype=torch.bfloat16)
        
    else:
        # Fall back to generating from T5 prompt using CLIP
        print(f"  No CLIP embedding loaded, generating from T5 prompt using CLIP model...")
        
        if loaded_t5_prompt:
            with torch.no_grad():
                (
                    _,
                    pooled_embeds,
                    _,
                ) = flux_pipe.encode_prompt(
                    prompt=loaded_t5_prompt,
                    prompt_2=None,
                    device=device,
                    num_images_per_prompt=1,
                    prompt_embeds=None,
                    pooled_prompt_embeds=None,
                    max_sequence_length=512,
                )
        else:
            print("❌ No CLIP embedding and no prompt available!")
            return None
    
    print(f"  Pooled embeddings shape: {pooled_embeds.shape}")
    
    # Generate image
    try:
        print("\nRunning FLUX diffusion (4 steps)...")
        image = flux_pipe(
            prompt_embeds=t5_tensor,
            pooled_prompt_embeds=pooled_embeds,
            num_inference_steps=4,
            guidance_scale=0.0,
            height=1024,
            width=1024,
            generator=torch.manual_seed(seed)
        ).images[0]
        
        image.save(output_filename)
        print(f"\n✓ Image generated and saved to '{output_filename}'")
        
        return image
        
    except Exception as e:
        print(f"❌ Error generating image: {e}")
        import traceback
        traceback.print_exc()
        return None

# Generation button
generate_button = widgets.Button(
    description='Generate Image',
    button_style='primary',
    layout=widgets.Layout(width='300px', height='50px')
)

seed_input = widgets.IntText(
    value=42,
    description='Seed:',
    style={'description_width': 'initial'}
)

generation_output = widgets.Output()

def on_generate_click(b):
    with generation_output:
        generation_output.clear_output(wait=True)
        
        # Create output filename based on loaded embeddings
        t5_name = t5_file_dropdown.value.rsplit('.', 1)[0] if t5_file_dropdown.value else "no_t5"
        clip_name = clip_file_dropdown.value.rsplit('.', 1)[0] if clip_file_dropdown.value else "no_clip"
        output_filename = f"generated_{t5_name}_{clip_name}.png"
        
        image = generate_from_loaded_embeddings(
            output_filename=output_filename,
            seed=seed_input.value
        )
        
        if image:
            display(image)

generate_button.on_click(on_generate_click)

display(widgets.VBox([
    widgets.HTML("<hr>"),
    widgets.HTML("<h3>3. Generate Image</h3>"),
    seed_input,
    generate_button
]), generation_output)

## 10. Compare Original vs Modified Embeddings

Generate both images and view them side-by-side to see the effect of your modifications.

In [32]:
def compare_embeddings(seed=42):
    """
    Generate images from both original and modified embeddings and display side-by-side.
    """
    if 'flux_pipe' not in globals():
        print("❌ FLUX not loaded!")
        return
    
    if current_embedding is None:
        print("❌ No original embedding! Generate or load an embedding first.")
        return
    
    if modified_embedding is None:
        print("❌ No modified embedding! Apply a manipulation first.")
        return
    
    print("="*60)
    print("EMBEDDING COMPARISON")
    print("="*60)
    
    # Show embedding statistics
    print("\nORIGINAL EMBEDDING:")
    print(f"  Non-zero values: {np.count_nonzero(current_embedding):,}")
    print(f"  Mean: {current_embedding.mean():.4f}")
    print(f"  Std: {current_embedding.std():.4f}")
    print(f"  Min: {current_embedding.min():.4f}")
    print(f"  Max: {current_embedding.max():.4f}")
    
    print("\nMODIFIED EMBEDDING:")
    print(f"  Non-zero values: {np.count_nonzero(modified_embedding):,}")
    print(f"  Mean: {modified_embedding.mean():.4f}")
    print(f"  Std: {modified_embedding.std():.4f}")
    print(f"  Min: {modified_embedding.min():.4f}")
    print(f"  Max: {modified_embedding.max():.4f}")
    
    diff = np.abs(current_embedding - modified_embedding).sum()
    print(f"\nTOTAL ABSOLUTE DIFFERENCE: {diff:,.2f}")
    print(f"Percentage changed: {(np.count_nonzero(current_embedding != modified_embedding) / current_embedding.size * 100):.2f}%")
    
    print("\n" + "="*60)
    print("GENERATING IMAGES...")
    print("="*60 + "\n")
    
    # Generate from original
    print("[1/2] Generating from ORIGINAL embedding...")
    img1 = generate_image_from_embedding(
        current_embedding,
        output_filename="comparison_original.png",
        seed=seed
    )
    
    if img1 is None:
        return
    
    print("\n[2/2] Generating from MODIFIED embedding...")
    img2 = generate_image_from_embedding(
        modified_embedding,
        output_filename="comparison_modified.png",
        seed=seed
    )
    
    if img2 is None:
        return
    
    # Display side by side
    print("\n" + "="*60)
    print("COMPARISON RESULTS")
    print("="*60 + "\n")
    
    import matplotlib.pyplot as plt
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 8))
    
    axes[0].imshow(img1)
    axes[0].set_title('Original Embedding', fontsize=14, fontweight='bold')
    axes[0].axis('off')
    
    axes[1].imshow(img2)
    axes[1].set_title('Modified Embedding', fontsize=14, fontweight='bold')
    axes[1].axis('off')
    
    plt.suptitle(f'Prompt: "{prompt_input.value}"', fontsize=12, y=0.98)
    plt.tight_layout()
    plt.savefig("comparison_sidebyside.png", dpi=150, bbox_inches='tight')
    plt.show()
    
    print("✓ Comparison complete!")
    print("\nFiles saved:")
    print("  - comparison_original.png")
    print("  - comparison_modified.png")
    print("  - comparison_sidebyside.png")

# Comparison button
compare_button = widgets.Button(
    description='Compare Original vs Modified',
    button_style='success',
    layout=widgets.Layout(width='300px', height='50px')
)

compare_output = widgets.Output()

def on_compare_click(b):
    with compare_output:
        compare_output.clear_output(wait=True)
        compare_embeddings()

compare_button.on_click(on_compare_click)
display(widgets.VBox([
    widgets.HTML("<h3>⚡ Generate Comparison</h3><p>This will generate images from both embeddings with the same random seed.</p>"),
    compare_button
]), compare_output)

VBox(children=(HTML(value='<h3>⚡ Generate Comparison</h3><p>This will generate images from both embeddings wit…

Output()

## 11. Summary

What you've learned:

1. ✓ Generate T5-XXL embeddings from text (512 tokens × 4096 dimensions = 2,097,152 numbers)
2. ✓ Save/load embeddings as JSON files
3. ✓ Manually manipulate embedding values
4. ✓ Create custom attention masks with sliders
5. ✓ Generate images directly from custom embeddings using FLUX
6. ✓ Compare original vs modified embeddings side-by-side
7. ✓ Understand the complete pipeline: text → T5-XXL → embeddings → FLUX → image

## Workflow:

1. **Generate embedding** from your text prompt
2. **Save to JSON** for backup
3. **Apply manipulations** (zero out values, etc.)
4. **Create attention masks** (optional - for future use)
5. **Generate images** from both original and modified embeddings
6. **Compare results** to see what your changes did!

## Next Steps:

- Experiment with different zero percentages (10%, 30%, 50%, 90%)
- Try zeroing specific token embeddings instead of random values
- Compare results with different prompts
- Implement attention mask injection (requires modifying cross-attention layers)
- Explore interpolating between two different embeddings
- Try adding noise to embeddings and see the effect

In [None]:
# Clean up cache to free home directory space
import shutil
import os
from pathlib import Path

def cleanup_cache():
    """Delete cache directories to free up home quota"""
    cache_dirs = [
        Path.home() / "data" / ".cache",
        Path.home() / ".cache" / "huggingface",
        Path.home() / ".cache" / "torch",
        Path.home() / ".cache" / "uv",
    ]
    
    total_freed = 0
    for cache_dir in cache_dirs:
        if cache_dir.exists():
            try:
                # Get size before deletion (approximate)
                size = sum(f.stat().st_size for f in cache_dir.rglob('*') if f.is_file())
                shutil.rmtree(cache_dir)
                total_freed += size
                print(f"\u2713 Deleted {cache_dir}: {size / 1024**3:.2f} GB")
            except Exception as e:
                print(f"\u2717 Could not delete {cache_dir}: {e}")
    
    print(f"\nTotal freed: {total_freed / 1024**3:.2f} GB")

# Run cleanup
cleanup_cache()