# T5-XXL Embeddings and Ornithology for FLUX ü¶öü¶úüê¶

This notebook generates **T5-XXL text embeddings** for use with FLUX image generation.

- **Model**: Google T5-v1.1-XXL encoder
- **Embedding dimension**: 4096
- **Sequence length**: 512 tokens
- **Output shape**: [512, 4096]

```mermaid
flowchart LR
    T[Text Prompt]
    
    TOK[T5 Tokenizer]
    ENC[T5-XXL Encoder]
    
    EMB[Text Embedding<br/>512 √ó 4096]
    
    FLUX[FLUX<br/>Diffusion Transformer]
    
    T --> TOK --> ENC --> EMB
    EMB -->|sequence conditioning| FLUX
```

In [1]:
import torch
import json
import numpy as np
from transformers import T5EncoderModel, T5Tokenizer
import os
from pathlib import Path
import ipywidgets as widgets
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load models path from config
current_dir = Path.cwd()
models_path_file = current_dir.parent / "misc/paths/models.txt"
with open(models_path_file, 'r') as f:
    models_path = f.read().strip()
MODELS_DIR = current_dir.parent / models_path
T5_MODEL_PATH = os.path.join(MODELS_DIR, "t5-v1_1-xxl")
# os.makedirs(MODELS_DIR, exist_ok=True)
# print(f"Models directory: {os.path.abspath(MODELS_DIR)}")
# print(f"T5 path: {os.path.abspath(T5_MODEL_PATH)}")
# print(f"FLUX path: {os.path.abspath(FLUX_MODEL_PATH)}")

Using device: cuda


In [2]:
# Load T5-XXL from local folder
print(f"Loading T5 model from: {T5_MODEL_PATH}...")

if not os.path.exists(T5_MODEL_PATH):
    print("\n‚ö†Ô∏è  Model not found locally. Downloading from Hugging Face...")
    print("This is a large model (~11GB) and will take several minutes.")
    print("Please be patient...\n")
    
    # Download and save to local folder
    tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-xxl")
    t5_model = T5EncoderModel.from_pretrained(
        "google/t5-v1_1-xxl",
        torch_dtype=torch.bfloat16  # Use bfloat16 to match FLUX
    )
    
    # Save to local folder
    print(f"Saving model to {T5_MODEL_PATH}...")
    tokenizer.save_pretrained(T5_MODEL_PATH)
    t5_model.save_pretrained(T5_MODEL_PATH)
    print("‚úì Model downloaded and saved locally!\n")
else:
    print("‚úì Loading from local folder...\n")

# Load from local folder
tokenizer = T5Tokenizer.from_pretrained(T5_MODEL_PATH, local_files_only=True)
t5_model = T5EncoderModel.from_pretrained(
    T5_MODEL_PATH,
    torch_dtype=torch.bfloat16,  # Use bfloat16 to match FLUX
    local_files_only=True
).to(device)

t5_model.eval()  # Set to evaluation mode

`torch_dtype` is deprecated! Use `dtype` instead!


Loading T5 model from: /shares/weddigen.ki.uzh/laura_wagner/latent_vandalism_workshop/data/models/t5-v1_1-xxl...
‚úì Loading from local folder...



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

T5EncoderModel(
  (shared): Embedding(32128, 4096)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 4096)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=4096, out_features=4096, bias=False)
              (k): Linear(in_features=4096, out_features=4096, bias=False)
              (v): Linear(in_features=4096, out_features=4096, bias=False)
              (o): Linear(in_features=4096, out_features=4096, bias=False)
              (relative_attention_bias): Embedding(32, 64)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=4096, out_features=10240, bias=False)
              (wi_1): Linear(in_features=4096, out_features=10240, bias=False)
              (wo

## Generate Single Embedding

Enter a prompt to generate its T5 embedding.

In [7]:
# Create text input widget
prompt_input = widgets.Textarea(
    value='a puffy european robin sitting on a tree branch',
    placeholder='Enter your prompt here',
    description='Prompt:',
    layout=widgets.Layout(width='80%', height='80px')
)

generate_button = widgets.Button(
    description='Generate Embedding',
    button_style='success'
)

output_area = widgets.Output()

# Global variable to store current embedding
current_embedding = None
current_tokens = None

def generate_embedding(b):
    global current_embedding, current_tokens
    
    with output_area:
        output_area.clear_output()
        
        prompt = prompt_input.value
        print(f"Generating embedding for: '{prompt}'\n")
        
        # Tokenize
        tokens = tokenizer(
            prompt,
            padding="max_length",
            max_length=512,
            truncation=True,
            return_tensors="pt"
        )
        
        # Get token strings for display
        token_ids = tokens['input_ids'][0].tolist()
        token_strings = [tokenizer.decode([tid]) for tid in token_ids]
        
        # Find how many real tokens (non-padding)
        num_real_tokens = (tokens['input_ids'][0] != tokenizer.pad_token_id).sum().item()
        
        print(f"Tokenized into {num_real_tokens} real tokens (+ {512 - num_real_tokens} padding):")
        print("First 10 tokens:", token_strings[:10])
        print()
        
        # Generate embedding
        with torch.no_grad():
            tokens = {k: v.to(device) for k, v in tokens.items()}
            outputs = t5_model(**tokens)
            embedding = outputs.last_hidden_state  # Shape: [1, 512, embedding_dim]
        
        # Convert bfloat16 to float32 before converting to numpy
        current_embedding = embedding.float().cpu().numpy()[0]  # Shape: [512, embedding_dim]
        current_tokens = token_strings
        
        embedding_dim = current_embedding.shape[1]
        total_numbers = current_embedding.shape[0] * current_embedding.shape[1]
        
        print(f"‚úì Embedding generated!")
        print(f"  Shape: {current_embedding.shape}")
        print(f"  Total numbers: {total_numbers:,}")
        print(f"  Size: {current_embedding.nbytes / 1024:.2f} KB")
        print()
        print(f"First token '{token_strings[0]}' embedding (first 10 values):")
        print(current_embedding[0, :10])



In [8]:
# Save embedding to file
EMBEDDINGS_DIR = current_dir.parent / "data/embeddings/T5"
os.makedirs(EMBEDDINGS_DIR, exist_ok=True)

save_button = widgets.Button(
    description='Save Embedding',
    button_style='primary'
)

save_output = widgets.Output()

def save_embedding(b):
    with save_output:
        save_output.clear_output()
        
        if current_embedding is None:
            print("‚ùå No embedding to save! Generate an embedding first.")
            return
        
        # Get first 4 non-padding tokens (excluding special tokens)
        filename_tokens = []
        for token in current_tokens:
            # Skip padding tokens and clean up special characters
            cleaned_token = token.strip().replace('‚ñÅ', '').replace('</s>', '')
            if cleaned_token and cleaned_token != '<pad>':
                filename_tokens.append(cleaned_token)
            if len(filename_tokens) >= 4:
                break
        
        # Create filename from first 4 tokens
        filename = "_".join(filename_tokens) + ".json"
        filepath = EMBEDDINGS_DIR / filename
        
        # Save embedding
        embedding_data = {
            "prompt": prompt_input.value,
            "embedding": current_embedding.tolist(),
            "shape": list(current_embedding.shape)
        }
        
        with open(filepath, 'w') as f:
            json.dump(embedding_data, f)
        
        print(f"‚úì Embedding saved to:")
        print(f"  {filepath}")
        print(f"  Size: {os.path.getsize(filepath) / 1024:.2f} KB")

generate_button.on_click(generate_embedding)

display(prompt_input, generate_button, output_area)

save_button.on_click(save_embedding)

display(save_button, save_output)

Textarea(value='a puffy european robin sitting on a tree branch', description='Prompt:', layout=Layout(height=‚Ä¶

Button(button_style='success', description='Generate Embedding', style=ButtonStyle())

Output()

Button(button_style='primary', description='Save Embedding', style=ButtonStyle())

Output()

## Batch Generation from Text Input

Enter multiple prompts (one per line) to generate and save embeddings for all of them.

In [13]:
# Batch generate T5 embeddings from text input
from IPython.display import display

batch_text_input = widgets.Textarea(
    value='A curious Raggiana bird-of-paradise peeking through dense green leaves.\nA magnificent riflebird with iridescent feathers perched on a mossy log.\nA vibrant King of Saxony bird-of-paradise showing off its long head plumes.\nA stunning Superb bird-of-paradife doing a courtship dance on the forest floor.',
    placeholder='Enter prompts, one per line',
    description='Prompts:',
    layout=widgets.Layout(width='80%', height='150px')
)

batch_text_button = widgets.Button(
    description='Batch Generate & Save',
    button_style='warning'
)

batch_text_output = widgets.Output()

def batch_generate_from_text(b):
    with batch_text_output:
        batch_text_output.clear_output()
        
        # Parse prompts (one per line)
        prompts = [p.strip() for p in batch_text_input.value.strip().split('\n') if p.strip()]
        
        if not prompts:
            print("No prompts provided!")
            return
        
        print(f"Generating {len(prompts)} T5 embeddings...\n")
        
        for i, prompt in enumerate(prompts, 1):
            print(f"[{i}/{len(prompts)}] '{prompt[:60]}{'...' if len(prompt) > 60 else ''}'")
            
            # Tokenize
            tokens = tokenizer(
                prompt,
                padding="max_length",
                max_length=512,
                truncation=True,
                return_tensors="pt"
            )
            
            # Get token strings for filename
            token_ids = tokens['input_ids'][0].tolist()
            token_strings = [tokenizer.decode([tid]) for tid in token_ids]
            
            # Generate embedding
            with torch.no_grad():
                tokens = {k: v.to(device) for k, v in tokens.items()}
                outputs = t5_model(**tokens)
                embedding = outputs.last_hidden_state.float().cpu().numpy()[0]
            
            # Create filename from first 4 tokens
            filename_tokens = []
            for token in token_strings:
                cleaned = token.strip().replace('‚ñÅ', '').replace('</s>', '')
                if cleaned and cleaned != '<pad>':
                    filename_tokens.append(cleaned)
                if len(filename_tokens) >= 4:
                    break
            
            filename = "_".join(filename_tokens) + ".json"
            filepath = EMBEDDINGS_DIR / filename
            
            # Save embedding
            embedding_data = {
                "prompt": prompt,
                "embedding": embedding.tolist(),
                "shape": list(embedding.shape)
            }
            
            with open(filepath, 'w') as f:
                json.dump(embedding_data, f)
            
            print(f"   ‚úì Saved: {filename}")
        
        print(f"\n‚úì All {len(prompts)} embeddings saved to:")
        print(f"  {EMBEDDINGS_DIR}")

batch_text_button.on_click(batch_generate_from_text)

print("Batch T5 Embedding Generator")
print(f"Output directory: {EMBEDDINGS_DIR}")
print("Enter prompts (one per line):\n")
display(batch_text_input, batch_text_button, batch_text_output)

Batch T5 Embedding Generator
Output directory: /shares/weddigen.ki.uzh/laura_wagner/latent_vandalism_workshop/data/embeddings/T5
Enter prompts (one per line):



Textarea(value='A curious Raggiana bird-of-paradise peeking through dense green leaves.\nA magnificent riflebi‚Ä¶



Output()

## Batch Generation from Example Prompts File

Load prompts from `misc/example_prompts.txt` and generate embeddings. Files are saved to `examples/` subfolder.

In [10]:
# import os
# import json
# import torch
# import ipywidgets as widgets
# from pathlib import Path
# from IPython.display import display

# # ------------------------------------------------------------------
# # Paths
# # ------------------------------------------------------------------

# current_dir = Path.cwd()
# BASE_EXAMPLES_DIR = current_dir.parent / "data/embeddings/T5/examples"
# os.makedirs(BASE_EXAMPLES_DIR, exist_ok=True)

# prompts_file = current_dir.parent / "misc/example_prompts.txt"

# # ------------------------------------------------------------------
# # T5 UX config
# # ------------------------------------------------------------------

# T5_CONFIG = {
#     'short': {
#         'section_name': 'Short prompts',
#         'subdir': 'short',
#         'max_length': 77
#     },
#     '77_tokens': {
#         'section_name': 'Short prompts',
#         'subdir': '77_tokens',
#         'max_length': 77
#     },
#     '512_tokens': {
#         'section_name': 'T5-XXL prompts',
#         'subdir': '512_tokens',
#         'max_length': 512
#     }
# }

# # ------------------------------------------------------------------
# # Prompt loader (section-aware)
# # ------------------------------------------------------------------

# def load_t5_prompts_from_file(filepath, section_name):
#     """Load a specific prompt section from example_prompts.txt."""
#     if not filepath.exists():
#         return []

#     with open(filepath, 'r', encoding='utf-8') as f:
#         content = f.read()

#     sections = content.split('#')
#     target_section = None

#     for section in sections:
#         if section_name in section:
#             target_section = section
#             break

#     if target_section is None:
#         return []

#     prompts = []
#     for line in target_section.splitlines():
#         line = line.strip()
#         if line and not line.startswith('#'):
#             prompts.append(line)

#     return prompts

# # ------------------------------------------------------------------
# # Widgets
# # ------------------------------------------------------------------

# t5_token_selector = widgets.Dropdown(
#     options=[
#         ('Short prompts', 'short'),
#         ('77 tokens (short)', '77_tokens'),
#         ('512 tokens (T5-XXL)', '512_tokens')
#     ],
#     value='512_tokens',
#     description='Prompt Set:',
#     style={'description_width': 'initial'}
# )

# file_batch_button = widgets.Button(
#     description='Generate from File',
#     button_style='info'
# )

# file_batch_output = widgets.Output()

# # ------------------------------------------------------------------
# # Batch generation
# # ------------------------------------------------------------------

# def batch_generate_from_file(b):
#     with file_batch_output:
#         file_batch_output.clear_output()

#         selection = t5_token_selector.value
#         config = T5_CONFIG[selection]

#         max_length = config['max_length']
#         output_dir = BASE_EXAMPLES_DIR / config['subdir']
#         os.makedirs(output_dir, exist_ok=True)

#         prompts = load_t5_prompts_from_file(
#             prompts_file,
#             config['section_name']
#         )

#         if not prompts:
#             print(f"‚ùå No prompts found for section '{config['section_name']}'.")
#             return

#         print(f"Loaded {len(prompts)} prompts")
#         print(f"Section: {config['section_name']}")
#         print(f"Max tokens: {max_length}")
#         print(f"Output: {output_dir}\n")

#         for i, prompt in enumerate(prompts, 1):
#             print(f"[{i}/{len(prompts)}] {prompt[:60]}{'...' if len(prompt) > 60 else ''}")

#             tokens = tokenizer(
#                 prompt,
#                 padding="max_length",
#                 max_length=max_length,
#                 truncation=True,
#                 return_tensors="pt"
#             )

#             with torch.no_grad():
#                 tokens = {k: v.to(device) for k, v in tokens.items()}
#                 outputs = t5_model(**tokens)
#                 embedding = outputs.last_hidden_state.float().cpu().numpy()[0]

#             # Filename tokens
#             token_ids = tokens['input_ids'][0].tolist()
#             token_strings = [tokenizer.decode([tid]) for tid in token_ids]

#             filename_tokens = []
#             for token in token_strings:
#                 cleaned = token.strip().replace('‚ñÅ', '').replace('</s>', '')
#                 if cleaned and cleaned != '<pad>':
#                     filename_tokens.append(cleaned)
#                 if len(filename_tokens) >= 4:
#                     break

#             filename = "_".join(filename_tokens) + ".json"
#             filepath = output_dir / filename

#             with open(filepath, 'w') as f:
#                 json.dump({
#                     "prompt": prompt,
#                     "embedding": embedding.tolist(),
#                     "shape": list(embedding.shape),
#                     "max_length": max_length,
#                     "prompt_set": config['section_name']
#                 }, f)

#             print(f"   ‚úì Saved: {filename}")

#         print("\n‚úÖ All T5 embeddings generated successfully.")

# # ------------------------------------------------------------------
# # UI wiring
# # ------------------------------------------------------------------

# file_batch_button.on_click(batch_generate_from_file)

# print("Generate T5 embeddings from example_prompts.txt")
# print(f"Source: {prompts_file}")
# print(f"Base output: {BASE_EXAMPLES_DIR}\n")

# display(t5_token_selector, file_batch_button, file_batch_output)


Generate T5 embeddings from example_prompts.txt
Source: /shares/weddigen.ki.uzh/laura_wagner/latent_vandalism_workshop/misc/example_prompts.txt
Base output: /shares/weddigen.ki.uzh/laura_wagner/latent_vandalism_workshop/data/embeddings/T5/examples



Dropdown(description='Prompt Set:', index=2, options=(('Short prompts', 'short'), ('77 tokens (short)', '77_to‚Ä¶

Button(button_style='info', description='Generate from File', style=ButtonStyle())

Output()

---
<sub>Latent Vandalism Workshop ‚Ä¢ Laura Wagner, 2026 ‚Ä¢ [laurajul.github.io](https://laurajul.github.io/)</sub>