# T5-XXL Embeddings for FLUX

This notebook generates **T5-XXL text embeddings** for use with FLUX image generation.

- **Model**: Google T5-v1.1-XXL encoder
- **Embedding dimension**: 4096
- **Sequence length**: 512 tokens
- **Output shape**: [512, 4096]

```mermaid
flowchart LR
    T[Text Prompt]
    
    TOK[T5 Tokenizer]
    ENC[T5-XXL Encoder]
    
    EMB[Text Embedding<br/>512 × 4096]
    
    FLUX[FLUX<br/>Diffusion Transformer]
    
    T --> TOK --> ENC --> EMB
    EMB -->|sequence conditioning| FLUX
```

In [None]:
import torch
import json
import numpy as np
from transformers import T5EncoderModel, T5Tokenizer
import os
from pathlib import Path
import ipywidgets as widgets
#from diffusers import FluxPipeline
#from IPython.display import display, Image as IPImage
from PIL import Image
# Set device

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load models path from config
current_dir = Path.cwd()
models_path_file = current_dir.parent / "misc/paths/models.txt"
with open(models_path_file, 'r') as f:
    models_path = f.read().strip()
MODELS_DIR = current_dir.parent / models_path
T5_MODEL_PATH = os.path.join(MODELS_DIR, "t5-v1_1-xxl")
# os.makedirs(MODELS_DIR, exist_ok=True)
# print(f"Models directory: {os.path.abspath(MODELS_DIR)}")
# print(f"T5 path: {os.path.abspath(T5_MODEL_PATH)}")
# print(f"FLUX path: {os.path.abspath(FLUX_MODEL_PATH)}")

In [2]:
# Load T5-XXL from local folder
print(f"Loading T5 model from: {T5_MODEL_PATH}...")

if not os.path.exists(T5_MODEL_PATH):
    print("\n⚠️  Model not found locally. Downloading from Hugging Face...")
    print("This is a large model (~11GB) and will take several minutes.")
    print("Please be patient...\n")
    
    # Download and save to local folder
    tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-xxl")
    t5_model = T5EncoderModel.from_pretrained(
        "google/t5-v1_1-xxl",
        torch_dtype=torch.bfloat16  # Use bfloat16 to match FLUX
    )
    
    # Save to local folder
    print(f"Saving model to {T5_MODEL_PATH}...")
    tokenizer.save_pretrained(T5_MODEL_PATH)
    t5_model.save_pretrained(T5_MODEL_PATH)
    print("✓ Model downloaded and saved locally!\n")
else:
    print("✓ Loading from local folder...\n")

# Load from local folder
tokenizer = T5Tokenizer.from_pretrained(T5_MODEL_PATH, local_files_only=True)
t5_model = T5EncoderModel.from_pretrained(
    T5_MODEL_PATH,
    torch_dtype=torch.bfloat16,  # Use bfloat16 to match FLUX
    local_files_only=True
).to(device)

t5_model.eval()  # Set to evaluation mode

`torch_dtype` is deprecated! Use `dtype` instead!


Loading T5 model from: /shares/weddigen.ki.uzh/laura_wagner/latent_vandalism_workshop/data/models/t5-v1_1-xxl...
✓ Loading from local folder...



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

T5EncoderModel(
  (shared): Embedding(32128, 4096)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 4096)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=4096, out_features=4096, bias=False)
              (k): Linear(in_features=4096, out_features=4096, bias=False)
              (v): Linear(in_features=4096, out_features=4096, bias=False)
              (o): Linear(in_features=4096, out_features=4096, bias=False)
              (relative_attention_bias): Embedding(32, 64)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=4096, out_features=10240, bias=False)
              (wi_1): Linear(in_features=4096, out_features=10240, bias=False)
              (wo

In [3]:
# Create text input widget
prompt_input = widgets.Textarea(
    value='a red cat sitting on a blue table',
    placeholder='Enter your prompt here',
    description='Prompt:',
    layout=widgets.Layout(width='80%', height='80px')
)

generate_button = widgets.Button(
    description='Generate Embedding',
    button_style='success'
)

output_area = widgets.Output()

# Global variable to store current embedding
current_embedding = None
current_tokens = None

def generate_embedding(b):
    global current_embedding, current_tokens
    
    with output_area:
        output_area.clear_output()
        
        prompt = prompt_input.value
        print(f"Generating embedding for: '{prompt}'\n")
        
        # Tokenize
        tokens = tokenizer(
            prompt,
            padding="max_length",
            max_length=512,
            truncation=True,
            return_tensors="pt"
        )
        
        # Get token strings for display
        token_ids = tokens['input_ids'][0].tolist()
        token_strings = [tokenizer.decode([tid]) for tid in token_ids]
        
        # Find how many real tokens (non-padding)
        num_real_tokens = (tokens['input_ids'][0] != tokenizer.pad_token_id).sum().item()
        
        print(f"Tokenized into {num_real_tokens} real tokens (+ {512 - num_real_tokens} padding):")
        print("First 10 tokens:", token_strings[:10])
        print()
        
        # Generate embedding
        with torch.no_grad():
            tokens = {k: v.to(device) for k, v in tokens.items()}
            outputs = t5_model(**tokens)
            embedding = outputs.last_hidden_state  # Shape: [1, 512, embedding_dim]
        
        # Convert bfloat16 to float32 before converting to numpy
        current_embedding = embedding.float().cpu().numpy()[0]  # Shape: [512, embedding_dim]
        current_tokens = token_strings
        
        embedding_dim = current_embedding.shape[1]
        total_numbers = current_embedding.shape[0] * current_embedding.shape[1]
        
        print(f"✓ Embedding generated!")
        print(f"  Shape: {current_embedding.shape}")
        print(f"  Total numbers: {total_numbers:,}")
        print(f"  Size: {current_embedding.nbytes / 1024:.2f} KB")
        print()
        print(f"First token '{token_strings[0]}' embedding (first 10 values):")
        print(current_embedding[0, :10])

generate_button.on_click(generate_embedding)

display(prompt_input, generate_button, output_area)

Textarea(value='a red cat sitting on a blue table', description='Prompt:', layout=Layout(height='80px', width=…

Button(button_style='success', description='Generate Embedding', style=ButtonStyle())

Output()

In [4]:
# Save embedding to file
EMBEDDINGS_DIR = current_dir.parent / "data/embeddings/T5"
os.makedirs(EMBEDDINGS_DIR, exist_ok=True)

save_button = widgets.Button(
    description='Save Embedding',
    button_style='primary'
)

save_output = widgets.Output()

def save_embedding(b):
    with save_output:
        save_output.clear_output()
        
        if current_embedding is None:
            print("❌ No embedding to save! Generate an embedding first.")
            return
        
        # Get first 4 non-padding tokens (excluding special tokens)
        filename_tokens = []
        for token in current_tokens:
            # Skip padding tokens and clean up special characters
            cleaned_token = token.strip().replace('▁', '').replace('</s>', '')
            if cleaned_token and cleaned_token != '<pad>':
                filename_tokens.append(cleaned_token)
            if len(filename_tokens) >= 4:
                break
        
        # Create filename from first 4 tokens
        filename = "_".join(filename_tokens) + ".json"
        filepath = EMBEDDINGS_DIR / filename
        
        # Save embedding
        embedding_data = {
            "prompt": prompt_input.value,
            "embedding": current_embedding.tolist(),
            "shape": list(current_embedding.shape)
        }
        
        with open(filepath, 'w') as f:
            json.dump(embedding_data, f)
        
        print(f"✓ Embedding saved to:")
        print(f"  {filepath}")
        print(f"  Size: {os.path.getsize(filepath) / 1024:.2f} KB")

save_button.on_click(save_embedding)

display(save_button, save_output)

Button(button_style='primary', description='Save Embedding', style=ButtonStyle())

Output()

---
<sub>Latent Vandalism Workshop • Laura Wagner, 2026 • [laurajul.github.io](https://laurajul.github.io/)</sub>