# T5 Embedding Manipulation for FLUX

This notebook lets you:
1. Generate T5 embeddings from text prompts
2. Save/load embeddings as JSON
3. Manually edit embedding values
4. Apply custom attention masks
5. Generate images with FLUX using modified embeddings

## Installation

Run this cell first to install required packages:

## Setup and Imports

In [None]:
import torch
import json
import numpy as np
from transformers import T5EncoderModel, T5Tokenizer
from diffusers import FluxPipeline
import ipywidgets as widgets
from IPython.display import display, Image as IPImage
from PIL import Image
import os

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

## 1. Load T5 Model (T5-XXL for FLUX)

We'll use `google/t5-v1_1-xxl` which is what FLUX uses. This produces 4096-dimensional embeddings.

**Note:** This is a large model (~11GB download). Make sure you have enough disk space and RAM.

In [None]:
# Load T5-XXL (what FLUX actually uses)
t5_model_name = "google/t5-v1_1-xxl"  # 11GB, embedding_dim=4096

print(f"Loading T5 model: {t5_model_name}...")
print("This is a large model and will take several minutes to download on first run.")
print("Please be patient...\n")

tokenizer = T5Tokenizer.from_pretrained(t5_model_name)
t5_model = T5EncoderModel.from_pretrained(
    t5_model_name,
    torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
t5_model.eval()  # Set to evaluation mode

print(f"\n✓ T5-XXL loaded successfully!")
print(f"  Embedding dimension: {t5_model.config.d_model}")
print(f"  Max sequence length: {tokenizer.model_max_length}")

## 2. Text Input Widget and Embedding Generation

In [None]:
# Create text input widget
prompt_input = widgets.Textarea(
    value='a red cat sitting on a blue table',
    placeholder='Enter your prompt here',
    description='Prompt:',
    layout=widgets.Layout(width='80%', height='80px')
)

generate_button = widgets.Button(
    description='Generate Embedding',
    button_style='success'
)

output_area = widgets.Output()

# Global variable to store current embedding
current_embedding = None
current_tokens = None

def generate_embedding(b):
    global current_embedding, current_tokens
    
    with output_area:
        output_area.clear_output()
        
        prompt = prompt_input.value
        print(f"Generating embedding for: '{prompt}'\n")
        
        # Tokenize
        tokens = tokenizer(
            prompt,
            padding="max_length",
            max_length=512,
            truncation=True,
            return_tensors="pt"
        )
        
        # Get token strings for display
        token_ids = tokens['input_ids'][0].tolist()
        token_strings = [tokenizer.decode([tid]) for tid in token_ids]
        
        # Find how many real tokens (non-padding)
        num_real_tokens = (tokens['input_ids'][0] != tokenizer.pad_token_id).sum().item()
        
        print(f"Tokenized into {num_real_tokens} real tokens (+ {512 - num_real_tokens} padding):")
        print("First 10 tokens:", token_strings[:10])
        print()
        
        # Generate embedding
        with torch.no_grad():
            tokens = {k: v.to(device) for k, v in tokens.items()}
            outputs = t5_model(**tokens)
            embedding = outputs.last_hidden_state  # Shape: [1, 512, embedding_dim]
        
        current_embedding = embedding.cpu().numpy()[0]  # Shape: [512, embedding_dim]
        current_tokens = token_strings
        
        embedding_dim = current_embedding.shape[1]
        total_numbers = current_embedding.shape[0] * current_embedding.shape[1]
        
        print(f"✓ Embedding generated!")
        print(f"  Shape: {current_embedding.shape}")
        print(f"  Total numbers: {total_numbers:,}")
        print(f"  Size: {current_embedding.nbytes / 1024:.2f} KB")
        print()
        print(f"First token '{token_strings[0]}' embedding (first 10 values):")
        print(current_embedding[0, :10])

generate_button.on_click(generate_embedding)

display(prompt_input, generate_button, output_area)

## 3. Save Embedding to JSON

In [None]:
def save_embedding(filename="embedding.json"):
    if current_embedding is None:
        print("❌ No embedding to save! Generate one first.")
        return
    
    data = {
        "embedding": current_embedding.tolist(),
        "tokens": current_tokens,
        "shape": list(current_embedding.shape),
        "prompt": prompt_input.value
    }
    
    with open(filename, 'w') as f:
        json.dump(data, f)
    
    file_size = os.path.getsize(filename) / (1024 * 1024)
    print(f"✓ Embedding saved to '{filename}' ({file_size:.2f} MB)")

# Save button
save_button = widgets.Button(description='Save Embedding', button_style='info')
save_output = widgets.Output()

def on_save_click(b):
    with save_output:
        save_output.clear_output()
        save_embedding()

save_button.on_click(on_save_click)
display(save_button, save_output)

## 4. Load Embedding from JSON

In [None]:
def load_embedding(filename="embedding.json"):
    global current_embedding, current_tokens
    
    with open(filename, 'r') as f:
        data = json.load(f)
    
    current_embedding = np.array(data['embedding'])
    current_tokens = data['tokens']
    
    print(f"✓ Embedding loaded from '{filename}'")
    print(f"  Original prompt: {data['prompt']}")
    print(f"  Shape: {current_embedding.shape}")
    print(f"  First token: '{current_tokens[0]}'")

# Load button
load_button = widgets.Button(description='Load Embedding', button_style='warning')
load_output = widgets.Output()

def on_load_click(b):
    with load_output:
        load_output.clear_output()
        try:
            load_embedding()
        except FileNotFoundError:
            print("❌ File 'embedding.json' not found. Save an embedding first.")

load_button.on_click(on_load_click)
display(load_button, load_output)

## 5. Manual Embedding Manipulation

Example: Zero out a percentage of values

In [None]:
def manipulate_embedding(zero_percentage=0.3):
    global current_embedding
    
    if current_embedding is None:
        print("❌ No embedding loaded!")
        return
    
    # Create a copy
    modified = current_embedding.copy()
    
    # Randomly zero out a percentage of values
    total_values = modified.size
    num_zeros = int(total_values * zero_percentage)
    
    # Random indices
    flat_modified = modified.flatten()
    zero_indices = np.random.choice(total_values, num_zeros, replace=False)
    flat_modified[zero_indices] = 0.0
    
    modified = flat_modified.reshape(current_embedding.shape)
    
    print(f"✓ Zeroed out {num_zeros:,} values ({zero_percentage*100}%)")
    print(f"  Original non-zero: {np.count_nonzero(current_embedding):,}")
    print(f"  Modified non-zero: {np.count_nonzero(modified):,}")
    
    return modified

# Manipulation widget
zero_slider = widgets.FloatSlider(
    value=0.3,
    min=0.0,
    max=1.0,
    step=0.05,
    description='Zero %:',
    readout_format='.0%'
)

manipulate_button = widgets.Button(description='Apply Manipulation', button_style='danger')
manipulate_output = widgets.Output()

modified_embedding = None

def on_manipulate_click(b):
    global modified_embedding
    with manipulate_output:
        manipulate_output.clear_output()
        modified_embedding = manipulate_embedding(zero_slider.value)

manipulate_button.on_click(on_manipulate_click)
display(zero_slider, manipulate_button, manipulate_output)

## 6. Manual Attention Masking

Control which tokens get how much attention

In [None]:
# Display tokens and create sliders
def create_attention_sliders(num_tokens=10):
    if current_tokens is None:
        print("❌ Generate an embedding first!")
        return None
    
    # Find non-padding tokens
    real_tokens = []
    for i, token in enumerate(current_tokens):
        if token.strip() and token != '<pad>' and i < num_tokens:
            real_tokens.append((i, token))
    
    print(f"Creating sliders for first {len(real_tokens)} tokens:")
    print(real_tokens)
    print()
    
    sliders = []
    for idx, token in real_tokens:
        slider = widgets.FloatSlider(
            value=1.0 / len(real_tokens),  # Equal distribution
            min=0.0,
            max=1.0,
            step=0.01,
            description=f'{idx}: {token[:15]}',
            readout_format='.0%',
            layout=widgets.Layout(width='400px')
        )
        sliders.append(slider)
    
    # Normalize button
    normalize_button = widgets.Button(
        description='Normalize to 100%',
        button_style='info'
    )
    
    total_label = widgets.Label(value='Total: 100%')
    
    def update_total(*args):
        total = sum(s.value for s in sliders)
        total_label.value = f'Total: {total*100:.1f}%'
        if abs(total - 1.0) > 0.01:
            total_label.value += ' ⚠️ Should sum to 100%'
    
    def normalize(*args):
        total = sum(s.value for s in sliders)
        if total > 0:
            for s in sliders:
                s.value = s.value / total
        update_total()
    
    for slider in sliders:
        slider.observe(update_total, 'value')
    
    normalize_button.on_click(normalize)
    
    update_total()
    
    return sliders, normalize_button, total_label, real_tokens

attention_controls = create_attention_sliders(num_tokens=10)

if attention_controls:
    sliders, normalize_btn, total_lbl, real_tokens = attention_controls
    display(widgets.VBox(sliders + [normalize_btn, total_lbl]))

## 7. Create Attention Mask Array

In [None]:
def create_attention_mask():
    if attention_controls is None:
        print("❌ Create attention sliders first!")
        return None
    
    sliders, _, _, real_tokens = attention_controls
    
    # Create mask array (512 tokens)
    mask = np.zeros(512)
    
    # Set weights from sliders
    for slider, (idx, token) in zip(sliders, real_tokens):
        mask[idx] = slider.value
    
    print("Attention Mask created:")
    print(f"Non-zero weights: {np.count_nonzero(mask)}")
    for slider, (idx, token) in zip(sliders, real_tokens):
        print(f"  Token {idx} '{token}': {slider.value*100:.1f}%")
    
    return mask

mask_button = widgets.Button(description='Create Mask', button_style='success')
mask_output = widgets.Output()

attention_mask = None

def on_mask_click(b):
    global attention_mask
    with mask_output:
        mask_output.clear_output()
        attention_mask = create_attention_mask()

mask_button.on_click(on_mask_click)
display(mask_button, mask_output)

## 8. Load FLUX Model

**Warning:** FLUX-schnell is still large (~24GB). This will take time and require significant VRAM/RAM.

If you don't have enough resources, skip this section and the embedding/mask experiments above are still valuable!

In [None]:
# Load FLUX-schnell from local folder
print("Loading FLUX-schnell from local folder...")
try:
    # Replace this path with your actual local model path
    local_model_path = "../phase_02/data/models/black-forest-labs__FLUX.1-dev/"  # UPDATE THIS PATH
    
    flux_pipe = FluxPipeline.from_pretrained(
        local_model_path,
        torch_dtype=torch.bfloat16,
        local_files_only=True  # Ensures it only loads from local files
    )
    flux_pipe = flux_pipe.to(device)
    print("✓ FLUX loaded successfully from local folder!")
except Exception as e:
    print(f"❌ Error loading FLUX: {e}")
    print("Please check that:")
    print(f"1. The path '{local_model_path}' is correct")
    print("2. The folder contains all necessary model files")
    print("3. You have enough RAM/VRAM to load the model")

## 9. Generate Image from Loaded Embedding

This will use the embedding loaded from the JSON file to generate an image.

In [None]:
def generate_image_from_embedding(embedding_array, output_filename="generated_from_embedding.png", seed=42):
    """
    Generate image using custom T5 embedding by injecting it into FLUX pipeline.
    """
    if 'flux_pipe' not in globals():
        print("❌ FLUX not loaded!")
        return None
    
    if embedding_array is None:
        print("❌ No embedding provided! Load an embedding first.")
        return None
    
    print(f"Generating image from custom embedding...")
    print(f"  Embedding shape: {embedding_array.shape}")
    print(f"  Expected shape: [512, 4096]")
    
    # Convert numpy to torch tensor
    embedding_tensor = torch.from_numpy(embedding_array).to(
        device=device,
        dtype=torch.float16 if device == "cuda" else torch.float32
    )
    
    # Add batch dimension: [512, 4096] -> [1, 512, 4096]
    embedding_tensor = embedding_tensor.unsqueeze(0)
    
    print(f"  Tensor shape: {embedding_tensor.shape}")
    print(f"  Device: {embedding_tensor.device}")
    print(f"  Dtype: {embedding_tensor.dtype}")
    print()
    
    # Generate image using the custom embedding
    # We'll inject this into the pipeline by using prompt_embeds parameter
    try:
        print("Running diffusion process (4 steps)...")
        image = flux_pipe(
            prompt_embeds=embedding_tensor,
            num_inference_steps=4,
            guidance_scale=0.0,
            height=1024,
            width=1024,
            generator=torch.manual_seed(seed)
        ).images[0]
        
        image.save(output_filename)
        print(f"\n✓ Image generated and saved to '{output_filename}'")
        
        return image
        
    except Exception as e:
        print(f"❌ Error generating image: {e}")
        print("\nThis might happen if:")
        print("  - Embedding dimensions don't match (should be [1, 512, 4096])")
        print("  - FLUX version doesn't support prompt_embeds parameter")
        print("  - Insufficient VRAM/RAM")
        return None

# Generate from current embedding button
generate_from_embedding_button = widgets.Button(
    description='Generate from Current Embedding',
    button_style='primary',
    layout=widgets.Layout(width='300px')
)

# Generate from modified embedding button
generate_from_modified_button = widgets.Button(
    description='Generate from Modified Embedding',
    button_style='warning',
    layout=widgets.Layout(width='300px')
)

generation_output = widgets.Output()

def on_generate_from_embedding(b):
    with generation_output:
        generation_output.clear_output(wait=True)
        if current_embedding is not None:
            image = generate_image_from_embedding(
                current_embedding,
                output_filename="image_original_embedding.png"
            )
            if image:
                display(image)
        else:
            print("❌ No embedding loaded! Generate or load an embedding first.")

def on_generate_from_modified(b):
    with generation_output:
        generation_output.clear_output(wait=True)
        if modified_embedding is not None:
            image = generate_image_from_embedding(
                modified_embedding,
                output_filename="image_modified_embedding.png"
            )
            if image:
                display(image)
        else:
            print("❌ No modified embedding! Apply a manipulation first.")

generate_from_embedding_button.on_click(on_generate_from_embedding)
generate_from_modified_button.on_click(on_generate_from_modified)

display(widgets.VBox([
    widgets.HTML("<h3>Generate Images from Embeddings</h3>"),
    generate_from_embedding_button,
    generate_from_modified_button
]), generation_output)

## 10. Compare Original vs Modified Embeddings

Generate both images and view them side-by-side to see the effect of your modifications.

In [None]:
def compare_embeddings(seed=42):
    """
    Generate images from both original and modified embeddings and display side-by-side.
    """
    if 'flux_pipe' not in globals():
        print("❌ FLUX not loaded!")
        return
    
    if current_embedding is None:
        print("❌ No original embedding! Generate or load an embedding first.")
        return
    
    if modified_embedding is None:
        print("❌ No modified embedding! Apply a manipulation first.")
        return
    
    print("="*60)
    print("EMBEDDING COMPARISON")
    print("="*60)
    
    # Show embedding statistics
    print("\nORIGINAL EMBEDDING:")
    print(f"  Non-zero values: {np.count_nonzero(current_embedding):,}")
    print(f"  Mean: {current_embedding.mean():.4f}")
    print(f"  Std: {current_embedding.std():.4f}")
    print(f"  Min: {current_embedding.min():.4f}")
    print(f"  Max: {current_embedding.max():.4f}")
    
    print("\nMODIFIED EMBEDDING:")
    print(f"  Non-zero values: {np.count_nonzero(modified_embedding):,}")
    print(f"  Mean: {modified_embedding.mean():.4f}")
    print(f"  Std: {modified_embedding.std():.4f}")
    print(f"  Min: {modified_embedding.min():.4f}")
    print(f"  Max: {modified_embedding.max():.4f}")
    
    diff = np.abs(current_embedding - modified_embedding).sum()
    print(f"\nTOTAL ABSOLUTE DIFFERENCE: {diff:,.2f}")
    print(f"Percentage changed: {(np.count_nonzero(current_embedding != modified_embedding) / current_embedding.size * 100):.2f}%")
    
    print("\n" + "="*60)
    print("GENERATING IMAGES...")
    print("="*60 + "\n")
    
    # Generate from original
    print("[1/2] Generating from ORIGINAL embedding...")
    img1 = generate_image_from_embedding(
        current_embedding,
        output_filename="comparison_original.png",
        seed=seed
    )
    
    if img1 is None:
        return
    
    print("\n[2/2] Generating from MODIFIED embedding...")
    img2 = generate_image_from_embedding(
        modified_embedding,
        output_filename="comparison_modified.png",
        seed=seed
    )
    
    if img2 is None:
        return
    
    # Display side by side
    print("\n" + "="*60)
    print("COMPARISON RESULTS")
    print("="*60 + "\n")
    
    import matplotlib.pyplot as plt
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 8))
    
    axes[0].imshow(img1)
    axes[0].set_title('Original Embedding', fontsize=14, fontweight='bold')
    axes[0].axis('off')
    
    axes[1].imshow(img2)
    axes[1].set_title('Modified Embedding', fontsize=14, fontweight='bold')
    axes[1].axis('off')
    
    plt.suptitle(f'Prompt: "{prompt_input.value}"', fontsize=12, y=0.98)
    plt.tight_layout()
    plt.savefig("comparison_sidebyside.png", dpi=150, bbox_inches='tight')
    plt.show()
    
    print("✓ Comparison complete!")
    print("\nFiles saved:")
    print("  - comparison_original.png")
    print("  - comparison_modified.png")
    print("  - comparison_sidebyside.png")

# Comparison button
compare_button = widgets.Button(
    description='Compare Original vs Modified',
    button_style='success',
    layout=widgets.Layout(width='300px', height='50px')
)

compare_output = widgets.Output()

def on_compare_click(b):
    with compare_output:
        compare_output.clear_output(wait=True)
        compare_embeddings()

compare_button.on_click(on_compare_click)
display(widgets.VBox([
    widgets.HTML("<h3>⚡ Generate Comparison</h3><p>This will generate images from both embeddings with the same random seed.</p>"),
    compare_button
]), compare_output)

## 11. Summary

What you've learned:

1. ✓ Generate T5-XXL embeddings from text (512 tokens × 4096 dimensions = 2,097,152 numbers)
2. ✓ Save/load embeddings as JSON files
3. ✓ Manually manipulate embedding values
4. ✓ Create custom attention masks with sliders
5. ✓ Generate images directly from custom embeddings using FLUX
6. ✓ Compare original vs modified embeddings side-by-side
7. ✓ Understand the complete pipeline: text → T5-XXL → embeddings → FLUX → image

## Workflow:

1. **Generate embedding** from your text prompt
2. **Save to JSON** for backup
3. **Apply manipulations** (zero out values, etc.)
4. **Create attention masks** (optional - for future use)
5. **Generate images** from both original and modified embeddings
6. **Compare results** to see what your changes did!

## Next Steps:

- Experiment with different zero percentages (10%, 30%, 50%, 90%)
- Try zeroing specific token embeddings instead of random values
- Compare results with different prompts
- Implement attention mask injection (requires modifying cross-attention layers)
- Explore interpolating between two different embeddings
- Try adding noise to embeddings and see the effect