# Generate Images from Saved Embeddings using FLUX

This notebook loads T5 and CLIP embeddings from checkpoint files and uses FLUX to generate images.

In [None]:
import torch
import os
from pathlib import Path
from diffusers import FluxPipeline
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import numpy as np

## Configuration

In [None]:
# Paths
EMBEDDINGS_DIR = "checkpoints/embeddings/vegas-2n-2c-400s-kd-debug"
OUTPUT_DIR = "output_images/generated_from_embeddings"

# FLUX model configuration
FLUX_MODEL = "black-forest-labs/FLUX.1-schnell"  # Fast version
# Alternative: "black-forest-labs/FLUX.1-dev" for higher quality

# Generation parameters
NUM_INFERENCE_STEPS = 4  # Schnell optimized for 1-4 steps
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32

# FIX: Use server embeddings which have the actual computed embeddings
# Node embeddings have all None values
USE_SERVER_EMBEDDINGS = True

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"Device: {DEVICE}")
print(f"Dtype: {DTYPE}")
print(f"Embeddings directory: {EMBEDDINGS_DIR}")
print(f"Using server embeddings: {USE_SERVER_EMBEDDINGS}")
print(f"Output directory: {OUTPUT_DIR}")

## Load FLUX Pipeline

In [None]:
print(f"Loading FLUX model: {FLUX_MODEL}...")
pipe = FluxPipeline.from_pretrained(
    FLUX_MODEL,
    torch_dtype=DTYPE
)
pipe = pipe.to(DEVICE)

# Disable progress bar for cleaner output
pipe.set_progress_bar_config(disable=True)

print("✓ FLUX model loaded successfully!")

## Load Embedding Files

In [None]:
def load_embedding_files(embeddings_dir, use_server_embeddings=True):
    """Load all .pt embedding files from directory."""
    if use_server_embeddings:
        # Server embeddings have the actual computed embeddings
        pattern = "server_embeddings_*.pt"
    else:
        # Node embeddings (might have None values)
        pattern = "node_*_embeddings_*.pt"
    
    embedding_files = sorted(Path(embeddings_dir).glob(pattern))
    
    if not embedding_files:
        raise FileNotFoundError(f"No files matching '{pattern}' found in {embeddings_dir}")
    
    print(f"Found {len(embedding_files)} embedding files:")
    for f in embedding_files:
        print(f"  - {f.name}")
    
    return embedding_files

embedding_files = load_embedding_files(EMBEDDINGS_DIR, use_server_embeddings=USE_SERVER_EMBEDDINGS)

## Inspect Embedding Structure

In [None]:
# Load first file to inspect structure
sample_file = embedding_files[0]
print(f"\nInspecting: {sample_file.name}")

data = torch.load(sample_file, map_location='cpu')

print(f"\nTop-level keys: {list(data.keys())}")
print(f"Node ID: {data['node_id']}")
print(f"Round: {data['round']}")
print(f"Timestamp: {data['timestamp']}")
print(f"Number of embeddings: {len(data['embeddings'])}")

if data['embeddings']:
    print(f"\nFirst embedding keys: {list(data['embeddings'][0].keys())}")
    
    # Check for non-None embeddings
    valid_embeddings = [e for e in data['embeddings'] 
                       if e['t5_embedding'] is not None and e['clip_embedding'] is not None]
    
    print(f"Valid embeddings (non-None): {len(valid_embeddings)}/{len(data['embeddings'])}")
    
    if valid_embeddings:
        sample_emb = valid_embeddings[0]
        print(f"\nSample embedding info:")
        print(f"  Class: {sample_emb['class_name']}")
        print(f"  Split: {sample_emb['split']}")
        print(f"  T5 shape: {sample_emb['t5_embedding'].shape}")
        print(f"  CLIP shape: {sample_emb['clip_embedding'].shape}")
    else:
        print("\n⚠ Warning: No valid embeddings found (all are None)")
        print("This might happen if embeddings were saved in 'embeddings_only' mode")
        print("but the model wasn't run to generate the actual embeddings.")

## Generate Images from Embeddings

In [None]:
def generate_images_from_checkpoint(checkpoint_path, pipe, output_dir, max_images=None):
    """
    Generate images from embeddings in a checkpoint file.
    
    Args:
        checkpoint_path: Path to .pt checkpoint file
        pipe: FluxPipeline instance
        output_dir: Directory to save generated images
        max_images: Maximum number of images to generate (None = all)
    
    Returns:
        List of generated image paths
    """
    # Load checkpoint
    data = torch.load(checkpoint_path, map_location='cpu')
    node_id = data.get('node_id', 'server')
    round_num = data['round']
    
    # Filter valid embeddings
    valid_embeddings = [
        e for e in data['embeddings']
        if e['t5_embedding'] is not None and e['clip_embedding'] is not None
    ]
    
    if not valid_embeddings:
        print(f"⚠ No valid embeddings in {checkpoint_path.name}")
        return []
    
    # Limit number of images if specified
    if max_images is not None:
        valid_embeddings = valid_embeddings[:max_images]
    
    print(f"\nProcessing {checkpoint_path.name}:")
    print(f"  Node/Server {node_id}, Round {round_num}")
    print(f"  Generating {len(valid_embeddings)} images...")
    
    generated_paths = []
    
    # Generate images
    for i, emb_data in enumerate(tqdm(valid_embeddings, desc="Generating")):
        try:
            # IMPORTANT FIX: The embeddings are stored with inverted names!
            # In the checkpoint:
            #   - 't5_embedding' actually contains CLIP embeddings [1, 768]
            #   - 'clip_embedding' actually contains T5 embeddings [1, seq_len, 4096]
            # 
            # FLUX expects:
            #   - prompt_embeds: T5 embeddings [batch, seq_len, 4096]
            #   - pooled_prompt_embeds: CLIP embeddings [batch, 768]
            
            # Get embeddings (with inverted names from checkpoint)
            t5_saved = emb_data['t5_embedding']  # Actually CLIP [1, 768]
            clip_saved = emb_data['clip_embedding']  # Actually T5 [1, seq_len, 4096]
            
            # Prepare embeddings with correct assignment for FLUX
            prompt_embeds = clip_saved.to(DEVICE).to(DTYPE)  # T5: [1, seq_len, 4096]
            pooled_prompt_embeds = t5_saved.to(DEVICE).to(DTYPE)  # CLIP: [1, 768]
            
            # Ensure correct shapes
            if len(pooled_prompt_embeds.shape) == 2:
                # Already [batch, 768] - good
                pass
            elif len(pooled_prompt_embeds.shape) == 1:
                pooled_prompt_embeds = pooled_prompt_embeds.unsqueeze(0)
            
            if len(prompt_embeds.shape) == 2:
                # If [seq_len, 4096], add batch dimension
                prompt_embeds = prompt_embeds.unsqueeze(0)
            
            print(f"  Image {i}: prompt_embeds={prompt_embeds.shape}, pooled={pooled_prompt_embeds.shape}")
            
            # Generate image
            result = pipe(
                prompt_embeds=prompt_embeds,
                pooled_prompt_embeds=pooled_prompt_embeds,
                num_inference_steps=NUM_INFERENCE_STEPS,
                output_type="pil",
                generator=torch.Generator(device=DEVICE).manual_seed(42 + i)  # Different seed per image
            )
            
            image = result.images[0]
            
            # Save image
            class_name = emb_data.get('class_name', 'unknown')
            split = emb_data.get('split', 'global')
            
            filename = f"server_r{round_num}_{split}_{class_name}_{i}.png"
            save_path = os.path.join(output_dir, filename)
            
            image.save(save_path)
            generated_paths.append(save_path)
            
            # Clean up GPU memory
            del prompt_embeds, pooled_prompt_embeds, result
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                
        except Exception as e:
            print(f"  ✗ Error generating image {i}: {e}")
            import traceback
            traceback.print_exc()
            continue
    
    print(f"  ✓ Generated {len(generated_paths)} images")
    return generated_paths

## Generate Images from All Checkpoints

In [None]:
# Set max_images_per_file to limit generation (useful for testing)
MAX_IMAGES_PER_FILE = 10  # Set to None to generate all images

all_generated_images = []

for checkpoint_file in embedding_files:
    generated = generate_images_from_checkpoint(
        checkpoint_file,
        pipe,
        OUTPUT_DIR,
        max_images=MAX_IMAGES_PER_FILE
    )
    all_generated_images.extend(generated)

print(f"\n{'='*60}")
print(f"Total images generated: {len(all_generated_images)}")
print(f"Saved to: {OUTPUT_DIR}")
print(f"{'='*60}")

## Display Sample Generated Images

In [None]:
# Display first 6 generated images
num_display = min(6, len(all_generated_images))

if num_display > 0:
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()
    
    for i, img_path in enumerate(all_generated_images[:num_display]):
        img = Image.open(img_path)
        axes[i].imshow(img)
        axes[i].set_title(Path(img_path).name, fontsize=8)
        axes[i].axis('off')
    
    # Hide unused subplots
    for i in range(num_display, 6):
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()
else:
    print("No images to display")

## Generate Images from Specific Embeddings

Use this section if you want to generate images from specific embeddings only.

In [None]:
# Example: Generate from specific node and round
target_node = 0
target_round = 1

target_file = Path(EMBEDDINGS_DIR) / f"node_{target_node}_embeddings_r{target_round}.pt"

if target_file.exists():
    print(f"Generating from: {target_file.name}")
    generated = generate_images_from_checkpoint(
        target_file,
        pipe,
        OUTPUT_DIR,
        max_images=5  # Generate just 5 images
    )
    
    # Display results
    if generated:
        fig, axes = plt.subplots(1, len(generated), figsize=(4*len(generated), 4))
        if len(generated) == 1:
            axes = [axes]
        
        for ax, img_path in zip(axes, generated):
            img = Image.open(img_path)
            ax.imshow(img)
            ax.set_title(Path(img_path).name, fontsize=8)
            ax.axis('off')
        
        plt.tight_layout()
        plt.show()
else:
    print(f"File not found: {target_file}")

## Batch Generate with Custom Parameters

In [None]:
def batch_generate_images(embeddings_list, pipe, batch_size=4):
    """
    Generate multiple images in batches for better efficiency.
    
    Args:
        embeddings_list: List of embedding dictionaries
        pipe: FluxPipeline instance
        batch_size: Number of images to generate at once
    
    Returns:
        List of PIL Images
    """
    all_images = []
    
    for i in tqdm(range(0, len(embeddings_list), batch_size), desc="Batches"):
        batch = embeddings_list[i:i+batch_size]
        
        # Stack embeddings
        prompt_embeds_batch = torch.stack([
            e['t5_embedding'] for e in batch
        ]).to(DEVICE).to(DTYPE)
        
        pooled_prompt_embeds_batch = torch.stack([
            e['clip_embedding'] for e in batch
        ]).to(DEVICE).to(DTYPE)
        
        # Generate batch
        try:
            result = pipe(
                prompt_embeds=prompt_embeds_batch,
                pooled_prompt_embeds=pooled_prompt_embeds_batch,
                num_inference_steps=NUM_INFERENCE_STEPS,
                output_type="pil"
            )
            
            all_images.extend(result.images)
            
        except Exception as e:
            print(f"Error in batch {i//batch_size}: {e}")
            continue
        
        # Clean up
        del prompt_embeds_batch, pooled_prompt_embeds_batch, result
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    return all_images

# Example usage:
# Load embeddings from a file
# data = torch.load(embedding_files[0], map_location='cpu')
# valid_embs = [e for e in data['embeddings'] if e['t5_embedding'] is not None][:8]
# images = batch_generate_images(valid_embs, pipe, batch_size=2)

## Cleanup

In [None]:
# Free up GPU memory
if torch.cuda.is_available():
    del pipe
    torch.cuda.empty_cache()
    print("✓ GPU memory cleared")