# T5 Embedding Manipulation for FLUX

This notebook lets you:
1. Generate T5 embeddings from text prompts
2. Save/load embeddings as JSON
3. Manually edit embedding values
4. Apply custom attention masks
5. Generate images with FLUX using modified embeddings

## Installation

Run this cell first to install required packages:

In [13]:
import torch
import json
import numpy as np
from transformers import T5EncoderModel, T5Tokenizer
from diffusers import FluxPipeline
import ipywidgets as widgets
from IPython.display import display, Image as IPImage
from PIL import Image
import os
from pathlib import Path

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Create models directory
current_dir = Path.cwd()
MODELS_DIR = current_dir.parent / "data/models"
T5_MODEL_PATH = os.path.join(MODELS_DIR, "t5-v1_1-xxl")
CLIP_MODEL_PATH = os.path.join(MODELS_DIR, "clip")
FLUX_MODEL_PATH = os.path.join(MODELS_DIR, "FLUX.1-schnell")

os.makedirs(MODELS_DIR, exist_ok=True)
print(f"Models directory: {os.path.abspath(MODELS_DIR)}")
print(f"T5 path: {os.path.abspath(T5_MODEL_PATH)}")
print(f"FLUX path: {os.path.abspath(FLUX_MODEL_PATH)}")

Using device: cuda
Models directory: /shares/weddigen.ki.uzh/laura_wagner/latent_vandalism_workshop/data/models
T5 path: /shares/weddigen.ki.uzh/laura_wagner/latent_vandalism_workshop/data/models/t5-v1_1-xxl
FLUX path: /shares/weddigen.ki.uzh/laura_wagner/latent_vandalism_workshop/data/models/FLUX.1-schnell


In [14]:
# Load Hugging Face token from file
from pathlib import Path

# Get the token file path
current_dir = Path.cwd()
token_file = current_dir.parent / "misc/credentials/hf.txt"

print(f"Looking for HF token at: {token_file}")

if token_file.exists():
    with open(token_file, 'r') as f:
        hf_token = f.read().strip()
    
    # Set the token as an environment variable
    os.environ['HF_TOKEN'] = hf_token
    
    # Also login using huggingface_hub
    from huggingface_hub import login
    login(token=hf_token)
    
    print("✓ Hugging Face token loaded and authenticated successfully!")
else:
    print(f"❌ Token file not found at: {token_file}")
    print("Please create the file or update the path.")

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


Looking for HF token at: /shares/weddigen.ki.uzh/laura_wagner/latent_vandalism_workshop/misc/credentials/hf.txt
✓ Hugging Face token loaded and authenticated successfully!


In [15]:
MODELS_DIR = current_dir.parent / "data/models"

In [16]:
# Load FLUX
try:
    if not os.path.exists(FLUX_MODEL_PATH):
        flux_pipe = FluxPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-schnell",
            torch_dtype=torch.bfloat16
        )
        flux_pipe.save_pretrained(FLUX_MODEL_PATH)
    else:
        flux_pipe = FluxPipeline.from_pretrained(
            FLUX_MODEL_PATH,
            torch_dtype=torch.bfloat16,
            local_files_only=True
        )
    
    flux_pipe = flux_pipe.to(device)
    
except Exception as e:
    print(f"Error loading FLUX: {e}")

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

## 6. Manual Attention Masking

Control which tokens get how much attention

In [21]:
# Display tokens and create sliders
def create_attention_sliders(num_tokens=10):
    if current_tokens is None:
        print("❌ Generate an embedding first!")
        return None
    
    # Find non-padding tokens
    real_tokens = []
    for i, token in enumerate(current_tokens):
        if token.strip() and token != '<pad>' and i < num_tokens:
            real_tokens.append((i, token))
    
    print(f"Creating sliders for first {len(real_tokens)} tokens:")
    print(real_tokens)
    print()
    
    sliders = []
    for idx, token in real_tokens:
        slider = widgets.FloatSlider(
            value=1.0 / len(real_tokens),  # Equal distribution
            min=0.0,
            max=1.0,
            step=0.01,
            description=f'{idx}: {token[:15]}',
            readout_format='.0%',
            layout=widgets.Layout(width='400px')
        )
        sliders.append(slider)
    
    # Normalize button
    normalize_button = widgets.Button(
        description='Normalize to 100%',
        button_style='info'
    )
    
    total_label = widgets.Label(value='Total: 100%')
    
    def update_total(*args):
        total = sum(s.value for s in sliders)
        total_label.value = f'Total: {total*100:.1f}%'
        if abs(total - 1.0) > 0.01:
            total_label.value += ' ⚠️ Should sum to 100%'
    
    def normalize(*args):
        total = sum(s.value for s in sliders)
        if total > 0:
            for s in sliders:
                s.value = s.value / total
        update_total()
    
    for slider in sliders:
        slider.observe(update_total, 'value')
    
    normalize_button.on_click(normalize)
    
    update_total()
    
    return sliders, normalize_button, total_label, real_tokens

attention_controls = create_attention_sliders(num_tokens=10)

if attention_controls:
    sliders, normalize_btn, total_lbl, real_tokens = attention_controls
    display(widgets.VBox(sliders + [normalize_btn, total_lbl]))

NameError: name 'current_tokens' is not defined

## 7. Create Attention Mask Array

In [None]:
def create_attention_mask():
    if attention_controls is None:
        print("❌ Create attention sliders first!")
        return None
    
    sliders, _, _, real_tokens = attention_controls
    
    # Create mask array (512 tokens)
    mask = np.zeros(512)
    
    # Set weights from sliders
    for slider, (idx, token) in zip(sliders, real_tokens):
        mask[idx] = slider.value
    
    print("Attention Mask created:")
    print(f"Non-zero weights: {np.count_nonzero(mask)}")
    for slider, (idx, token) in zip(sliders, real_tokens):
        print(f"  Token {idx} '{token}': {slider.value*100:.1f}%")
    
    return mask

mask_button = widgets.Button(description='Create Mask', button_style='success')
mask_output = widgets.Output()

attention_mask = None

def on_mask_click(b):
    global attention_mask
    with mask_output:
        mask_output.clear_output()
        attention_mask = create_attention_mask()

mask_button.on_click(on_mask_click)
display(mask_button, mask_output)

## Generate Image from embedding

In [17]:
# Load embeddings interface

# Setup directories
T5_EMBEDDINGS_DIR = current_dir.parent / "data/embeddings/T5"
CLIP_EMBEDDINGS_DIR = current_dir.parent / "data/embeddings/CLIP"

# Global variables for loaded embeddings
loaded_t5_embedding = None
loaded_clip_embedding = None
loaded_t5_prompt = None
loaded_clip_prompt = None

# T5 embedding selection
t5_file_dropdown = widgets.Dropdown(
    options=[],
    description='T5 Embedding:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='600px')
)

load_t5_button = widgets.Button(
    description='Load T5',
    button_style='success'
)

t5_load_output = widgets.Output()

# CLIP embedding selection
clip_file_dropdown = widgets.Dropdown(
    options=[],
    description='CLIP Embedding:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='600px')
)

load_clip_button = widgets.Button(
    description='Load CLIP',
    button_style='success'
)

clip_load_output = widgets.Output()

# Populate dropdown options
if T5_EMBEDDINGS_DIR.exists():
    t5_files = sorted([f.name for f in T5_EMBEDDINGS_DIR.glob('*.json')])
    t5_file_dropdown.options = t5_files

if CLIP_EMBEDDINGS_DIR.exists():
    clip_files = sorted([f.name for f in CLIP_EMBEDDINGS_DIR.glob('*.json')])
    clip_file_dropdown.options = clip_files

def load_t5_embedding_file(b):
    global loaded_t5_embedding, loaded_t5_prompt
    
    with t5_load_output:
        t5_load_output.clear_output()
        
        filename = t5_file_dropdown.value
        if not filename:
            print("❌ No file selected!")
            return
        
        filepath = T5_EMBEDDINGS_DIR / filename
        
        try:
            with open(filepath, 'r') as f:
                data = json.load(f)
            
            loaded_t5_embedding = np.array(data['embedding'])
            loaded_t5_prompt = data.get('prompt', 'Unknown')
            
            print(f"✓ Loaded T5 embedding!")
            print(f"  File: {filename}")
            print(f"  Prompt: '{loaded_t5_prompt}'")
            print(f"  Shape: {loaded_t5_embedding.shape}")
            print(f"  Expected: [512, 4096]")
            
        except Exception as e:
            print(f"❌ Error loading T5 embedding: {e}")

def load_clip_embedding_file(b):
    global loaded_clip_embedding, loaded_clip_prompt
    
    with clip_load_output:
        clip_load_output.clear_output()
        
        filename = clip_file_dropdown.value
        if not filename:
            print("❌ No file selected!")
            return
        
        filepath = CLIP_EMBEDDINGS_DIR / filename
        
        try:
            with open(filepath, 'r') as f:
                data = json.load(f)
            
            loaded_clip_embedding = np.array(data['embedding'])
            loaded_clip_prompt = data.get('prompt', 'Unknown')
            
            print(f"✓ Loaded CLIP embedding!")
            print(f"  File: {filename}")
            print(f"  Prompt: '{loaded_clip_prompt}'")
            print(f"  Shape: {loaded_clip_embedding.shape}")
            print(f"  Expected: [77, 768]")
            
        except Exception as e:
            print(f"❌ Error loading CLIP embedding: {e}")

load_t5_button.on_click(load_t5_embedding_file)
load_clip_button.on_click(load_clip_embedding_file)

# Display interface
display(widgets.VBox([
    widgets.HTML("<h3>Select Embeddings for Image Generation</h3>"),
    widgets.HTML("<hr>"),
    widgets.HTML("<h4>1. T5 Embedding (Text Encoding)</h4>"),
    t5_file_dropdown,
    load_t5_button,
    t5_load_output,
    widgets.HTML("<hr>"),
    widgets.HTML("<h4>2. CLIP Embedding (Pooled Encoding)</h4>"),
    clip_file_dropdown,
    load_clip_button,
    clip_load_output
]))

VBox(children=(HTML(value='<h3>Select Embeddings for Image Generation</h3>'), HTML(value='<hr>'), HTML(value='…

In [None]:
# Generate image from loaded embeddings

# Setup output directory
OUTPUT_IMAGES_DIR = current_dir.parent / "output/images"
os.makedirs(OUTPUT_IMAGES_DIR, exist_ok=True)

def parse_embedding_filename(filename):
    """
    Parse embedding filename to extract tokens and manipulation.
    Returns (tokens_string, manipulation_string)
    
    Examples:
    - 'a_red_cat_sitting.json' -> ('aredcatsitting', '')
    - 'a_red_cat_sitting_scaled_2.0x.json' -> ('aredcatsitting', 'scaled')
    - 'an_elephant_inverted.json' -> ('anelephant', 'inverted')
    - 'an_elephant_zeroed_30pct.json' -> ('anelephant', 'zeroed')
    """
    if not filename:
        return ('unknown', '')
    
    # Remove .json extension
    base_name = filename.rsplit('.', 1)[0]
    
    # Split by underscore
    parts = base_name.split('_')
    
    # First 4 parts are the tokens - join without underscores
    if len(parts) <= 4:
        tokens = ''.join(parts)
        return (tokens, '')
    
    tokens = ''.join(parts[:4])
    
    # Get manipulation type (simplified)
    # Just take the first part after the tokens
    manipulation_parts = parts[4:]
    if manipulation_parts:
        # Get the manipulation type (first word after tokens)
        manipulation = manipulation_parts[0]
    else:
        manipulation = ''
    
    return (tokens, manipulation)

def generate_from_loaded_embeddings(seed=42):
    """
    Generate image using loaded T5 and CLIP embeddings.
    """
    if 'flux_pipe' not in globals():
        print("❌ FLUX not loaded!")
        return None
    
    if loaded_t5_embedding is None:
        print("❌ No T5 embedding loaded! Load a T5 embedding first.")
        return None
    
    print(f"Generating image from loaded embeddings...")
    print(f"  T5 embedding shape: {loaded_t5_embedding.shape}")
    print(f"  T5 prompt: '{loaded_t5_prompt}'")
    
    # Convert T5 numpy to torch tensor with bfloat16
    t5_tensor = torch.from_numpy(loaded_t5_embedding.astype(np.float32)).to(
        device=device,
        dtype=torch.bfloat16
    )
    
    # Add batch dimension: [512, 4096] -> [1, 512, 4096]
    t5_tensor = t5_tensor.unsqueeze(0)
    
    print(f"  T5 tensor shape: {t5_tensor.shape}")
    
    # Process CLIP embeddings
    print("\nProcessing CLIP pooled embeddings...")
    
    if loaded_clip_embedding is not None:
        # Use loaded CLIP embedding
        print(f"  Using loaded CLIP embedding")
        print(f"  CLIP embedding shape: {loaded_clip_embedding.shape}")
        print(f"  CLIP prompt: '{loaded_clip_prompt}'")
        
        # Use last token embedding as pooled embedding (EOS token)
        pooled_embeds = torch.from_numpy(
            loaded_clip_embedding[-1:].astype(np.float32)
        ).to(device=device, dtype=torch.bfloat16)
        
    else:
        # Fall back to generating from T5 prompt using CLIP
        print(f"  No CLIP embedding loaded, generating from T5 prompt using CLIP model...")
        
        if loaded_t5_prompt:
            with torch.no_grad():
                (
                    _,
                    pooled_embeds,
                    _,
                ) = flux_pipe.encode_prompt(
                    prompt=loaded_t5_prompt,
                    prompt_2=None,
                    device=device,
                    num_images_per_prompt=1,
                    prompt_embeds=None,
                    pooled_prompt_embeds=None,
                    max_sequence_length=512,
                )
        else:
            print("❌ No CLIP embedding and no prompt available!")
            return None
    
    print(f"  Pooled embeddings shape: {pooled_embeds.shape}")
    
    # Construct filename from embeddings
    t5_tokens, t5_manip = parse_embedding_filename(t5_file_dropdown.value)
    clip_tokens, clip_manip = parse_embedding_filename(clip_file_dropdown.value)
    
    # Build filename: t5tokens_t5manip_cliptokens_clipmanip.png
    filename_parts = [t5_tokens]
    if t5_manip:
        filename_parts.append(t5_manip)
    filename_parts.append(clip_tokens)
    if clip_manip:
        filename_parts.append(clip_manip)
    
    filename = '_'.join(filename_parts) + '.png'
    output_filepath = OUTPUT_IMAGES_DIR / filename
    
    # Generate image
    try:
        print("\nRunning FLUX diffusion (4 steps)...")
        image = flux_pipe(
            prompt_embeds=t5_tensor,
            pooled_prompt_embeds=pooled_embeds,
            num_inference_steps=4,
            guidance_scale=0.0,
            height=1024,
            width=1024,
            generator=torch.manual_seed(seed)
        ).images[0]
        
        image.save(output_filepath)
        print(f"\n✓ Image generated and saved!")
        print(f"  Path: {output_filepath}")
        print(f"  Filename: {filename}")
        
        return image
        
    except Exception as e:
        print(f"❌ Error generating image: {e}")
        import traceback
        traceback.print_exc()
        return None

# Generation button
generate_button = widgets.Button(
    description='Generate Image',
    button_style='primary',
    layout=widgets.Layout(width='300px', height='50px')
)

seed_input = widgets.IntText(
    value=42,
    description='Seed:',
    style={'description_width': 'initial'}
)

generation_output = widgets.Output()

def on_generate_click(b):
    with generation_output:
        generation_output.clear_output(wait=True)
        
        image = generate_from_loaded_embeddings(seed=seed_input.value)
        
        if image:
            display(image)

generate_button.on_click(on_generate_click)

display(widgets.VBox([
    widgets.HTML("<hr>"),
    widgets.HTML("<h3>3. Generate Image</h3>"),
    seed_input,
    generate_button
]), generation_output)

## 11. Summary

What you've learned:

1. ✓ Generate T5-XXL embeddings from text (512 tokens × 4096 dimensions = 2,097,152 numbers)
2. ✓ Save/load embeddings as JSON files
3. ✓ Manually manipulate embedding values
4. ✓ Create custom attention masks with sliders
5. ✓ Generate images directly from custom embeddings using FLUX
6. ✓ Compare original vs modified embeddings side-by-side
7. ✓ Understand the complete pipeline: text → T5-XXL → embeddings → FLUX → image

## Workflow:

1. **Generate embedding** from your text prompt
2. **Save to JSON** for backup
3. **Apply manipulations** (zero out values, etc.)
4. **Create attention masks** (optional - for future use)
5. **Generate images** from both original and modified embeddings
6. **Compare results** to see what your changes did!

## Next Steps:

- Experiment with different zero percentages (10%, 30%, 50%, 90%)
- Try zeroing specific token embeddings instead of random values
- Compare results with different prompts
- Implement attention mask injection (requires modifying cross-attention layers)
- Explore interpolating between two different embeddings
- Try adding noise to embeddings and see the effect

In [None]:
# Clean up cache to free home directory space
import shutil
import os
from pathlib import Path

def cleanup_cache():
    """Delete cache directories to free up home quota"""
    cache_dirs = [
        Path.home() / "data" / ".cache",
        Path.home() / ".cache" / "huggingface",
        Path.home() / ".cache" / "torch",
        Path.home() / ".cache" / "uv",
    ]
    
    total_freed = 0
    for cache_dir in cache_dirs:
        if cache_dir.exists():
            try:
                # Get size before deletion (approximate)
                size = sum(f.stat().st_size for f in cache_dir.rglob('*') if f.is_file())
                shutil.rmtree(cache_dir)
                total_freed += size
                print(f"\u2713 Deleted {cache_dir}: {size / 1024**3:.2f} GB")
            except Exception as e:
                print(f"\u2717 Could not delete {cache_dir}: {e}")
    
    print(f"\nTotal freed: {total_freed / 1024**3:.2f} GB")

# Run cleanup
cleanup_cache()