# T5 Embedding Manipulation for FLUX

This notebook lets you:
1. Generate T5 embeddings from text prompts
2. Save/load embeddings as JSON
3. Manually edit embedding values
4. Apply custom attention masks
5. Generate images with FLUX using modified embeddings

## Installation

Run this cell first to install required packages:

In [None]:
!pip install transformers torch diffusers accelerate sentencepiece protobuf ipywidgets numpy pillow --break-system-packages

## Setup and Imports

In [None]:
import torch
import json
import numpy as np
from transformers import T5EncoderModel, T5Tokenizer
from diffusers import FluxPipeline
import ipywidgets as widgets
from IPython.display import display, Image as IPImage
from PIL import Image
import os

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

## 1. Load T5 Model (Small Version)

We'll use `google/t5-v1_1-small` which is much smaller than T5-XXL but compatible for experimentation.

**Note:** FLUX officially uses T5-XXL, but for learning purposes, T5-small works. The embedding dimension will be 512 instead of 4096.

In [None]:
# Load T5-small (for experimentation - faster and smaller)
# For production use: "google/t5-v1_1-xxl" but it's 11GB+
t5_model_name = "google/t5-v1_1-small"  # 242MB, embedding_dim=512

print(f"Loading T5 model: {t5_model_name}...")
tokenizer = T5Tokenizer.from_pretrained(t5_model_name)
t5_model = T5EncoderModel.from_pretrained(t5_model_name).to(device)
t5_model.eval()  # Set to evaluation mode

print(f"T5 loaded! Embedding dimension: {t5_model.config.d_model}")
print(f"Max sequence length: {tokenizer.model_max_length}")

## 2. Text Input Widget and Embedding Generation

In [None]:
# Create text input widget
prompt_input = widgets.Textarea(
    value='a red cat sitting on a blue table',
    placeholder='Enter your prompt here',
    description='Prompt:',
    layout=widgets.Layout(width='80%', height='80px')
)

generate_button = widgets.Button(
    description='Generate Embedding',
    button_style='success'
)

output_area = widgets.Output()

# Global variable to store current embedding
current_embedding = None
current_tokens = None

def generate_embedding(b):
    global current_embedding, current_tokens
    
    with output_area:
        output_area.clear_output()
        
        prompt = prompt_input.value
        print(f"Generating embedding for: '{prompt}'\n")
        
        # Tokenize
        tokens = tokenizer(
            prompt,
            padding="max_length",
            max_length=512,
            truncation=True,
            return_tensors="pt"
        )
        
        # Get token strings for display
        token_ids = tokens['input_ids'][0].tolist()
        token_strings = [tokenizer.decode([tid]) for tid in token_ids]
        
        # Find how many real tokens (non-padding)
        num_real_tokens = (tokens['input_ids'][0] != tokenizer.pad_token_id).sum().item()
        
        print(f"Tokenized into {num_real_tokens} real tokens (+ {512 - num_real_tokens} padding):")
        print("First 10 tokens:", token_strings[:10])
        print()
        
        # Generate embedding
        with torch.no_grad():
            tokens = {k: v.to(device) for k, v in tokens.items()}
            outputs = t5_model(**tokens)
            embedding = outputs.last_hidden_state  # Shape: [1, 512, embedding_dim]
        
        current_embedding = embedding.cpu().numpy()[0]  # Shape: [512, embedding_dim]
        current_tokens = token_strings
        
        embedding_dim = current_embedding.shape[1]
        total_numbers = current_embedding.shape[0] * current_embedding.shape[1]
        
        print(f"✓ Embedding generated!")
        print(f"  Shape: {current_embedding.shape}")
        print(f"  Total numbers: {total_numbers:,}")
        print(f"  Size: {current_embedding.nbytes / 1024:.2f} KB")
        print()
        print(f"First token '{token_strings[0]}' embedding (first 10 values):")
        print(current_embedding[0, :10])

generate_button.on_click(generate_embedding)

display(prompt_input, generate_button, output_area)

## 3. Save Embedding to JSON

In [None]:
def save_embedding(filename="embedding.json"):
    if current_embedding is None:
        print("❌ No embedding to save! Generate one first.")
        return
    
    data = {
        "embedding": current_embedding.tolist(),
        "tokens": current_tokens,
        "shape": list(current_embedding.shape),
        "prompt": prompt_input.value
    }
    
    with open(filename, 'w') as f:
        json.dump(data, f)
    
    file_size = os.path.getsize(filename) / (1024 * 1024)
    print(f"✓ Embedding saved to '{filename}' ({file_size:.2f} MB)")

# Save button
save_button = widgets.Button(description='Save Embedding', button_style='info')
save_output = widgets.Output()

def on_save_click(b):
    with save_output:
        save_output.clear_output()
        save_embedding()

save_button.on_click(on_save_click)
display(save_button, save_output)

## 4. Load Embedding from JSON

In [None]:
def load_embedding(filename="embedding.json"):
    global current_embedding, current_tokens
    
    with open(filename, 'r') as f:
        data = json.load(f)
    
    current_embedding = np.array(data['embedding'])
    current_tokens = data['tokens']
    
    print(f"✓ Embedding loaded from '{filename}'")
    print(f"  Original prompt: {data['prompt']}")
    print(f"  Shape: {current_embedding.shape}")
    print(f"  First token: '{current_tokens[0]}'")

# Load button
load_button = widgets.Button(description='Load Embedding', button_style='warning')
load_output = widgets.Output()

def on_load_click(b):
    with load_output:
        load_output.clear_output()
        try:
            load_embedding()
        except FileNotFoundError:
            print("❌ File 'embedding.json' not found. Save an embedding first.")

load_button.on_click(on_load_click)
display(load_button, load_output)

## 5. Manual Embedding Manipulation

Example: Zero out a percentage of values

In [None]:
def manipulate_embedding(zero_percentage=0.3):
    global current_embedding
    
    if current_embedding is None:
        print("❌ No embedding loaded!")
        return
    
    # Create a copy
    modified = current_embedding.copy()
    
    # Randomly zero out a percentage of values
    total_values = modified.size
    num_zeros = int(total_values * zero_percentage)
    
    # Random indices
    flat_modified = modified.flatten()
    zero_indices = np.random.choice(total_values, num_zeros, replace=False)
    flat_modified[zero_indices] = 0.0
    
    modified = flat_modified.reshape(current_embedding.shape)
    
    print(f"✓ Zeroed out {num_zeros:,} values ({zero_percentage*100}%)")
    print(f"  Original non-zero: {np.count_nonzero(current_embedding):,}")
    print(f"  Modified non-zero: {np.count_nonzero(modified):,}")
    
    return modified

# Manipulation widget
zero_slider = widgets.FloatSlider(
    value=0.3,
    min=0.0,
    max=1.0,
    step=0.05,
    description='Zero %:',
    readout_format='.0%'
)

manipulate_button = widgets.Button(description='Apply Manipulation', button_style='danger')
manipulate_output = widgets.Output()

modified_embedding = None

def on_manipulate_click(b):
    global modified_embedding
    with manipulate_output:
        manipulate_output.clear_output()
        modified_embedding = manipulate_embedding(zero_slider.value)

manipulate_button.on_click(on_manipulate_click)
display(zero_slider, manipulate_button, manipulate_output)

## 6. Manual Attention Masking

Control which tokens get how much attention

In [None]:
# Display tokens and create sliders
def create_attention_sliders(num_tokens=10):
    if current_tokens is None:
        print("❌ Generate an embedding first!")
        return None
    
    # Find non-padding tokens
    real_tokens = []
    for i, token in enumerate(current_tokens):
        if token.strip() and token != '<pad>' and i < num_tokens:
            real_tokens.append((i, token))
    
    print(f"Creating sliders for first {len(real_tokens)} tokens:")
    print(real_tokens)
    print()
    
    sliders = []
    for idx, token in real_tokens:
        slider = widgets.FloatSlider(
            value=1.0 / len(real_tokens),  # Equal distribution
            min=0.0,
            max=1.0,
            step=0.01,
            description=f'{idx}: {token[:15]}',
            readout_format='.0%',
            layout=widgets.Layout(width='400px')
        )
        sliders.append(slider)
    
    # Normalize button
    normalize_button = widgets.Button(
        description='Normalize to 100%',
        button_style='info'
    )
    
    total_label = widgets.Label(value='Total: 100%')
    
    def update_total(*args):
        total = sum(s.value for s in sliders)
        total_label.value = f'Total: {total*100:.1f}%'
        if abs(total - 1.0) > 0.01:
            total_label.value += ' ⚠️ Should sum to 100%'
    
    def normalize(*args):
        total = sum(s.value for s in sliders)
        if total > 0:
            for s in sliders:
                s.value = s.value / total
        update_total()
    
    for slider in sliders:
        slider.observe(update_total, 'value')
    
    normalize_button.on_click(normalize)
    
    update_total()
    
    return sliders, normalize_button, total_label, real_tokens

attention_controls = create_attention_sliders(num_tokens=10)

if attention_controls:
    sliders, normalize_btn, total_lbl, real_tokens = attention_controls
    display(widgets.VBox(sliders + [normalize_btn, total_lbl]))

## 7. Create Attention Mask Array

In [None]:
def create_attention_mask():
    if attention_controls is None:
        print("❌ Create attention sliders first!")
        return None
    
    sliders, _, _, real_tokens = attention_controls
    
    # Create mask array (512 tokens)
    mask = np.zeros(512)
    
    # Set weights from sliders
    for slider, (idx, token) in zip(sliders, real_tokens):
        mask[idx] = slider.value
    
    print("Attention Mask created:")
    print(f"Non-zero weights: {np.count_nonzero(mask)}")
    for slider, (idx, token) in zip(sliders, real_tokens):
        print(f"  Token {idx} '{token}': {slider.value*100:.1f}%")
    
    return mask

mask_button = widgets.Button(description='Create Mask', button_style='success')
mask_output = widgets.Output()

attention_mask = None

def on_mask_click(b):
    global attention_mask
    with mask_output:
        mask_output.clear_output()
        attention_mask = create_attention_mask()

mask_button.on_click(on_mask_click)
display(mask_button, mask_output)

## 8. Load FLUX Model

**Warning:** FLUX-schnell is still large (~24GB). This will take time and require significant VRAM/RAM.

If you don't have enough resources, skip this section and the embedding/mask experiments above are still valuable!

In [None]:
# Load FLUX-schnell (smallest FLUX variant)
print("Loading FLUX-schnell... This will take several minutes and requires ~24GB download.")
print("If you don't have enough resources, you can skip this cell.")

try:
    flux_pipe = FluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-schnell",
        torch_dtype=torch.bfloat16
    )
    flux_pipe = flux_pipe.to(device)
    print("✓ FLUX loaded successfully!")
except Exception as e:
    print(f"❌ Error loading FLUX: {e}")
    print("You may need to:")
    print("1. Accept the license at https://huggingface.co/black-forest-labs/FLUX.1-schnell")
    print("2. Login with: huggingface-cli login")
    print("3. Have enough disk space (~24GB) and RAM/VRAM")

## 9. Generate Image with Custom Embedding and Mask

**Note:** This is a simplified version. Full integration would require modifying FLUX's pipeline to accept custom embeddings and masks.

In [None]:
def generate_image_simple(prompt_text, use_mask=False):
    """
    Simple generation using text prompt.
    For custom embeddings, you'd need to modify the pipeline.
    """
    if 'flux_pipe' not in globals():
        print("❌ FLUX not loaded!")
        return None
    
    print(f"Generating image for: '{prompt_text}'")
    
    if use_mask and attention_mask is not None:
        print("Note: Custom masking requires pipeline modification.")
        print("Generating with standard pipeline for now...")
    
    image = flux_pipe(
        prompt=prompt_text,
        num_inference_steps=4,  # Schnell is optimized for 4 steps
        guidance_scale=0.0,  # Schnell doesn't use guidance
    ).images[0]
    
    output_path = "generated_image.png"
    image.save(output_path)
    print(f"✓ Image saved to '{output_path}'")
    
    return image

# Generate button
generate_image_button = widgets.Button(
    description='Generate Image',
    button_style='primary'
)
image_output = widgets.Output()

def on_generate_image_click(b):
    with image_output:
        image_output.clear_output(wait=True)
        image = generate_image_simple(prompt_input.value, use_mask=False)
        if image:
            display(image)

generate_image_button.on_click(on_generate_image_click)
display(generate_image_button, image_output)

## 10. Advanced: Custom Pipeline Integration (Template)

To truly use custom embeddings and masks, you'd need to modify the FLUX pipeline. Here's a template:

In [None]:
# This is a template showing where you'd inject custom embeddings
# Full implementation requires deep diving into the diffusers library

"""
def generate_with_custom_embedding(embedding, attention_mask=None):
    # 1. Convert numpy embedding to torch tensor
    embedding_tensor = torch.from_numpy(embedding).unsqueeze(0).to(device)
    
    # 2. If you have attention_mask, convert it too
    if attention_mask is not None:
        mask_tensor = torch.from_numpy(attention_mask).unsqueeze(0).to(device)
    
    # 3. You'd need to modify flux_pipe's forward pass to:
    #    - Skip text encoding
    #    - Use your embedding_tensor directly
    #    - Apply mask_tensor in cross-attention layers
    
    # This requires:
    # - Accessing flux_pipe.transformer
    # - Hooking into cross-attention layers
    # - Replacing attention weights
    
    pass
"""

print("Custom embedding integration requires modifying the diffusers pipeline.")
print("This is an advanced topic - consider exploring:")
print("1. diffusers library source code")
print("2. Custom pipeline examples in diffusers documentation")
print("3. Attention manipulation techniques (Prompt-to-Prompt, etc.)")

## Summary

What you've learned:

1. ✓ Generate T5 embeddings from text (512 tokens × 512 dimensions = 262,144 numbers)
2. ✓ Save/load embeddings as JSON files
3. ✓ Manually manipulate embedding values
4. ✓ Create custom attention masks with sliders
5. ✓ Understand the pipeline: text → T5 → embeddings → FLUX → image

Next steps:
- Experiment with zeroing different percentages of values
- Try extreme attention masks (90% on one token)
- Explore the diffusers library to implement custom pipeline integration
- Compare images with/without modifications