From c46a649130abd310381e021f943ce661a35cffd1 Mon Sep 17 00:00:00 2001 From: davidb Date: Wed, 8 Oct 2025 14:42:09 +0000 Subject: [PATCH 01/52] Add Photon model and pipeline support This commit adds support for the Photon image generation model: - PhotonTransformer2DModel: Core transformer architecture - PhotonPipeline: Text-to-image generation pipeline - Attention processor updates for Photon-specific attention mechanism - Conversion script for loading Photon checkpoints - Documentation and tests --- docs/source/en/api/pipelines/photon.md | 171 ++++ scripts/convert_photon_to_diffusers.py | 353 ++++++++ src/diffusers/__init__.py | 1 + src/diffusers/models/__init__.py | 1 + src/diffusers/models/attention_processor.py | 58 ++ src/diffusers/models/transformers/__init__.py | 1 + .../models/transformers/transformer_photon.py | 812 ++++++++++++++++++ src/diffusers/pipelines/__init__.py | 1 + src/diffusers/pipelines/photon/__init__.py | 5 + .../pipelines/photon/pipeline_output.py | 35 + .../pipelines/photon/pipeline_photon.py | 643 ++++++++++++++ .../test_models_transformer_photon.py | 85 ++ 12 files changed, 2166 insertions(+) create mode 100644 docs/source/en/api/pipelines/photon.md create mode 100644 scripts/convert_photon_to_diffusers.py create mode 100644 src/diffusers/models/transformers/transformer_photon.py create mode 100644 src/diffusers/pipelines/photon/__init__.py create mode 100644 src/diffusers/pipelines/photon/pipeline_output.py create mode 100644 src/diffusers/pipelines/photon/pipeline_photon.py create mode 100644 tests/models/transformers/test_models_transformer_photon.py diff --git a/docs/source/en/api/pipelines/photon.md b/docs/source/en/api/pipelines/photon.md new file mode 100644 index 000000000000..270284673f92 --- /dev/null +++ b/docs/source/en/api/pipelines/photon.md @@ -0,0 +1,171 @@ + + +# PhotonPipeline + +
+ LoRA +
+ +Photon is a text-to-image diffusion model using simplified MMDIT architecture with flow matching for efficient high-quality image generation. The model uses T5Gemma as the text encoder and supports either Flux VAE (AutoencoderKL) or DC-AE (AutoencoderDC) for latent compression. + +Key features: + +- **Simplified MMDIT architecture**: Uses a simplified MMDIT architecture for image generation where text tokens are not updated through the transformer blocks +- **Flow Matching**: Employs flow matching with discrete scheduling for efficient sampling +- **Flexible VAE Support**: Compatible with both Flux VAE (8x compression, 16 latent channels) and DC-AE (32x compression, 32 latent channels) +- **T5Gemma Text Encoder**: Uses Google's T5Gemma-2B-2B-UL2 model for text encoding offering multiple language support +- **Efficient Architecture**: ~1.3B parameters in the transformer, enabling fast inference while maintaining quality + +## Available models: +We offer a range of **Photon models** featuring different **VAE configurations**, each optimized for generating images at various resolutions. +Both **fine-tuned** and **non-fine-tuned** versions are available: + +- **Non-fine-tuned models** perform best with **highly detailed prompts**, capturing fine nuances and complex compositions. +- **Fine-tuned models**, trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist), enhance the **aesthetic quality** of the base models—especially when prompts are **less detailed**. + + +| Model | Recommended dtype | Resolution | Fine-tuned | +|:-----:|:-----------------:|:----------:|:----------:| +| [`Photoroom/photon-256-t2i`](https://huggingface.co/Photoroom/photon-256-t2i) | `torch.bfloat16` | 256x256 | No | +| [`Photoroom/photon-256-t2i-sft`](https://huggingface.co/Photoroom/photon-256-t2i-sft) | `torch.bfloat16` | 256x256 | Yes | +| [`Photoroom/photon-512-t2i`](https://huggingface.co/Photoroom/photon-512-t2i) | `torch.bfloat16` | 512x512 | No | +| [`Photoroom/photon-512-t2i-sft`](hhttps://huggingface.co/Photoroom/photon-512-t2i-sft) | `torch.bfloat16` | 512x512 | Yes | +| [`Photoroom/photon-512-t2i-dc-ae`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae) | `torch.bfloat16` | 512x512 | No | +| [`Photoroom/photon-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft) | `torch.bfloat16` | 512x512 | Yes | + +Refer to [this](https://huggingface.co/collections/Photoroom/photon-models-68e66254c202ebfab99ad38e) collection for more information. + +## Loading the Pipeline + +```py +from diffusers.pipelines.photon import PhotonPipeline + +# Load pipeline - VAE and text encoder will be loaded from HuggingFace +pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i") +pipe.to("cuda") + +prompt = "A vast night sky over a quiet city suddenly blazes with enormous glowing neon letters spelling “PHOTON.” The word hums and flickers dramatically, as if trying a little too hard to look epic. The soft glow bathes the rooftops and streets below in blue and pink light. A few people look up, squinting, some taking selfies; a cat blinks lazily at the sky’s new centerpiece. The air feels cinematic and electric — like a sci-fi movie that doesn’t take itself too seriously. Mist swirls around the neon glow, adding a dreamy, aesthetic touch to the humor of it all." +image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0] +image.save("photon_output.png") +``` + +### Manual Component Loading + +You can also load components individually: + +```py +import torch +from diffusers import PhotonPipeline +from diffusers.models import AutoencoderKL, AutoencoderDC +from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from transformers import T5GemmaModel, GemmaTokenizerFast + +# Load transformer +transformer = PhotonTransformer2DModel.from_pretrained( + "Photoroom/photon-512-t2i", subfolder="transformer" +) + +# Load scheduler +scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + "Photoroom/photon-512-t2i", subfolder="scheduler" +) + +# Load T5Gemma text encoder +t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2") +text_encoder = t5gemma_model.encoder +tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2") +tokenizer.model_max_length = 256 +# Load VAE - choose either Flux VAE or DC-AE +# Flux VAE (16 latent channels): +vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae") +# Or DC-AE (32 latent channels): +# vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers") + +pipe = PhotonPipeline( + transformer=transformer, + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae +) +pipe.to("cuda") +``` + +## VAE Variants + +Photon supports two VAE configurations: + +### Flux VAE (AutoencoderKL) +- **Compression**: 8x spatial compression +- **Latent channels**: 16 +- **Model**: `black-forest-labs/FLUX.1-dev` (subfolder: "vae") +- **Use case**: Balanced quality and speed + +### DC-AE (AutoencoderDC) +- **Compression**: 32x spatial compression +- **Latent channels**: 32 +- **Model**: `mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers` +- **Use case**: Higher compression for faster processing + +The VAE type is automatically determined from the checkpoint's `model_index.json` configuration. + +## Generation Parameters + +Key parameters for image generation: + +- **num_inference_steps**: Number of denoising steps (default: 28). More steps generally improve quality at the cost of speed. +- **guidance_scale**: Classifier-free guidance strength (default: 4.0). Higher values produce images more closely aligned with the prompt. +- **height/width**: Output image dimensions (default: 512x512). Can be customized in the checkpoint configuration. + +```py +# Example with custom parameters +import torch +from diffusers.pipelines.photon import PhotonPipeline + +pipe = pipe( + prompt="A highly detailed 3D animated scene of a cute, intelligent duck scientist in a futuristic laboratory. The duck stands on a shiny metallic floor surrounded by glowing glass tubes filled with colorful liquids—blue, green, and purple—connected by translucent hoses emitting soft light. The duck wears a tiny white lab coat, safety goggles, and has a curious, determined expression while conducting an experiment. Sparks of energy and soft particle effects fill the air as scientific instruments hum with power. In the background, holographic screens display molecular diagrams and equations. Above the duck’s head, the word “PHOTON” glows vividly in midair as if made of pure light, illuminating the scene with a warm golden glow. The lighting is cinematic, with rich reflections and subtle depth of field, emphasizing a Pixar-like, ultra-polished 3D animation style. Rendered in ultra high resolution, realistic subsurface scattering on the duck’s feathers, and vibrant color grading that gives a sense of wonder and scientific discovery.", + num_inference_steps=28, + guidance_scale=4.0, + height=512, + width=512, + generator=torch.Generator("cuda").manual_seed(42) +).images[0] +``` + +## Memory Optimization + +For memory-constrained environments: + +```py +import torch +from diffusers.pipelines.photon import PhotonPipeline + +pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i", torch_dtype=torch.float16) +pipe.enable_model_cpu_offload() # Offload components to CPU when not in use + +# Or use sequential CPU offload for even lower memory +pipe.enable_sequential_cpu_offload() +``` + +## PhotonPipeline + +[[autodoc]] PhotonPipeline + - all + - __call__ + +## PhotonPipelineOutput + +[[autodoc]] pipelines.photon.pipeline_output.PhotonPipelineOutput diff --git a/scripts/convert_photon_to_diffusers.py b/scripts/convert_photon_to_diffusers.py new file mode 100644 index 000000000000..ac6aba458046 --- /dev/null +++ b/scripts/convert_photon_to_diffusers.py @@ -0,0 +1,353 @@ +#!/usr/bin/env python3 +""" +Script to convert Photon checkpoint from original codebase to diffusers format. +""" + +import argparse +import json +import os +import sys +from dataclasses import asdict, dataclass +from typing import Dict, Tuple + +import torch +from safetensors.torch import save_file + + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) + +from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel +from diffusers.pipelines.photon import PhotonPipeline + + +DEFAULT_RESOLUTION = 512 + + +@dataclass(frozen=True) +class PhotonBase: + context_in_dim: int = 2304 + hidden_size: int = 1792 + mlp_ratio: float = 3.5 + num_heads: int = 28 + depth: int = 16 + axes_dim: Tuple[int, int] = (32, 32) + theta: int = 10_000 + time_factor: float = 1000.0 + time_max_period: int = 10_000 + + +@dataclass(frozen=True) +class PhotonFlux(PhotonBase): + in_channels: int = 16 + patch_size: int = 2 + + +@dataclass(frozen=True) +class PhotonDCAE(PhotonBase): + in_channels: int = 32 + patch_size: int = 1 + + +def build_config(vae_type: str) -> Tuple[dict, int]: + if vae_type == "flux": + cfg = PhotonFlux() + elif vae_type == "dc-ae": + cfg = PhotonDCAE() + else: + raise ValueError(f"Unsupported VAE type: {vae_type}. Use 'flux' or 'dc-ae'") + + config_dict = asdict(cfg) + config_dict["axes_dim"] = list(config_dict["axes_dim"]) # type: ignore[index] + return config_dict + + +def create_parameter_mapping(depth: int) -> dict: + """Create mapping from old parameter names to new diffusers names.""" + + # Key mappings for structural changes + mapping = {} + + # RMSNorm: scale -> weight + for i in range(depth): + mapping[f"blocks.{i}.qk_norm.query_norm.scale"] = f"blocks.{i}.qk_norm.query_norm.weight" + mapping[f"blocks.{i}.qk_norm.key_norm.scale"] = f"blocks.{i}.qk_norm.key_norm.weight" + mapping[f"blocks.{i}.k_norm.scale"] = f"blocks.{i}.k_norm.weight" + + # Attention: attn_out -> attention.to_out.0 + mapping[f"blocks.{i}.attn_out.weight"] = f"blocks.{i}.attention.to_out.0.weight" + + return mapping + + +def convert_checkpoint_parameters(old_state_dict: Dict[str, torch.Tensor], depth: int) -> Dict[str, torch.Tensor]: + """Convert old checkpoint parameters to new diffusers format.""" + + print("Converting checkpoint parameters...") + + mapping = create_parameter_mapping(depth) + converted_state_dict = {} + + for key, value in old_state_dict.items(): + new_key = key + + # Apply specific mappings if needed + if key in mapping: + new_key = mapping[key] + print(f" Mapped: {key} -> {new_key}") + + # Handle img_qkv_proj -> split to to_q, to_k, to_v + if "img_qkv_proj.weight" in key: + print(f" Found QKV projection: {key}") + # Split QKV weight into separate Q, K, V projections + qkv_weight = value + q_weight, k_weight, v_weight = qkv_weight.chunk(3, dim=0) + + # Extract layer number from key (e.g., blocks.0.img_qkv_proj.weight -> 0) + parts = key.split(".") + layer_idx = None + for i, part in enumerate(parts): + if part == "blocks" and i + 1 < len(parts) and parts[i + 1].isdigit(): + layer_idx = parts[i + 1] + break + + if layer_idx is not None: + converted_state_dict[f"blocks.{layer_idx}.attention.to_q.weight"] = q_weight + converted_state_dict[f"blocks.{layer_idx}.attention.to_k.weight"] = k_weight + converted_state_dict[f"blocks.{layer_idx}.attention.to_v.weight"] = v_weight + print(f" Split QKV for layer {layer_idx}") + + # Also keep the original img_qkv_proj for backward compatibility + converted_state_dict[new_key] = value + else: + converted_state_dict[new_key] = value + + print(f"✓ Converted {len(converted_state_dict)} parameters") + return converted_state_dict + + +def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> PhotonTransformer2DModel: + """Create and load PhotonTransformer2DModel from old checkpoint.""" + + print(f"Loading checkpoint from: {checkpoint_path}") + + # Load old checkpoint + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + old_checkpoint = torch.load(checkpoint_path, map_location="cpu") + + # Handle different checkpoint formats + if isinstance(old_checkpoint, dict): + if "model" in old_checkpoint: + state_dict = old_checkpoint["model"] + elif "state_dict" in old_checkpoint: + state_dict = old_checkpoint["state_dict"] + else: + state_dict = old_checkpoint + else: + state_dict = old_checkpoint + + print(f"✓ Loaded checkpoint with {len(state_dict)} parameters") + + # Convert parameter names if needed + model_depth = int(config.get("depth", 16)) + converted_state_dict = convert_checkpoint_parameters(state_dict, depth=model_depth) + + # Create transformer with config + print("Creating PhotonTransformer2DModel...") + transformer = PhotonTransformer2DModel(**config) + + # Load state dict + print("Loading converted parameters...") + missing_keys, unexpected_keys = transformer.load_state_dict(converted_state_dict, strict=False) + + if missing_keys: + print(f"⚠ Missing keys: {missing_keys}") + if unexpected_keys: + print(f"⚠ Unexpected keys: {unexpected_keys}") + + if not missing_keys and not unexpected_keys: + print("✓ All parameters loaded successfully!") + + return transformer + + +def create_scheduler_config(output_path: str): + """Create FlowMatchEulerDiscreteScheduler config.""" + + scheduler_config = {"_class_name": "FlowMatchEulerDiscreteScheduler", "num_train_timesteps": 1000, "shift": 1.0} + + scheduler_path = os.path.join(output_path, "scheduler") + os.makedirs(scheduler_path, exist_ok=True) + + with open(os.path.join(scheduler_path, "scheduler_config.json"), "w") as f: + json.dump(scheduler_config, f, indent=2) + + print("✓ Created scheduler config") + + +def download_and_save_vae(vae_type: str, output_path: str): + """Download and save VAE to local directory.""" + from diffusers import AutoencoderDC, AutoencoderKL + + vae_path = os.path.join(output_path, "vae") + os.makedirs(vae_path, exist_ok=True) + + if vae_type == "flux": + print("Downloading FLUX VAE from black-forest-labs/FLUX.1-dev...") + vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae") + else: # dc-ae + print("Downloading DC-AE VAE from mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers...") + vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers") + + vae.save_pretrained(vae_path) + print(f"✓ Saved VAE to {vae_path}") + + +def download_and_save_text_encoder(output_path: str): + """Download and save T5Gemma text encoder and tokenizer.""" + from transformers import GemmaTokenizerFast + from transformers.models.t5gemma.modeling_t5gemma import T5GemmaModel + + text_encoder_path = os.path.join(output_path, "text_encoder") + tokenizer_path = os.path.join(output_path, "tokenizer") + os.makedirs(text_encoder_path, exist_ok=True) + os.makedirs(tokenizer_path, exist_ok=True) + + print("Downloading T5Gemma model from google/t5gemma-2b-2b-ul2...") + t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2") + + t5gemma_model.save_pretrained(text_encoder_path) + print(f"✓ Saved T5Gemma model to {text_encoder_path}") + + print("Downloading tokenizer from google/t5gemma-2b-2b-ul2...") + tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2") + tokenizer.model_max_length = 256 + tokenizer.save_pretrained(tokenizer_path) + print(f"✓ Saved tokenizer to {tokenizer_path}") + + +def create_model_index(vae_type: str, default_image_size: int, output_path: str): + """Create model_index.json for the pipeline.""" + + if vae_type == "flux": + vae_class = "AutoencoderKL" + else: # dc-ae + vae_class = "AutoencoderDC" + + model_index = { + "_class_name": "PhotonPipeline", + "_diffusers_version": "0.31.0.dev0", + "_name_or_path": os.path.basename(output_path), + "default_sample_size": default_image_size, + "scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"], + "text_encoder": ["transformers", "T5GemmaModel"], + "tokenizer": ["transformers", "GemmaTokenizerFast"], + "transformer": ["diffusers", "PhotonTransformer2DModel"], + "vae": ["diffusers", vae_class], + } + + model_index_path = os.path.join(output_path, "model_index.json") + with open(model_index_path, "w") as f: + json.dump(model_index, f, indent=2) + + +def main(args): + # Validate inputs + if not os.path.exists(args.checkpoint_path): + raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint_path}") + + config = build_config(args.vae_type) + + # Create output directory + os.makedirs(args.output_path, exist_ok=True) + print(f"✓ Output directory: {args.output_path}") + + # Create transformer from checkpoint + transformer = create_transformer_from_checkpoint(args.checkpoint_path, config) + + # Save transformer + transformer_path = os.path.join(args.output_path, "transformer") + os.makedirs(transformer_path, exist_ok=True) + + # Save config + with open(os.path.join(transformer_path, "config.json"), "w") as f: + json.dump(config, f, indent=2) + + # Save model weights as safetensors + state_dict = transformer.state_dict() + save_file(state_dict, os.path.join(transformer_path, "diffusion_pytorch_model.safetensors")) + print(f"✓ Saved transformer to {transformer_path}") + + # Create scheduler config + create_scheduler_config(args.output_path) + + download_and_save_vae(args.vae_type, args.output_path) + download_and_save_text_encoder(args.output_path) + + # Create model_index.json + create_model_index(args.vae_type, args.resolution, args.output_path) + + # Verify the pipeline can be loaded + try: + pipeline = PhotonPipeline.from_pretrained(args.output_path) + print("Pipeline loaded successfully!") + print(f"Transformer: {type(pipeline.transformer).__name__}") + print(f"VAE: {type(pipeline.vae).__name__}") + print(f"Text Encoder: {type(pipeline.text_encoder).__name__}") + print(f"Scheduler: {type(pipeline.scheduler).__name__}") + + # Display model info + num_params = sum(p.numel() for p in pipeline.transformer.parameters()) + print(f"✓ Transformer parameters: {num_params:,}") + + except Exception as e: + print(f"Pipeline verification failed: {e}") + return False + + print("Conversion completed successfully!") + print(f"Converted pipeline saved to: {args.output_path}") + print(f"VAE type: {args.vae_type}") + + return True + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert Photon checkpoint to diffusers format") + + parser.add_argument( + "--checkpoint_path", type=str, required=True, help="Path to the original Photon checkpoint (.pth file)" + ) + + parser.add_argument( + "--output_path", type=str, required=True, help="Output directory for the converted diffusers pipeline" + ) + + parser.add_argument( + "--vae_type", + type=str, + choices=["flux", "dc-ae"], + required=True, + help="VAE type to use: 'flux' for AutoencoderKL (16 channels) or 'dc-ae' for AutoencoderDC (32 channels)", + ) + + parser.add_argument( + "--resolution", + type=int, + choices=[256, 512, 1024], + default=DEFAULT_RESOLUTION, + help="Target resolution for the model (256, 512, or 1024). Affects the transformer's sample_size.", + ) + + args = parser.parse_args() + + try: + success = main(args) + if not success: + sys.exit(1) + except Exception as e: + print(f"Conversion failed: {e}") + import traceback + + traceback.print_exc() + sys.exit(1) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index aa500b149441..c2528bc50fe5 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -232,6 +232,7 @@ "MultiControlNetModel", "OmniGenTransformer2DModel", "ParallelConfig", + "PhotonTransformer2DModel", "PixArtTransformer2DModel", "PriorTransformer", "QwenImageControlNetModel", diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 8d029bf5d31c..f3164e48cfbf 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -96,6 +96,7 @@ _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"] + _import_structure["transformers.transformer_photon"] = ["PhotonTransformer2DModel"] _import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] _import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"] diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 66455d733aee..c325bc71ae84 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -5605,6 +5605,63 @@ def __new__(cls, *args, **kwargs): return processor +class PhotonAttnProcessor2_0: + r""" + Processor for implementing Photon-style attention with multi-source tokens and RoPE. Properly integrates with + diffusers Attention module while handling Photon-specific logic. + """ + + def __init__(self): + if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): + raise ImportError("PhotonAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: "Attention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + Apply Photon attention using standard diffusers interface. + + Expected tensor formats from PhotonBlock.attn_forward(): + - hidden_states: Image queries with RoPE applied [B, H, L_img, D] + - encoder_hidden_states: Packed key+value tensors [B, H, L_all, 2*D] (concatenated keys and values from text + + image + spatial conditioning) + - attention_mask: Custom attention mask [B, H, L_img, L_all] or None + """ + + if encoder_hidden_states is None: + raise ValueError( + "PhotonAttnProcessor2_0 requires 'encoder_hidden_states' containing packed key+value tensors. " + "This should be provided by PhotonBlock.attn_forward()." + ) + + # Unpack the combined key+value tensor + # encoder_hidden_states is [B, H, L_all, 2*D] containing [keys, values] + key, value = encoder_hidden_states.chunk(2, dim=-1) # Each [B, H, L_all, D] + + # Apply scaled dot-product attention with Photon's processed tensors + # hidden_states is image queries [B, H, L_img, D] + attn_output = torch.nn.functional.scaled_dot_product_attention( + hidden_states.contiguous(), key.contiguous(), value.contiguous(), attn_mask=attention_mask + ) + + # Reshape from [B, H, L_img, D] to [B, L_img, H*D] + batch_size, num_heads, seq_len, head_dim = attn_output.shape + attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, num_heads * head_dim) + + # Apply output projection + attn_output = attn.to_out[0](attn_output) + if len(attn.to_out) > 1: + attn_output = attn.to_out[1](attn_output) # dropout if present + + return attn_output + + ADDED_KV_ATTENTION_PROCESSORS = ( AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, @@ -5653,6 +5710,7 @@ def __new__(cls, *args, **kwargs): PAGHunyuanAttnProcessor2_0, PAGCFGHunyuanAttnProcessor2_0, LuminaAttnProcessor2_0, + PhotonAttnProcessor2_0, FusedAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 6b80ea6c82a5..ab5311518ba7 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -32,6 +32,7 @@ from .transformer_lumina2 import Lumina2Transformer2DModel from .transformer_mochi import MochiTransformer3DModel from .transformer_omnigen import OmniGenTransformer2DModel + from .transformer_photon import PhotonTransformer2DModel from .transformer_qwenimage import QwenImageTransformer2DModel from .transformer_sd3 import SD3Transformer2DModel from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel diff --git a/src/diffusers/models/transformers/transformer_photon.py b/src/diffusers/models/transformers/transformer_photon.py new file mode 100644 index 000000000000..7b29c8bfdafb --- /dev/null +++ b/src/diffusers/models/transformers/transformer_photon.py @@ -0,0 +1,812 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, Union + +import torch +from einops import rearrange +from einops.layers.torch import Rearrange +from torch import Tensor, nn +from torch.nn.functional import fold, unfold + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ..attention_processor import Attention, AttentionProcessor, PhotonAttnProcessor2_0 +from ..embeddings import get_timestep_embedding +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import RMSNorm + + +logger = logging.get_logger(__name__) + + +def get_image_ids(batch_size: int, height: int, width: int, patch_size: int, device: torch.device) -> Tensor: + r""" + Generates 2D patch coordinate indices for a batch of images. + + Parameters: + batch_size (`int`): + Number of images in the batch. + height (`int`): + Height of the input images (in pixels). + width (`int`): + Width of the input images (in pixels). + patch_size (`int`): + Size of the square patches that the image is divided into. + device (`torch.device`): + The device on which to create the tensor. + + Returns: + `torch.Tensor`: + Tensor of shape `(batch_size, num_patches, 2)` containing the (row, col) coordinates of each patch in the + image grid. + """ + + img_ids = torch.zeros(height // patch_size, width // patch_size, 2, device=device) + img_ids[..., 0] = torch.arange(height // patch_size, device=device)[:, None] + img_ids[..., 1] = torch.arange(width // patch_size, device=device)[None, :] + return img_ids.reshape((height // patch_size) * (width // patch_size), 2).unsqueeze(0).repeat(batch_size, 1, 1) + + +def apply_rope(xq: Tensor, freqs_cis: Tensor) -> Tensor: + r""" + Applies rotary positional embeddings (RoPE) to a query tensor. + + Parameters: + xq (`torch.Tensor`): + Input tensor of shape `(..., dim)` representing the queries. + freqs_cis (`torch.Tensor`): + Precomputed rotary frequency components of shape `(..., dim/2, 2)` containing cosine and sine pairs. + + Returns: + `torch.Tensor`: + Tensor of the same shape as `xq` with rotary embeddings applied. + """ + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq) + + +class EmbedND(nn.Module): + r""" + N-dimensional rotary positional embedding. + + This module creates rotary embeddings (RoPE) across multiple axes, where each axis can have its own embedding + dimension. The embeddings are combined and returned as a single tensor + + Parameters: + dim (int): + Base embedding dimension (must be even). + theta (int): + Scaling factor that controls the frequency spectrum of the rotary embeddings. + axes_dim (list[int]): + List of embedding dimensions for each axis (each must be even). + """ + + def __init__(self, dim: int, theta: int, axes_dim: list[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + self.rope_rearrange = Rearrange("b n d (i j) -> b n d i j", i=2, j=2) + + def rope(self, pos: Tensor, dim: int, theta: int) -> Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = pos.unsqueeze(-1) * omega.unsqueeze(0) + out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) + out = self.rope_rearrange(out) + return out.float() + + def forward(self, ids: Tensor) -> Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [self.rope(ids[:, :, i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + return emb.unsqueeze(1) + + +class MLPEmbedder(nn.Module): + r""" + A simple 2-layer MLP used for embedding inputs. + + Parameters: + in_dim (`int`): + Dimensionality of the input features. + hidden_dim (`int`): + Dimensionality of the hidden and output embedding space. + + Returns: + `torch.Tensor`: + Tensor of shape `(..., hidden_dim)` containing the embedded representations. + """ + + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + def forward(self, x: Tensor) -> Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class QKNorm(torch.nn.Module): + r""" + Applies RMS normalization to query and key tensors separately before attention which can help stabilize training + and improve numerical precision. + + Parameters: + dim (`int`): + Dimensionality of the query and key vectors. + + Returns: + (`torch.Tensor`, `torch.Tensor`): + A tuple `(q, k)` where both are normalized and cast to the same dtype as the value tensor `v`. + """ + + def __init__(self, dim: int): + super().__init__() + self.query_norm = RMSNorm(dim, eps=1e-6) + self.key_norm = RMSNorm(dim, eps=1e-6) + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: + q = self.query_norm(q) + k = self.key_norm(k) + return q.to(v), k.to(v) + + +@dataclass +class ModulationOut: + shift: Tensor + scale: Tensor + gate: Tensor + + +class Modulation(nn.Module): + r""" + Modulation network that generates scale, shift, and gating parameters. + + Given an input vector, the module projects it through a linear layer to produce six chunks, which are grouped into + two `ModulationOut` objects. + + Parameters: + dim (`int`): + Dimensionality of the input vector. The output will have `6 * dim` features internally. + + Returns: + (`ModulationOut`, `ModulationOut`): + A tuple of two modulation outputs. Each `ModulationOut` contains three components (e.g., scale, shift, + gate). + """ + + def __init__(self, dim: int): + super().__init__() + self.lin = nn.Linear(dim, 6 * dim, bias=True) + nn.init.constant_(self.lin.weight, 0) + nn.init.constant_(self.lin.bias, 0) + + def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut]: + out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(6, dim=-1) + return ModulationOut(*out[:3]), ModulationOut(*out[3:]) + + +class PhotonBlock(nn.Module): + r""" + Multimodal transformer block with text–image cross-attention, modulation, and MLP. + + Parameters: + hidden_size (`int`): + Dimension of the hidden representations. + num_heads (`int`): + Number of attention heads. + mlp_ratio (`float`, *optional*, defaults to 4.0): + Expansion ratio for the hidden dimension inside the MLP. + qk_scale (`float`, *optional*): + Scale factor for queries and keys. If not provided, defaults to ``head_dim**-0.5``. + + Attributes: + img_pre_norm (`nn.LayerNorm`): + Pre-normalization applied to image tokens before QKV projection. + img_qkv_proj (`nn.Linear`): + Linear projection to produce image queries, keys, and values. + qk_norm (`QKNorm`): + RMS normalization applied separately to image queries and keys. + txt_kv_proj (`nn.Linear`): + Linear projection to produce text keys and values. + k_norm (`RMSNorm`): + RMS normalization applied to text keys. + attention (`Attention`): + Multi-head attention module for cross-attention between image, text, and optional spatial conditioning + tokens. + post_attention_layernorm (`nn.LayerNorm`): + Normalization applied after attention. + gate_proj / up_proj / down_proj (`nn.Linear`): + Feedforward layers forming the gated MLP. + mlp_act (`nn.GELU`): + Nonlinear activation used in the MLP. + modulation (`Modulation`): + Produces scale/shift/gating parameters for modulated layers. + + Methods: + attn_forward(img, txt, pe, modulation, spatial_conditioning=None, attention_mask=None): + Compute cross-attention between image and text tokens, with optional spatial conditioning and attention + masking. + + Parameters: + img (`torch.Tensor`): + Image tokens of shape `(B, L_img, hidden_size)`. + txt (`torch.Tensor`): + Text tokens of shape `(B, L_txt, hidden_size)`. + pe (`torch.Tensor`): + Rotary positional embeddings to apply to queries and keys. + modulation (`ModulationOut`): + Scale and shift parameters for modulating image tokens. + spatial_conditioning (`torch.Tensor`, *optional*): + Extra conditioning tokens of shape `(B, L_cond, hidden_size)`. + attention_mask (`torch.Tensor`, *optional*): + Boolean mask of shape `(B, L_txt)` where 0 marks padding. + + Returns: + `torch.Tensor`: + Attention output of shape `(B, L_img, hidden_size)`. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: float | None = None, + ): + super().__init__() + + self.hidden_dim = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.scale = qk_scale or self.head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.hidden_size = hidden_size + + # img qkv + self.img_pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_qkv_proj = nn.Linear(hidden_size, hidden_size * 3, bias=False) + self.qk_norm = QKNorm(self.head_dim) + + # txt kv + self.txt_kv_proj = nn.Linear(hidden_size, hidden_size * 2, bias=False) + self.k_norm = RMSNorm(self.head_dim, eps=1e-6) + + self.attention = Attention( + query_dim=hidden_size, + heads=num_heads, + dim_head=self.head_dim, + bias=False, + out_bias=False, + processor=PhotonAttnProcessor2_0(), + ) + + # mlp + self.post_attention_layernorm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.gate_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False) + self.up_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False) + self.down_proj = nn.Linear(self.mlp_hidden_dim, hidden_size, bias=False) + self.mlp_act = nn.GELU(approximate="tanh") + + self.modulation = Modulation(hidden_size) + + def _attn_forward( + self, + img: Tensor, + txt: Tensor, + pe: Tensor, + modulation: ModulationOut, + spatial_conditioning: None | Tensor = None, + attention_mask: None | Tensor = None, + ) -> Tensor: + # image tokens proj and norm + img_mod = (1 + modulation.scale) * self.img_pre_norm(img) + modulation.shift + + img_qkv = self.img_qkv_proj(img_mod) + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + img_q, img_k = self.qk_norm(img_q, img_k, img_v) + + # txt tokens proj and norm + txt_kv = self.txt_kv_proj(txt) + txt_k, txt_v = rearrange(txt_kv, "B L (K H D) -> K B H L D", K=2, H=self.num_heads) + txt_k = self.k_norm(txt_k) + + # compute attention + img_q, img_k = apply_rope(img_q, pe), apply_rope(img_k, pe) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + # build multiplicative 0/1 mask for provided attention_mask over [cond?, text, image] keys + attn_mask: Tensor | None = None + if attention_mask is not None: + bs, _, l_img, _ = img_q.shape + l_txt = txt_k.shape[2] + + if attention_mask.dim() != 2: + raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}") + if attention_mask.shape[-1] != l_txt: + raise ValueError(f"attention_mask last dim {attention_mask.shape[-1]} must equal text length {l_txt}") + + device = img_q.device + + ones_img = torch.ones((bs, l_img), dtype=torch.bool, device=device) + + mask_parts = [ + attention_mask.to(torch.bool), + ones_img, + ] + joint_mask = torch.cat(mask_parts, dim=-1) # (B, L_all) + + # repeat across heads and query positions + attn_mask = joint_mask[:, None, None, :].expand(-1, self.num_heads, l_img, -1) # (B,H,L_img,L_all) + + kv_packed = torch.cat([k, v], dim=-1) + + attn = self.attention( + hidden_states=img_q, + encoder_hidden_states=kv_packed, + attention_mask=attn_mask, + ) + + return attn + + def _ffn_forward(self, x: Tensor, modulation: ModulationOut) -> Tensor: + x = (1 + modulation.scale) * self.post_attention_layernorm(x) + modulation.shift + return self.down_proj(self.mlp_act(self.gate_proj(x)) * self.up_proj(x)) + + def forward( + self, + img: Tensor, + txt: Tensor, + vec: Tensor, + pe: Tensor, + spatial_conditioning: Tensor | None = None, + attention_mask: Tensor | None = None, + **_: dict[str, Any], + ) -> Tensor: + r""" + Runs modulation-gated cross-attention and MLP, with residual connections. + + Parameters: + img (`torch.Tensor`): + Image tokens of shape `(B, L_img, hidden_size)`. + txt (`torch.Tensor`): + Text tokens of shape `(B, L_txt, hidden_size)`. + vec (`torch.Tensor`): + Conditioning vector used by `Modulation` to produce scale/shift/gates, shape `(B, hidden_size)` (or + broadcastable). + pe (`torch.Tensor`): + Rotary positional embeddings applied inside attention. + spatial_conditioning (`torch.Tensor`, *optional*): + Extra conditioning tokens of shape `(B, L_cond, hidden_size)`. Used only if spatial conditioning is + enabled in the block. + attention_mask (`torch.Tensor`, *optional*): + Boolean mask for text tokens of shape `(B, L_txt)`, where `0` marks padding. + **_: + Ignored additional keyword arguments for API compatibility. + + Returns: + `torch.Tensor`: + Updated image tokens of shape `(B, L_img, hidden_size)`. + """ + + mod_attn, mod_mlp = self.modulation(vec) + + img = img + mod_attn.gate * self._attn_forward( + img, + txt, + pe, + mod_attn, + spatial_conditioning=spatial_conditioning, + attention_mask=attention_mask, + ) + img = img + mod_mlp.gate * self._ffn_forward(img, mod_mlp) + return img + + +class LastLayer(nn.Module): + r""" + Final projection layer with adaptive LayerNorm modulation. + + This layer applies a normalized and modulated transformation to input tokens and projects them into patch-level + outputs. + + Parameters: + hidden_size (`int`): + Dimensionality of the input tokens. + patch_size (`int`): + Size of the square image patches. + out_channels (`int`): + Number of output channels per pixel (e.g. RGB = 3). + + Forward Inputs: + x (`torch.Tensor`): + Input tokens of shape `(B, L, hidden_size)`, where `L` is the number of patches. + vec (`torch.Tensor`): + Conditioning vector of shape `(B, hidden_size)` used to generate shift and scale parameters for adaptive + LayerNorm. + + Returns: + `torch.Tensor`: + Projected patch outputs of shape `(B, L, patch_size * patch_size * out_channels)`. + """ + + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + def forward(self, x: Tensor, vec: Tensor) -> Tensor: + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x + + +def img2seq(img: Tensor, patch_size: int) -> Tensor: + r""" + Flattens an image tensor into a sequence of non-overlapping patches. + + Parameters: + img (`torch.Tensor`): + Input image tensor of shape `(B, C, H, W)`. + patch_size (`int`): + Size of each square patch. Must evenly divide both `H` and `W`. + + Returns: + `torch.Tensor`: + Flattened patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W + // patch_size)` is the number of patches. + """ + return unfold(img, kernel_size=patch_size, stride=patch_size).transpose(1, 2) + + +def seq2img(seq: Tensor, patch_size: int, shape: Tensor) -> Tensor: + r""" + Reconstructs an image tensor from a sequence of patches (inverse of `img2seq`). + + Parameters: + seq (`torch.Tensor`): + Patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W // + patch_size)`. + patch_size (`int`): + Size of each square patch. + shape (`tuple` or `torch.Tensor`): + The original image spatial shape `(H, W)`. If a tensor is provided, the first two values are interpreted as + height and width. + + Returns: + `torch.Tensor`: + Reconstructed image tensor of shape `(B, C, H, W)`. + """ + if isinstance(shape, tuple): + shape = shape[-2:] + elif isinstance(shape, torch.Tensor): + shape = (int(shape[0]), int(shape[1])) + else: + raise NotImplementedError(f"shape type {type(shape)} not supported") + return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size) + + +class PhotonTransformer2DModel(ModelMixin, ConfigMixin): + r""" + Transformer-based 2D model for text to image generation. It supports attention processor injection and LoRA + scaling. + + Parameters: + in_channels (`int`, *optional*, defaults to 16): + Number of input channels in the latent image. + patch_size (`int`, *optional*, defaults to 2): + Size of the square patches used to flatten the input image. + context_in_dim (`int`, *optional*, defaults to 2304): + Dimensionality of the text conditioning input. + hidden_size (`int`, *optional*, defaults to 1792): + Dimension of the hidden representation. + mlp_ratio (`float`, *optional*, defaults to 3.5): + Expansion ratio for the hidden dimension inside MLP blocks. + num_heads (`int`, *optional*, defaults to 28): + Number of attention heads. + depth (`int`, *optional*, defaults to 16): + Number of transformer blocks. + axes_dim (`list[int]`, *optional*): + List of dimensions for each positional embedding axis. Defaults to `[32, 32]`. + theta (`int`, *optional*, defaults to 10000): + Frequency scaling factor for rotary embeddings. + time_factor (`float`, *optional*, defaults to 1000.0): + Scaling factor applied in timestep embeddings. + time_max_period (`int`, *optional*, defaults to 10000): + Maximum frequency period for timestep embeddings. + + Attributes: + pe_embedder (`EmbedND`): + Multi-axis rotary embedding generator for positional encodings. + img_in (`nn.Linear`): + Projection layer for image patch tokens. + time_in (`MLPEmbedder`): + Embedding layer for timestep embeddings. + txt_in (`nn.Linear`): + Projection layer for text conditioning. + blocks (`nn.ModuleList`): + Stack of transformer blocks (`PhotonBlock`). + final_layer (`LastLayer`): + Projection layer mapping hidden tokens back to patch outputs. + + Methods: + attn_processors: + Returns a dictionary of all attention processors in the model. + set_attn_processor(processor): + Replaces attention processors across all attention layers. + process_inputs(image_latent, txt): + Converts inputs into patch tokens, encodes text, and produces positional encodings. + compute_timestep_embedding(timestep, dtype): + Creates a timestep embedding of dimension 256, scaled and projected. + forward_transformers(image_latent, cross_attn_conditioning, timestep, time_embedding, attention_mask, + **block_kwargs): + Runs the sequence of transformer blocks over image and text tokens. + forward(image_latent, timestep, cross_attn_conditioning, micro_conditioning, cross_attn_mask=None, + attention_kwargs=None, return_dict=True): + Full forward pass from latent input to reconstructed output image. + + Returns: + `Transformer2DModelOutput` if `return_dict=True` (default), otherwise a tuple containing: + - `sample` (`torch.Tensor`): Reconstructed image of shape `(B, C, H, W)`. + """ + + config_name = "config.json" + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 16, + patch_size: int = 2, + context_in_dim: int = 2304, + hidden_size: int = 1792, + mlp_ratio: float = 3.5, + num_heads: int = 28, + depth: int = 16, + axes_dim: list = None, + theta: int = 10000, + time_factor: float = 1000.0, + time_max_period: int = 10000, + ): + super().__init__() + + if axes_dim is None: + axes_dim = [32, 32] + + # Store parameters directly + self.in_channels = in_channels + self.patch_size = patch_size + self.out_channels = self.in_channels * self.patch_size**2 + + self.time_factor = time_factor + self.time_max_period = time_max_period + + if hidden_size % num_heads != 0: + raise ValueError(f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}") + + pe_dim = hidden_size // num_heads + + if sum(axes_dim) != pe_dim: + raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}") + + self.hidden_size = hidden_size + self.num_heads = num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim) + self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.txt_in = nn.Linear(context_in_dim, self.hidden_size) + + self.blocks = nn.ModuleList( + [ + PhotonBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=mlp_ratio, + ) + for i in range(depth) + ] + ) + + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + + self.gradient_checkpointing = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def _process_inputs(self, image_latent: Tensor, txt: Tensor, **_: Any) -> tuple[Tensor, Tensor, Tensor]: + txt = self.txt_in(txt) + img = img2seq(image_latent, self.patch_size) + bs, _, h, w = image_latent.shape + img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=image_latent.device) + pe = self.pe_embedder(img_ids) + return img, txt, pe + + def _compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> Tensor: + return self.time_in( + get_timestep_embedding( + timesteps=timestep, + embedding_dim=256, + max_period=self.time_max_period, + scale=self.time_factor, + flip_sin_to_cos=True, # Match original cos, sin order + ).to(dtype) + ) + + def _forward_transformers( + self, + image_latent: Tensor, + cross_attn_conditioning: Tensor, + timestep: Optional[Tensor] = None, + time_embedding: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + **block_kwargs: Any, + ) -> Tensor: + img = self.img_in(image_latent) + + if time_embedding is not None: + vec = time_embedding + else: + if timestep is None: + raise ValueError("Please provide either a timestep or a timestep_embedding") + vec = self._compute_timestep_embedding(timestep, dtype=img.dtype) + + for block in self.blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + img = self._gradient_checkpointing_func( + block.__call__, + img, + cross_attn_conditioning, + vec, + block_kwargs.get("pe"), + block_kwargs.get("spatial_conditioning"), + attention_mask, + ) + else: + img = block( + img=img, txt=cross_attn_conditioning, vec=vec, attention_mask=attention_mask, **block_kwargs + ) + + img = self.final_layer(img, vec) + return img + + def forward( + self, + image_latent: Tensor, + timestep: Tensor, + cross_attn_conditioning: Tensor, + micro_conditioning: Tensor, + cross_attn_mask: None | Tensor = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]: + r""" + Forward pass of the PhotonTransformer2DModel. + + The latent image is split into patch tokens, combined with text conditioning, and processed through a stack of + transformer blocks modulated by the timestep. The output is reconstructed into the latent image space. + + Parameters: + image_latent (`torch.Tensor`): + Input latent image tensor of shape `(B, C, H, W)`. + timestep (`torch.Tensor`): + Timestep tensor of shape `(B,)` or `(1,)`, used for temporal conditioning. + cross_attn_conditioning (`torch.Tensor`): + Text conditioning tensor of shape `(B, L_txt, context_in_dim)`. + micro_conditioning (`torch.Tensor`): + Extra conditioning vector (currently unused, reserved for future use). + cross_attn_mask (`torch.Tensor`, *optional*): + Boolean mask of shape `(B, L_txt)`, where `0` marks padding in the text sequence. + attention_kwargs (`dict`, *optional*): + Additional arguments passed to attention layers. If using the PEFT backend, the key `"scale"` controls + LoRA scaling (default: 1.0). + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a `Transformer2DModelOutput` or a tuple. + + Returns: + `Transformer2DModelOutput` if `return_dict=True`, otherwise a tuple: + + - `sample` (`torch.Tensor`): Output latent image of shape `(B, C, H, W)`. + """ + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + img_seq, txt, pe = self._process_inputs(image_latent, cross_attn_conditioning) + img_seq = self._forward_transformers(img_seq, txt, timestep, pe=pe, attention_mask=cross_attn_mask) + output = seq2img(img_seq, self.patch_size, image_latent.shape) + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index c438caed571f..1fa8dcf0c8b8 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -144,6 +144,7 @@ "FluxKontextPipeline", "FluxKontextInpaintPipeline", ] + _import_structure["photon"] = ["PhotonPipeline"] _import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm2"] = [ "AudioLDM2Pipeline", diff --git a/src/diffusers/pipelines/photon/__init__.py b/src/diffusers/pipelines/photon/__init__.py new file mode 100644 index 000000000000..559c9d0b1d2d --- /dev/null +++ b/src/diffusers/pipelines/photon/__init__.py @@ -0,0 +1,5 @@ +from .pipeline_output import PhotonPipelineOutput +from .pipeline_photon import PhotonPipeline + + +__all__ = ["PhotonPipeline", "PhotonPipelineOutput"] diff --git a/src/diffusers/pipelines/photon/pipeline_output.py b/src/diffusers/pipelines/photon/pipeline_output.py new file mode 100644 index 000000000000..ca0674d94b6c --- /dev/null +++ b/src/diffusers/pipelines/photon/pipeline_output.py @@ -0,0 +1,35 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class PhotonPipelineOutput(BaseOutput): + """ + Output class for Photon pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] diff --git a/src/diffusers/pipelines/photon/pipeline_photon.py b/src/diffusers/pipelines/photon/pipeline_photon.py new file mode 100644 index 000000000000..750d3aacc754 --- /dev/null +++ b/src/diffusers/pipelines/photon/pipeline_photon.py @@ -0,0 +1,643 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +from typing import Any, Callable, Dict, List, Optional, Union + +import ftfy +import torch +from transformers import ( + AutoTokenizer, + GemmaTokenizerFast, + T5EncoderModel, + T5TokenizerFast, +) + +from diffusers.image_processor import PixArtImageProcessor +from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderDC, AutoencoderKL +from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel, seq2img +from diffusers.pipelines.photon.pipeline_output import PhotonPipelineOutput +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import ( + logging, + replace_example_docstring, +) +from diffusers.utils.torch_utils import randn_tensor + + +DEFAULT_RESOLUTION = 512 + +ASPECT_RATIO_256_BIN = { + "0.46": [160, 352], + "0.6": [192, 320], + "0.78": [224, 288], + "1.0": [256, 256], + "1.29": [288, 224], + "1.67": [320, 192], + "2.2": [352, 160], +} + +ASPECT_RATIO_512_BIN = { + "0.5": [352, 704], + "0.57": [384, 672], + "0.6": [384, 640], + "0.68": [416, 608], + "0.78": [448, 576], + "0.88": [480, 544], + "1.0": [512, 512], + "1.13": [544, 480], + "1.29": [576, 448], + "1.46": [608, 416], + "1.67": [640, 384], + "1.75": [672, 384], + "2.0": [704, 352], +} + +logger = logging.get_logger(__name__) + + +class TextPreprocessor: + """Text preprocessing utility for PhotonPipeline.""" + + def __init__(self): + """Initialize text preprocessor.""" + self.bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + r"\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) + + def clean_text(self, text: str) -> str: + """Clean text using comprehensive text processing logic.""" + # See Deepfloyd https://github.com/deep-floyd/IF/blob/develop/deepfloyd_if/modules/t5.py + text = str(text) + text = ul.unquote_plus(text) + text = text.strip().lower() + text = re.sub("", "person", text) + + # Remove all urls: + text = re.sub( + r"\b((?:https?|www):(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@))", + "", + text, + ) # regex for urls + + # @ + text = re.sub(r"@[\w\d]+\b", "", text) + + # 31C0—31EF CJK Strokes through 4E00—9FFF CJK Unified Ideographs + text = re.sub(r"[\u31c0-\u31ef]+", "", text) + text = re.sub(r"[\u31f0-\u31ff]+", "", text) + text = re.sub(r"[\u3200-\u32ff]+", "", text) + text = re.sub(r"[\u3300-\u33ff]+", "", text) + text = re.sub(r"[\u3400-\u4dbf]+", "", text) + text = re.sub(r"[\u4dc0-\u4dff]+", "", text) + text = re.sub(r"[\u4e00-\u9fff]+", "", text) + + # все виды тире / all types of dash --> "-" + text = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", + "-", + text, + ) + + # кавычки к одному стандарту + text = re.sub(r"[`´«»" "¨]", '"', text) + text = re.sub(r"['']", "'", text) + + # " and & + text = re.sub(r""?", "", text) + text = re.sub(r"&", "", text) + + # ip addresses: + text = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", text) + + # article ids: + text = re.sub(r"\d:\d\d\s+$", "", text) + + # \n + text = re.sub(r"\\n", " ", text) + + # "#123", "#12345..", "123456.." + text = re.sub(r"#\d{1,3}\b", "", text) + text = re.sub(r"#\d{5,}\b", "", text) + text = re.sub(r"\b\d{6,}\b", "", text) + + # filenames: + text = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", text) + + # Clean punctuation + text = re.sub(r"[\"\']{2,}", r'"', text) # """AUSVERKAUFT""" + text = re.sub(r"[\.]{2,}", r" ", text) + + text = re.sub(self.bad_punct_regex, r" ", text) # ***AUSVERKAUFT***, #AUSVERKAUFT + text = re.sub(r"\s+\.\s+", r" ", text) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, text)) > 3: + text = re.sub(regex2, " ", text) + + # Basic cleaning + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + text = text.strip() + + # Clean alphanumeric patterns + text = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", text) # jc6640 + text = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", text) # jc6640vc + text = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", text) # 6640vc231 + + # Common spam patterns + text = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", text) + text = re.sub(r"(free\s)?download(\sfree)?", "", text) + text = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", text) + text = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", text) + text = re.sub(r"\bpage\s+\d+\b", "", text) + + text = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", text) # j2d1a2a... + text = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", text) + + # Final cleanup + text = re.sub(r"\b\s+\:\s+", r": ", text) + text = re.sub(r"(\D[,\./])\b", r"\1 ", text) + text = re.sub(r"\s+", " ", text) + + text.strip() + + text = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", text) + text = re.sub(r"^[\'\_,\-\:;]", r"", text) + text = re.sub(r"[\'\_,\-\:\-\+]$", r"", text) + text = re.sub(r"^\.\S+$", "", text) + + return text.strip() + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import PhotonPipeline + + >>> # Load pipeline with from_pretrained + >>> pipe = PhotonPipeline.from_pretrained("path/to/photon_checkpoint") + >>> pipe.to("cuda") + + >>> prompt = "A digital painting of a rusty, vintage tram on a sandy beach" + >>> image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0] + >>> image.save("photon_output.png") + ``` +""" + + +class PhotonPipeline( + DiffusionPipeline, + LoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, +): + r""" + Pipeline for text-to-image generation using Photon Transformer. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + transformer ([`PhotonTransformer2DModel`]): + The Photon transformer model to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + text_encoder ([`T5EncoderModel`]): + Standard text encoder model for encoding prompts. + tokenizer ([`T5TokenizerFast` or `GemmaTokenizerFast`]): + Tokenizer for the text encoder. + vae ([`AutoencoderKL`] or [`AutoencoderDC`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + Supports both AutoencoderKL (8x compression) and AutoencoderDC (32x compression). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents"] + _optional_components = [] + + def __init__( + self, + transformer: PhotonTransformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + text_encoder: Union[T5EncoderModel, Any], + tokenizer: Union[T5TokenizerFast, GemmaTokenizerFast, AutoTokenizer], + vae: Union[AutoencoderKL, AutoencoderDC], + default_sample_size: Optional[int] = DEFAULT_RESOLUTION, + ): + super().__init__() + + if PhotonTransformer2DModel is None: + raise ImportError( + "PhotonTransformer2DModel is not available. Please ensure the transformer_photon module is properly installed." + ) + + # Extract encoder if text_encoder is T5GemmaModel + if hasattr(text_encoder, "encoder"): + text_encoder = text_encoder.encoder + + self.text_encoder = text_encoder + self.tokenizer = tokenizer + self.text_preprocessor = TextPreprocessor() + + self.register_modules( + transformer=transformer, + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + ) + + self.register_to_config(default_sample_size=default_sample_size) + + # Enhance VAE with universal properties for both AutoencoderKL and AutoencoderDC + self._enhance_vae_properties() + + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def _enhance_vae_properties(self): + """Add universal properties to VAE for consistent interface across AutoencoderKL and AutoencoderDC.""" + if not hasattr(self, "vae") or self.vae is None: + return + + if hasattr(self.vae, "spatial_compression_ratio"): + # AutoencoderDC already has this property + pass + elif hasattr(self.vae, "config") and hasattr(self.vae.config, "block_out_channels"): + # AutoencoderKL: calculate from block_out_channels + self.vae.spatial_compression_ratio = 2 ** (len(self.vae.config.block_out_channels) - 1) + else: + # Fallback + self.vae.spatial_compression_ratio = 8 + + if hasattr(self.vae, "config"): + self.vae.scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215) + else: + self.vae.scaling_factor = 0.18215 + + if hasattr(self.vae, "config"): + shift_factor = getattr(self.vae.config, "shift_factor", None) + if shift_factor is None: # AutoencoderDC case + self.vae.shift_factor = 0.0 + else: + self.vae.shift_factor = shift_factor + else: + self.vae.shift_factor = 0.0 + + if hasattr(self.vae, "config") and hasattr(self.vae.config, "latent_channels"): + # AutoencoderDC has latent_channels in config + self.vae.latent_channels = int(self.vae.config.latent_channels) + elif hasattr(self.vae, "config") and hasattr(self.vae.config, "in_channels"): + # AutoencoderKL has in_channels in config + self.vae.latent_channels = int(self.vae.config.in_channels) + else: + # Fallback based on VAE type - DC-AE typically has 32, AutoencoderKL has 4/16 + if hasattr(self.vae, "spatial_compression_ratio") and self.vae.spatial_compression_ratio == 32: + self.vae.latent_channels = 32 # DC-AE default + else: + self.vae.latent_channels = 16 # FluxVAE default + + @property + def vae_scale_factor(self): + """Compatibility property that returns spatial compression ratio.""" + return getattr(self.vae, "spatial_compression_ratio", 8) + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ): + """Prepare initial latents for the diffusion process.""" + if latents is None: + latent_height, latent_width = ( + height // self.vae.spatial_compression_ratio, + width // self.vae.spatial_compression_ratio, + ) + shape = (batch_size, num_channels_latents, latent_height, latent_width) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + return latents + + def encode_prompt(self, prompt: Union[str, List[str]], device: torch.device): + """Encode text prompt using standard text encoder and tokenizer.""" + if isinstance(prompt, str): + prompt = [prompt] + + return self._encode_prompt_standard(prompt, device) + + def _encode_prompt_standard(self, prompt: List[str], device: torch.device): + """Encode prompt using standard text encoder and tokenizer with batch processing.""" + # Clean text using modular preprocessor + cleaned_prompts = [self.text_preprocessor.clean_text(text) for text in prompt] + cleaned_uncond_prompts = [self.text_preprocessor.clean_text("") for _ in prompt] + + all_prompts = cleaned_prompts + cleaned_uncond_prompts + + tokens = self.tokenizer( + all_prompts, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + + input_ids = tokens["input_ids"].to(device) + attention_mask = tokens["attention_mask"].bool().to(device) + + with torch.no_grad(): + with torch.autocast("cuda", enabled=False): + emb = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + ) + + all_embeddings = emb["last_hidden_state"] + + # Split back into conditional and unconditional + batch_size = len(prompt) + text_embeddings = all_embeddings[:batch_size] + uncond_text_embeddings = all_embeddings[batch_size:] + + cross_attn_mask = attention_mask[:batch_size] + uncond_cross_attn_mask = attention_mask[batch_size:] + + return text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask + + def check_inputs( + self, + prompt: Union[str, List[str]], + height: int, + width: int, + guidance_scale: float, + callback_on_step_end_tensor_inputs: Optional[List[str]] = None, + ): + """Check that all inputs are in correct format.""" + if height % self.vae.spatial_compression_ratio != 0 or width % self.vae.spatial_compression_ratio != 0: + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae.spatial_compression_ratio} but are {height} and {width}." + ) + + if guidance_scale < 1.0: + raise ValueError(f"guidance_scale has to be >= 1.0 but is {guidance_scale}") + + if callback_on_step_end_tensor_inputs is not None and not isinstance(callback_on_step_end_tensor_inputs, list): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be a list but is {callback_on_step_end_tensor_inputs}" + ) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 4.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + use_resolution_binning: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 28): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.photon.PhotonPipelineOutput`] instead of a plain tuple. + use_resolution_binning (`bool`, *optional*, defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + predefined aspect ratio bins. After the produced latents are decoded into images, they are resized back to + the requested resolution. Useful for generating non-square images at optimal resolutions. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self, step, timestep, callback_kwargs)`. + `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include tensors that are listed + in the `._callback_tensor_inputs` attribute. + + Examples: + + Returns: + [`~pipelines.photon.PhotonPipelineOutput`] or `tuple`: [`~pipelines.photon.PhotonPipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. + """ + + # 0. Default height and width from config + default_sample_size = getattr(self.config, "default_sample_size", DEFAULT_RESOLUTION) + height = height or default_sample_size + width = width or default_sample_size + + if use_resolution_binning: + if default_sample_size <= 256: + aspect_ratio_bin = ASPECT_RATIO_256_BIN + else: + aspect_ratio_bin = ASPECT_RATIO_512_BIN + + # Store original dimensions + orig_height, orig_width = height, width + # Map to closest resolution in the bin + height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + # 1. Check inputs + self.check_inputs( + prompt, + height, + width, + guidance_scale, + callback_on_step_end_tensor_inputs, + ) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError("prompt must be provided as a string or list of strings") + + device = self._execution_device + + # 2. Encode input prompt + text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask = self.encode_prompt( + prompt, device + ) + + # 3. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 4. Prepare latent variables + num_channels_latents = self.vae.latent_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + ) + + # 5. Prepare extra step kwargs + extra_step_kwargs = {} + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_eta: + extra_step_kwargs["eta"] = 0.0 + + # 6. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Duplicate latents for CFG + latents_in = torch.cat([latents, latents], dim=0) + + # Cross-attention batch (uncond, cond) + ca_embed = torch.cat([uncond_text_embeddings, text_embeddings], dim=0) + ca_mask = None + if cross_attn_mask is not None and uncond_cross_attn_mask is not None: + ca_mask = torch.cat([uncond_cross_attn_mask, cross_attn_mask], dim=0) + + # Normalize timestep for the transformer + t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).repeat(2).to(device) + + # Process inputs for transformer + img_seq, txt, pe = self.transformer._process_inputs(latents_in, ca_embed) + + # Forward through transformer layers + img_seq = self.transformer._forward_transformers( + img_seq, + txt, + time_embedding=self.transformer._compute_timestep_embedding(t_cont, img_seq.dtype), + pe=pe, + attention_mask=ca_mask, + ) + + # Convert back to image format + noise_both = seq2img(img_seq, self.transformer.patch_size, latents_in.shape) + + # Apply CFG + noise_uncond, noise_text = noise_both.chunk(2, dim=0) + noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) + + # Compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_on_step_end(self, i, t, callback_kwargs) + + # Call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # 8. Post-processing + if output_type == "latent": + image = latents + else: + # Unscale latents for VAE (supports both AutoencoderKL and AutoencoderDC) + latents = (latents / self.vae.scaling_factor) + self.vae.shift_factor + # Decode using VAE (AutoencoderKL or AutoencoderDC) + image = self.vae.decode(latents, return_dict=False)[0] + # Resize back to original resolution if using binning + if use_resolution_binning: + image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) + + # Use standard image processor for post-processing + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return PhotonPipelineOutput(images=image) diff --git a/tests/models/transformers/test_models_transformer_photon.py b/tests/models/transformers/test_models_transformer_photon.py new file mode 100644 index 000000000000..1491b83bf65c --- /dev/null +++ b/tests/models/transformers/test_models_transformer_photon.py @@ -0,0 +1,85 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel + +from ...testing_utils import enable_full_determinism, torch_device +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class PhotonTransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = PhotonTransformer2DModel + main_input_name = "image_latent" + uses_custom_attn_processor = True + + @property + def dummy_input(self): + return self.prepare_dummy_input() + + @property + def input_shape(self): + return (16, 16, 16) + + @property + def output_shape(self): + return (16, 16, 16) + + def prepare_dummy_input(self, height=16, width=16): + batch_size = 1 + num_latent_channels = 16 + sequence_length = 16 + embedding_dim = 1792 + + image_latent = torch.randn((batch_size, num_latent_channels, height, width)).to(torch_device) + cross_attn_conditioning = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + micro_conditioning = torch.randn((batch_size, embedding_dim)).to(torch_device) + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + + return { + "image_latent": image_latent, + "timestep": timestep, + "cross_attn_conditioning": cross_attn_conditioning, + "micro_conditioning": micro_conditioning, + } + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "in_channels": 16, + "patch_size": 2, + "context_in_dim": 1792, + "hidden_size": 1792, + "mlp_ratio": 3.5, + "num_heads": 28, + "depth": 4, # Smaller depth for testing + "axes_dim": [32, 32], + "theta": 10_000, + } + inputs_dict = self.prepare_dummy_input() + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"PhotonTransformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +if __name__ == "__main__": + unittest.main() From 8e78a99c65f727f644582bc8c032ba76e46aca1b Mon Sep 17 00:00:00 2001 From: davidb Date: Thu, 9 Oct 2025 16:06:35 +0000 Subject: [PATCH 02/52] just store the T5Gemma encoder --- scripts/convert_photon_to_diffusers.py | 8 +++++--- src/diffusers/pipelines/photon/pipeline_photon.py | 11 ++++------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/scripts/convert_photon_to_diffusers.py b/scripts/convert_photon_to_diffusers.py index ac6aba458046..fc4161ff6275 100644 --- a/scripts/convert_photon_to_diffusers.py +++ b/scripts/convert_photon_to_diffusers.py @@ -217,8 +217,10 @@ def download_and_save_text_encoder(output_path: str): print("Downloading T5Gemma model from google/t5gemma-2b-2b-ul2...") t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2") - t5gemma_model.save_pretrained(text_encoder_path) - print(f"✓ Saved T5Gemma model to {text_encoder_path}") + # Extract and save only the encoder + t5gemma_encoder = t5gemma_model.encoder + t5gemma_encoder.save_pretrained(text_encoder_path) + print(f"✓ Saved T5GemmaEncoder to {text_encoder_path}") print("Downloading tokenizer from google/t5gemma-2b-2b-ul2...") tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2") @@ -241,7 +243,7 @@ def create_model_index(vae_type: str, default_image_size: int, output_path: str) "_name_or_path": os.path.basename(output_path), "default_sample_size": default_image_size, "scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"], - "text_encoder": ["transformers", "T5GemmaModel"], + "text_encoder": ["transformers.models.t5gemma.modeling_t5gemma", "T5GemmaEncoder"], "tokenizer": ["transformers", "GemmaTokenizerFast"], "transformer": ["diffusers", "PhotonTransformer2DModel"], "vae": ["diffusers", vae_class], diff --git a/src/diffusers/pipelines/photon/pipeline_photon.py b/src/diffusers/pipelines/photon/pipeline_photon.py index 750d3aacc754..7673b89ca1dc 100644 --- a/src/diffusers/pipelines/photon/pipeline_photon.py +++ b/src/diffusers/pipelines/photon/pipeline_photon.py @@ -26,6 +26,7 @@ T5EncoderModel, T5TokenizerFast, ) +from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder from diffusers.image_processor import PixArtImageProcessor from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin @@ -233,8 +234,8 @@ class PhotonPipeline( The Photon transformer model to denoise the encoded image latents. scheduler ([`FlowMatchEulerDiscreteScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. - text_encoder ([`T5EncoderModel`]): - Standard text encoder model for encoding prompts. + text_encoder ([`T5EncoderModel`] or [`T5GemmaEncoder`]): + Text encoder model for encoding prompts. Supports T5EncoderModel or T5GemmaEncoder. tokenizer ([`T5TokenizerFast` or `GemmaTokenizerFast`]): Tokenizer for the text encoder. vae ([`AutoencoderKL`] or [`AutoencoderDC`]): @@ -250,7 +251,7 @@ def __init__( self, transformer: PhotonTransformer2DModel, scheduler: FlowMatchEulerDiscreteScheduler, - text_encoder: Union[T5EncoderModel, Any], + text_encoder: Union[T5EncoderModel, T5GemmaEncoder], tokenizer: Union[T5TokenizerFast, GemmaTokenizerFast, AutoTokenizer], vae: Union[AutoencoderKL, AutoencoderDC], default_sample_size: Optional[int] = DEFAULT_RESOLUTION, @@ -262,10 +263,6 @@ def __init__( "PhotonTransformer2DModel is not available. Please ensure the transformer_photon module is properly installed." ) - # Extract encoder if text_encoder is T5GemmaModel - if hasattr(text_encoder, "encoder"): - text_encoder = text_encoder.encoder - self.text_encoder = text_encoder self.tokenizer = tokenizer self.text_preprocessor = TextPreprocessor() From 3aeada78e9ce3c5e3bebaa202d16b414f6c34a0c Mon Sep 17 00:00:00 2001 From: davidb Date: Thu, 9 Oct 2025 16:17:40 +0000 Subject: [PATCH 03/52] enhance_vae_properties if vae is provided only --- src/diffusers/pipelines/photon/pipeline_photon.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/photon/pipeline_photon.py b/src/diffusers/pipelines/photon/pipeline_photon.py index 7673b89ca1dc..b30da7d62739 100644 --- a/src/diffusers/pipelines/photon/pipeline_photon.py +++ b/src/diffusers/pipelines/photon/pipeline_photon.py @@ -234,8 +234,8 @@ class PhotonPipeline( The Photon transformer model to denoise the encoded image latents. scheduler ([`FlowMatchEulerDiscreteScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. - text_encoder ([`T5EncoderModel`] or [`T5GemmaEncoder`]): - Text encoder model for encoding prompts. Supports T5EncoderModel or T5GemmaEncoder. + text_encoder ([`T5GemmaEncoder`]): + Text encoder model for encoding prompts. tokenizer ([`T5TokenizerFast` or `GemmaTokenizerFast`]): Tokenizer for the text encoder. vae ([`AutoencoderKL`] or [`AutoencoderDC`]): @@ -251,9 +251,9 @@ def __init__( self, transformer: PhotonTransformer2DModel, scheduler: FlowMatchEulerDiscreteScheduler, - text_encoder: Union[T5EncoderModel, T5GemmaEncoder], + text_encoder: Union[T5GemmaEncoder], tokenizer: Union[T5TokenizerFast, GemmaTokenizerFast, AutoTokenizer], - vae: Union[AutoencoderKL, AutoencoderDC], + vae: Optional[Union[AutoencoderKL, AutoencoderDC]] = None, default_sample_size: Optional[int] = DEFAULT_RESOLUTION, ): super().__init__() @@ -277,8 +277,9 @@ def __init__( self.register_to_config(default_sample_size=default_sample_size) - # Enhance VAE with universal properties for both AutoencoderKL and AutoencoderDC - self._enhance_vae_properties() + if vae is not None: + # Enhance VAE with universal properties for both AutoencoderKL and AutoencoderDC + self._enhance_vae_properties() self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) From d5c32727261db336da887dd1acfa85e78df3608e Mon Sep 17 00:00:00 2001 From: davidb Date: Thu, 9 Oct 2025 16:27:09 +0000 Subject: [PATCH 04/52] remove autocast for text encoder forwad --- src/diffusers/pipelines/photon/pipeline_photon.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/photon/pipeline_photon.py b/src/diffusers/pipelines/photon/pipeline_photon.py index b30da7d62739..11d4fd0f0621 100644 --- a/src/diffusers/pipelines/photon/pipeline_photon.py +++ b/src/diffusers/pipelines/photon/pipeline_photon.py @@ -381,12 +381,11 @@ def _encode_prompt_standard(self, prompt: List[str], device: torch.device): attention_mask = tokens["attention_mask"].bool().to(device) with torch.no_grad(): - with torch.autocast("cuda", enabled=False): - emb = self.text_encoder( - input_ids=input_ids, - attention_mask=attention_mask, - output_hidden_states=True, - ) + emb = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + ) all_embeddings = emb["last_hidden_state"] From 234c5e352f04c72f6aff8fbc28f1f776b9622536 Mon Sep 17 00:00:00 2001 From: David Briand Date: Thu, 9 Oct 2025 16:19:26 +0000 Subject: [PATCH 05/52] BF16 example --- docs/source/en/api/pipelines/photon.md | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/docs/source/en/api/pipelines/photon.md b/docs/source/en/api/pipelines/photon.md index 270284673f92..71c9a02bcf10 100644 --- a/docs/source/en/api/pipelines/photon.md +++ b/docs/source/en/api/pipelines/photon.md @@ -76,7 +76,7 @@ from transformers import T5GemmaModel, GemmaTokenizerFast # Load transformer transformer = PhotonTransformer2DModel.from_pretrained( "Photoroom/photon-512-t2i", subfolder="transformer" -) +).to(dtype=torch.bfloat16) # Load scheduler scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( @@ -134,15 +134,15 @@ Key parameters for image generation: # Example with custom parameters import torch from diffusers.pipelines.photon import PhotonPipeline - -pipe = pipe( - prompt="A highly detailed 3D animated scene of a cute, intelligent duck scientist in a futuristic laboratory. The duck stands on a shiny metallic floor surrounded by glowing glass tubes filled with colorful liquids—blue, green, and purple—connected by translucent hoses emitting soft light. The duck wears a tiny white lab coat, safety goggles, and has a curious, determined expression while conducting an experiment. Sparks of energy and soft particle effects fill the air as scientific instruments hum with power. In the background, holographic screens display molecular diagrams and equations. Above the duck’s head, the word “PHOTON” glows vividly in midair as if made of pure light, illuminating the scene with a warm golden glow. The lighting is cinematic, with rich reflections and subtle depth of field, emphasizing a Pixar-like, ultra-polished 3D animation style. Rendered in ultra high resolution, realistic subsurface scattering on the duck’s feathers, and vibrant color grading that gives a sense of wonder and scientific discovery.", - num_inference_steps=28, - guidance_scale=4.0, - height=512, - width=512, - generator=torch.Generator("cuda").manual_seed(42) -).images[0] +with torch.autocast("cuda", dtype=torch.bfloat16): + pipe = pipe( + prompt="A highly detailed 3D animated scene of a cute, intelligent duck scientist in a futuristic laboratory. The duck stands on a shiny metallic floor surrounded by glowing glass tubes filled with colorful liquids—blue, green, and purple—connected by translucent hoses emitting soft light. The duck wears a tiny white lab coat, safety goggles, and has a curious, determined expression while conducting an experiment. Sparks of energy and soft particle effects fill the air as scientific instruments hum with power. In the background, holographic screens display molecular diagrams and equations. Above the duck’s head, the word “PHOTON” glows vividly in midair as if made of pure light, illuminating the scene with a warm golden glow. The lighting is cinematic, with rich reflections and subtle depth of field, emphasizing a Pixar-like, ultra-polished 3D animation style. Rendered in ultra high resolution, realistic subsurface scattering on the duck’s feathers, and vibrant color grading that gives a sense of wonder and scientific discovery.", + num_inference_steps=28, + guidance_scale=4.0, + height=512, + width=512, + generator=torch.Generator("cuda").manual_seed(42) + ).images[0] ``` ## Memory Optimization From 49528a4bd0c439123e15a9a10df0a0dd3afd40af Mon Sep 17 00:00:00 2001 From: davidb Date: Fri, 10 Oct 2025 09:46:20 +0000 Subject: [PATCH 06/52] conditioned CFG --- .../pipelines/photon/pipeline_photon.py | 113 +++++++++++------- 1 file changed, 72 insertions(+), 41 deletions(-) diff --git a/src/diffusers/pipelines/photon/pipeline_photon.py b/src/diffusers/pipelines/photon/pipeline_photon.py index 11d4fd0f0621..b05ca1f5ea1a 100644 --- a/src/diffusers/pipelines/photon/pipeline_photon.py +++ b/src/diffusers/pipelines/photon/pipeline_photon.py @@ -245,13 +245,13 @@ class PhotonPipeline( model_cpu_offload_seq = "text_encoder->transformer->vae" _callback_tensor_inputs = ["latents"] - _optional_components = [] + _optional_components = ["vae"] def __init__( self, transformer: PhotonTransformer2DModel, scheduler: FlowMatchEulerDiscreteScheduler, - text_encoder: Union[T5GemmaEncoder], + text_encoder: T5GemmaEncoder, tokenizer: Union[T5TokenizerFast, GemmaTokenizerFast, AutoTokenizer], vae: Optional[Union[AutoencoderKL, AutoencoderDC]] = None, default_sample_size: Optional[int] = DEFAULT_RESOLUTION, @@ -330,6 +330,11 @@ def vae_scale_factor(self): """Compatibility property that returns spatial compression ratio.""" return getattr(self.vae, "spatial_compression_ratio", 8) + @property + def do_classifier_free_guidance(self): + """Check if classifier-free guidance is enabled based on guidance scale.""" + return self._guidance_scale > 1.0 + def prepare_latents( self, batch_size: int, @@ -353,49 +358,67 @@ def prepare_latents( latents = latents.to(device) return latents - def encode_prompt(self, prompt: Union[str, List[str]], device: torch.device): + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: torch.device, + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + ): """Encode text prompt using standard text encoder and tokenizer.""" if isinstance(prompt, str): prompt = [prompt] - return self._encode_prompt_standard(prompt, device) - - def _encode_prompt_standard(self, prompt: List[str], device: torch.device): - """Encode prompt using standard text encoder and tokenizer with batch processing.""" - # Clean text using modular preprocessor - cleaned_prompts = [self.text_preprocessor.clean_text(text) for text in prompt] - cleaned_uncond_prompts = [self.text_preprocessor.clean_text("") for _ in prompt] - - all_prompts = cleaned_prompts + cleaned_uncond_prompts + return self._encode_prompt_standard(prompt, device, do_classifier_free_guidance, negative_prompt) + def _tokenize_prompts(self, prompts: List[str], device: torch.device): + """Tokenize and clean prompts.""" + cleaned = [self.text_preprocessor.clean_text(text) for text in prompts] tokens = self.tokenizer( - all_prompts, + cleaned, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_attention_mask=True, return_tensors="pt", ) + return tokens["input_ids"].to(device), tokens["attention_mask"].bool().to(device) + + def _encode_prompt_standard( + self, + prompt: List[str], + device: torch.device, + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + ): + """Encode prompt using standard text encoder and tokenizer with batch processing.""" + batch_size = len(prompt) - input_ids = tokens["input_ids"].to(device) - attention_mask = tokens["attention_mask"].bool().to(device) + if do_classifier_free_guidance: + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + + prompts_to_encode = negative_prompt + prompt + else: + prompts_to_encode = prompt + + input_ids, attention_mask = self._tokenize_prompts(prompts_to_encode, device) with torch.no_grad(): - emb = self.text_encoder( + embeddings = self.text_encoder( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, - ) - - all_embeddings = emb["last_hidden_state"] + )["last_hidden_state"] - # Split back into conditional and unconditional - batch_size = len(prompt) - text_embeddings = all_embeddings[:batch_size] - uncond_text_embeddings = all_embeddings[batch_size:] - - cross_attn_mask = attention_mask[:batch_size] - uncond_cross_attn_mask = attention_mask[batch_size:] + if do_classifier_free_guidance: + uncond_text_embeddings, text_embeddings = embeddings.split(batch_size, dim=0) + uncond_cross_attn_mask, cross_attn_mask = attention_mask.split(batch_size, dim=0) + else: + text_embeddings = embeddings + cross_attn_mask = attention_mask + uncond_text_embeddings = None + uncond_cross_attn_mask = None return text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask @@ -534,9 +557,11 @@ def __call__( device = self._execution_device + self._guidance_scale = guidance_scale + # 2. Encode input prompt text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask = self.encode_prompt( - prompt, device + prompt, device, do_classifier_free_guidance=self.do_classifier_free_guidance ) # 3. Prepare timesteps @@ -572,17 +597,22 @@ def __call__( with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): - # Duplicate latents for CFG - latents_in = torch.cat([latents, latents], dim=0) - - # Cross-attention batch (uncond, cond) - ca_embed = torch.cat([uncond_text_embeddings, text_embeddings], dim=0) - ca_mask = None - if cross_attn_mask is not None and uncond_cross_attn_mask is not None: - ca_mask = torch.cat([uncond_cross_attn_mask, cross_attn_mask], dim=0) - - # Normalize timestep for the transformer - t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).repeat(2).to(device) + # Duplicate latents if using classifier-free guidance + if self.do_classifier_free_guidance: + latents_in = torch.cat([latents, latents], dim=0) + # Cross-attention batch (uncond, cond) + ca_embed = torch.cat([uncond_text_embeddings, text_embeddings], dim=0) + ca_mask = None + if cross_attn_mask is not None and uncond_cross_attn_mask is not None: + ca_mask = torch.cat([uncond_cross_attn_mask, cross_attn_mask], dim=0) + # Normalize timestep for the transformer + t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).repeat(2).to(device) + else: + latents_in = latents + ca_embed = text_embeddings + ca_mask = cross_attn_mask + # Normalize timestep for the transformer + t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).to(device) # Process inputs for transformer img_seq, txt, pe = self.transformer._process_inputs(latents_in, ca_embed) @@ -597,11 +627,12 @@ def __call__( ) # Convert back to image format - noise_both = seq2img(img_seq, self.transformer.patch_size, latents_in.shape) + noise_pred = seq2img(img_seq, self.transformer.patch_size, latents_in.shape) # Apply CFG - noise_uncond, noise_text = noise_both.chunk(2, dim=0) - noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) + if self.do_classifier_free_guidance: + noise_uncond, noise_text = noise_pred.chunk(2, dim=0) + noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) # Compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample From 0a4183cfa874467e3d95ce8b2a33c2fe2f309278 Mon Sep 17 00:00:00 2001 From: davidb Date: Fri, 10 Oct 2025 10:55:01 +0000 Subject: [PATCH 07/52] remove enhance vae and use vae.config directly when possible --- .../pipelines/photon/pipeline_photon.py | 72 +++++-------------- 1 file changed, 16 insertions(+), 56 deletions(-) diff --git a/src/diffusers/pipelines/photon/pipeline_photon.py b/src/diffusers/pipelines/photon/pipeline_photon.py index b05ca1f5ea1a..b6f3a78e692c 100644 --- a/src/diffusers/pipelines/photon/pipeline_photon.py +++ b/src/diffusers/pipelines/photon/pipeline_photon.py @@ -277,58 +277,14 @@ def __init__( self.register_to_config(default_sample_size=default_sample_size) - if vae is not None: - # Enhance VAE with universal properties for both AutoencoderKL and AutoencoderDC - self._enhance_vae_properties() - - self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) - - def _enhance_vae_properties(self): - """Add universal properties to VAE for consistent interface across AutoencoderKL and AutoencoderDC.""" - if not hasattr(self, "vae") or self.vae is None: - return - - if hasattr(self.vae, "spatial_compression_ratio"): - # AutoencoderDC already has this property - pass - elif hasattr(self.vae, "config") and hasattr(self.vae.config, "block_out_channels"): - # AutoencoderKL: calculate from block_out_channels - self.vae.spatial_compression_ratio = 2 ** (len(self.vae.config.block_out_channels) - 1) - else: - # Fallback - self.vae.spatial_compression_ratio = 8 - - if hasattr(self.vae, "config"): - self.vae.scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215) - else: - self.vae.scaling_factor = 0.18215 - - if hasattr(self.vae, "config"): - shift_factor = getattr(self.vae.config, "shift_factor", None) - if shift_factor is None: # AutoencoderDC case - self.vae.shift_factor = 0.0 - else: - self.vae.shift_factor = shift_factor - else: - self.vae.shift_factor = 0.0 - - if hasattr(self.vae, "config") and hasattr(self.vae.config, "latent_channels"): - # AutoencoderDC has latent_channels in config - self.vae.latent_channels = int(self.vae.config.latent_channels) - elif hasattr(self.vae, "config") and hasattr(self.vae.config, "in_channels"): - # AutoencoderKL has in_channels in config - self.vae.latent_channels = int(self.vae.config.in_channels) - else: - # Fallback based on VAE type - DC-AE typically has 32, AutoencoderKL has 4/16 - if hasattr(self.vae, "spatial_compression_ratio") and self.vae.spatial_compression_ratio == 32: - self.vae.latent_channels = 32 # DC-AE default - else: - self.vae.latent_channels = 16 # FluxVAE default + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) @property - def vae_scale_factor(self): - """Compatibility property that returns spatial compression ratio.""" - return getattr(self.vae, "spatial_compression_ratio", 8) + def vae_spatial_compression_ratio(self): + if hasattr(self.vae, "spatial_compression_ratio"): + return self.vae.spatial_compression_ratio + else: # Flux VAE + return 2 ** (len(self.vae.config.block_out_channels) - 1) @property def do_classifier_free_guidance(self): @@ -348,9 +304,10 @@ def prepare_latents( ): """Prepare initial latents for the diffusion process.""" if latents is None: + spatial_compression = self.vae_spatial_compression_ratio latent_height, latent_width = ( - height // self.vae.spatial_compression_ratio, - width // self.vae.spatial_compression_ratio, + height // spatial_compression, + width // spatial_compression, ) shape = (batch_size, num_channels_latents, latent_height, latent_width) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) @@ -431,9 +388,10 @@ def check_inputs( callback_on_step_end_tensor_inputs: Optional[List[str]] = None, ): """Check that all inputs are in correct format.""" - if height % self.vae.spatial_compression_ratio != 0 or width % self.vae.spatial_compression_ratio != 0: + spatial_compression = self.vae_spatial_compression_ratio + if height % spatial_compression != 0 or width % spatial_compression != 0: raise ValueError( - f"`height` and `width` have to be divisible by {self.vae.spatial_compression_ratio} but are {height} and {width}." + f"`height` and `width` have to be divisible by {spatial_compression} but are {height} and {width}." ) if guidance_scale < 1.0: @@ -574,7 +532,7 @@ def __call__( timesteps = self.scheduler.timesteps # 4. Prepare latent variables - num_channels_latents = self.vae.latent_channels + num_channels_latents = self.vae.config.latent_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, @@ -652,7 +610,9 @@ def __call__( image = latents else: # Unscale latents for VAE (supports both AutoencoderKL and AutoencoderDC) - latents = (latents / self.vae.scaling_factor) + self.vae.shift_factor + scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215) + shift_factor = getattr(self.vae.config, "shift_factor", 0.0) + latents = (latents / scaling_factor) + shift_factor # Decode using VAE (AutoencoderKL or AutoencoderDC) image = self.vae.decode(latents, return_dict=False)[0] # Resize back to original resolution if using binning From 7f1199b0bbd180d42a1e555704825f5ca4baf926 Mon Sep 17 00:00:00 2001 From: davidb Date: Fri, 10 Oct 2025 12:15:14 +0000 Subject: [PATCH 08/52] move PhotonAttnProcessor2_0 in transformer_photon --- src/diffusers/models/attention_processor.py | 57 ------------------- .../models/transformers/transformer_photon.py | 56 +++++++++++++++++- .../pipelines/photon/pipeline_output.py | 2 +- .../pipelines/photon/pipeline_photon.py | 2 +- 4 files changed, 57 insertions(+), 60 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index c325bc71ae84..71b8ba685e78 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -5605,62 +5605,6 @@ def __new__(cls, *args, **kwargs): return processor -class PhotonAttnProcessor2_0: - r""" - Processor for implementing Photon-style attention with multi-source tokens and RoPE. Properly integrates with - diffusers Attention module while handling Photon-specific logic. - """ - - def __init__(self): - if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): - raise ImportError("PhotonAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.") - - def __call__( - self, - attn: "Attention", - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor: - """ - Apply Photon attention using standard diffusers interface. - - Expected tensor formats from PhotonBlock.attn_forward(): - - hidden_states: Image queries with RoPE applied [B, H, L_img, D] - - encoder_hidden_states: Packed key+value tensors [B, H, L_all, 2*D] (concatenated keys and values from text + - image + spatial conditioning) - - attention_mask: Custom attention mask [B, H, L_img, L_all] or None - """ - - if encoder_hidden_states is None: - raise ValueError( - "PhotonAttnProcessor2_0 requires 'encoder_hidden_states' containing packed key+value tensors. " - "This should be provided by PhotonBlock.attn_forward()." - ) - - # Unpack the combined key+value tensor - # encoder_hidden_states is [B, H, L_all, 2*D] containing [keys, values] - key, value = encoder_hidden_states.chunk(2, dim=-1) # Each [B, H, L_all, D] - - # Apply scaled dot-product attention with Photon's processed tensors - # hidden_states is image queries [B, H, L_img, D] - attn_output = torch.nn.functional.scaled_dot_product_attention( - hidden_states.contiguous(), key.contiguous(), value.contiguous(), attn_mask=attention_mask - ) - - # Reshape from [B, H, L_img, D] to [B, L_img, H*D] - batch_size, num_heads, seq_len, head_dim = attn_output.shape - attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, num_heads * head_dim) - - # Apply output projection - attn_output = attn.to_out[0](attn_output) - if len(attn.to_out) > 1: - attn_output = attn.to_out[1](attn_output) # dropout if present - - return attn_output - ADDED_KV_ATTENTION_PROCESSORS = ( AttnAddedKVProcessor, @@ -5710,7 +5654,6 @@ def __call__( PAGHunyuanAttnProcessor2_0, PAGCFGHunyuanAttnProcessor2_0, LuminaAttnProcessor2_0, - PhotonAttnProcessor2_0, FusedAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0, diff --git a/src/diffusers/models/transformers/transformer_photon.py b/src/diffusers/models/transformers/transformer_photon.py index 7b29c8bfdafb..fc7a89c8d644 100644 --- a/src/diffusers/models/transformers/transformer_photon.py +++ b/src/diffusers/models/transformers/transformer_photon.py @@ -23,7 +23,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers -from ..attention_processor import Attention, AttentionProcessor, PhotonAttnProcessor2_0 +from ..attention_processor import Attention, AttentionProcessor from ..embeddings import get_timestep_embedding from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -79,7 +79,61 @@ def apply_rope(xq: Tensor, freqs_cis: Tensor) -> Tensor: xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] return xq_out.reshape(*xq.shape).type_as(xq) +class PhotonAttnProcessor2_0: + r""" + Processor for implementing Photon-style attention with multi-source tokens and RoPE. Properly integrates with + diffusers Attention module while handling Photon-specific logic. + """ + + def __init__(self): + if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): + raise ImportError("PhotonAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: "Attention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + Apply Photon attention using standard diffusers interface. + + Expected tensor formats from PhotonBlock.attn_forward(): + - hidden_states: Image queries with RoPE applied [B, H, L_img, D] + - encoder_hidden_states: Packed key+value tensors [B, H, L_all, 2*D] (concatenated keys and values from text + + image + spatial conditioning) + - attention_mask: Custom attention mask [B, H, L_img, L_all] or None + """ + + if encoder_hidden_states is None: + raise ValueError( + "PhotonAttnProcessor2_0 requires 'encoder_hidden_states' containing packed key+value tensors. " + "This should be provided by PhotonBlock.attn_forward()." + ) + + # Unpack the combined key+value tensor + # encoder_hidden_states is [B, H, L_all, 2*D] containing [keys, values] + key, value = encoder_hidden_states.chunk(2, dim=-1) # Each [B, H, L_all, D] + + # Apply scaled dot-product attention with Photon's processed tensors + # hidden_states is image queries [B, H, L_img, D] + attn_output = torch.nn.functional.scaled_dot_product_attention( + hidden_states.contiguous(), key.contiguous(), value.contiguous(), attn_mask=attention_mask + ) + + # Reshape from [B, H, L_img, D] to [B, L_img, H*D] + batch_size, num_heads, seq_len, head_dim = attn_output.shape + attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, num_heads * head_dim) + + # Apply output projection + attn_output = attn.to_out[0](attn_output) + if len(attn.to_out) > 1: + attn_output = attn.to_out[1](attn_output) # dropout if present + return attn_output class EmbedND(nn.Module): r""" N-dimensional rotary positional embedding. diff --git a/src/diffusers/pipelines/photon/pipeline_output.py b/src/diffusers/pipelines/photon/pipeline_output.py index ca0674d94b6c..d4b0ff462983 100644 --- a/src/diffusers/pipelines/photon/pipeline_output.py +++ b/src/diffusers/pipelines/photon/pipeline_output.py @@ -1,4 +1,4 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. +# Copyright 2025 The Photoroom and the HuggingFace Teams. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/diffusers/pipelines/photon/pipeline_photon.py b/src/diffusers/pipelines/photon/pipeline_photon.py index b6f3a78e692c..4986893b068c 100644 --- a/src/diffusers/pipelines/photon/pipeline_photon.py +++ b/src/diffusers/pipelines/photon/pipeline_photon.py @@ -1,4 +1,4 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. +# Copyright 2025 The Photoroom and the HuggingFace Teams. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From cabde58b4427d53fefe77c10115075e39427730b Mon Sep 17 00:00:00 2001 From: davidb Date: Fri, 10 Oct 2025 13:40:49 +0000 Subject: [PATCH 09/52] remove einops dependency and now inherits from AttentionMixin --- .../models/transformers/transformer_photon.py | 82 ++++--------------- 1 file changed, 16 insertions(+), 66 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_photon.py b/src/diffusers/models/transformers/transformer_photon.py index fc7a89c8d644..8c7180294c13 100644 --- a/src/diffusers/models/transformers/transformer_photon.py +++ b/src/diffusers/models/transformers/transformer_photon.py @@ -16,13 +16,12 @@ from typing import Any, Dict, Optional, Tuple, Union import torch -from einops import rearrange -from einops.layers.torch import Rearrange from torch import Tensor, nn from torch.nn.functional import fold, unfold from ...configuration_utils import ConfigMixin, register_to_config from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ..attention import AttentionMixin from ..attention_processor import Attention, AttentionProcessor from ..embeddings import get_timestep_embedding from ..modeling_outputs import Transformer2DModelOutput @@ -134,6 +133,7 @@ def __call__( attn_output = attn.to_out[1](attn_output) # dropout if present return attn_output +# copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py class EmbedND(nn.Module): r""" N-dimensional rotary positional embedding. @@ -155,7 +155,6 @@ def __init__(self, dim: int, theta: int, axes_dim: list[int]): self.dim = dim self.theta = theta self.axes_dim = axes_dim - self.rope_rearrange = Rearrange("b n d (i j) -> b n d i j", i=2, j=2) def rope(self, pos: Tensor, dim: int, theta: int) -> Tensor: assert dim % 2 == 0 @@ -163,7 +162,9 @@ def rope(self, pos: Tensor, dim: int, theta: int) -> Tensor: omega = 1.0 / (theta**scale) out = pos.unsqueeze(-1) * omega.unsqueeze(0) out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) - out = self.rope_rearrange(out) + # Native PyTorch equivalent of: Rearrange("b n d (i j) -> b n d i j", i=2, j=2) + # out shape: (b, n, d, 4) -> reshape to (b, n, d, 2, 2) + out = out.reshape(*out.shape[:-1], 2, 2) return out.float() def forward(self, ids: Tensor) -> Tensor: @@ -378,12 +379,20 @@ def _attn_forward( img_mod = (1 + modulation.scale) * self.img_pre_norm(img) + modulation.shift img_qkv = self.img_qkv_proj(img_mod) - img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + # Native PyTorch equivalent of: rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + B, L, _ = img_qkv.shape + img_qkv = img_qkv.reshape(B, L, 3, self.num_heads, self.head_dim) # (B, L, K, H, D) + img_qkv = img_qkv.permute(2, 0, 3, 1, 4) # (K, B, H, L, D) + img_q, img_k, img_v = img_qkv[0], img_qkv[1], img_qkv[2] img_q, img_k = self.qk_norm(img_q, img_k, img_v) # txt tokens proj and norm txt_kv = self.txt_kv_proj(txt) - txt_k, txt_v = rearrange(txt_kv, "B L (K H D) -> K B H L D", K=2, H=self.num_heads) + # Native PyTorch equivalent of: rearrange(txt_kv, "B L (K H D) -> K B H L D", K=2, H=self.num_heads) + B, L, _ = txt_kv.shape + txt_kv = txt_kv.reshape(B, L, 2, self.num_heads, self.head_dim) # (B, L, K, H, D) + txt_kv = txt_kv.permute(2, 0, 3, 1, 4) # (K, B, H, L, D) + txt_k, txt_v = txt_kv[0], txt_kv[1] txt_k = self.k_norm(txt_k) # compute attention @@ -564,7 +573,7 @@ def seq2img(seq: Tensor, patch_size: int, shape: Tensor) -> Tensor: return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size) -class PhotonTransformer2DModel(ModelMixin, ConfigMixin): +class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): r""" Transformer-based 2D model for text to image generation. It supports attention processor injection and LoRA scaling. @@ -689,65 +698,6 @@ def __init__( self.gradient_checkpointing = False - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - def _process_inputs(self, image_latent: Tensor, txt: Tensor, **_: Any) -> tuple[Tensor, Tensor, Tensor]: txt = self.txt_in(txt) img = img2seq(image_latent, self.patch_size) From bca1f7ce3f24e2d84e031b24138c12c938c2db0a Mon Sep 17 00:00:00 2001 From: davidb Date: Fri, 10 Oct 2025 13:53:18 +0000 Subject: [PATCH 10/52] unify the structure of the forward block --- src/diffusers/models/attention_processor.py | 1 - .../models/transformers/transformer_photon.py | 108 ++++++++---------- .../pipelines/photon/pipeline_photon.py | 74 ++++++------ 3 files changed, 82 insertions(+), 101 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 71b8ba685e78..66455d733aee 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -5605,7 +5605,6 @@ def __new__(cls, *args, **kwargs): return processor - ADDED_KV_ATTENTION_PROCESSORS = ( AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, diff --git a/src/diffusers/models/transformers/transformer_photon.py b/src/diffusers/models/transformers/transformer_photon.py index 8c7180294c13..533eb356e006 100644 --- a/src/diffusers/models/transformers/transformer_photon.py +++ b/src/diffusers/models/transformers/transformer_photon.py @@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention import AttentionMixin -from ..attention_processor import Attention, AttentionProcessor +from ..attention_processor import Attention from ..embeddings import get_timestep_embedding from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -78,6 +78,7 @@ def apply_rope(xq: Tensor, freqs_cis: Tensor) -> Tensor: xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] return xq_out.reshape(*xq.shape).type_as(xq) + class PhotonAttnProcessor2_0: r""" Processor for implementing Photon-style attention with multi-source tokens and RoPE. Properly integrates with @@ -133,6 +134,8 @@ def __call__( attn_output = attn.to_out[1](attn_output) # dropout if present return attn_output + + # copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py class EmbedND(nn.Module): r""" @@ -299,9 +302,8 @@ class PhotonBlock(nn.Module): Produces scale/shift/gating parameters for modulated layers. Methods: - attn_forward(img, txt, pe, modulation, spatial_conditioning=None, attention_mask=None): - Compute cross-attention between image and text tokens, with optional spatial conditioning and attention - masking. + attn_forward(img, txt, pe, modulation, attention_mask=None): + Compute cross-attention between image and text tokens, with optional attention masking. Parameters: img (`torch.Tensor`): @@ -312,8 +314,6 @@ class PhotonBlock(nn.Module): Rotary positional embeddings to apply to queries and keys. modulation (`ModulationOut`): Scale and shift parameters for modulating image tokens. - spatial_conditioning (`torch.Tensor`, *optional*): - Extra conditioning tokens of shape `(B, L_cond, hidden_size)`. attention_mask (`torch.Tensor`, *optional*): Boolean mask of shape `(B, L_txt)` where 0 marks padding. @@ -372,7 +372,6 @@ def _attn_forward( txt: Tensor, pe: Tensor, modulation: ModulationOut, - spatial_conditioning: None | Tensor = None, attention_mask: None | Tensor = None, ) -> Tensor: # image tokens proj and norm @@ -444,7 +443,6 @@ def forward( txt: Tensor, vec: Tensor, pe: Tensor, - spatial_conditioning: Tensor | None = None, attention_mask: Tensor | None = None, **_: dict[str, Any], ) -> Tensor: @@ -461,9 +459,6 @@ def forward( broadcastable). pe (`torch.Tensor`): Rotary positional embeddings applied inside attention. - spatial_conditioning (`torch.Tensor`, *optional*): - Extra conditioning tokens of shape `(B, L_cond, hidden_size)`. Used only if spatial conditioning is - enabled in the block. attention_mask (`torch.Tensor`, *optional*): Boolean mask for text tokens of shape `(B, L_txt)`, where `0` marks padding. **_: @@ -481,7 +476,6 @@ def forward( txt, pe, mod_attn, - spatial_conditioning=spatial_conditioning, attention_mask=attention_mask, ) img = img + mod_mlp.gate * self._ffn_forward(img, mod_mlp) @@ -698,14 +692,6 @@ def __init__( self.gradient_checkpointing = False - def _process_inputs(self, image_latent: Tensor, txt: Tensor, **_: Any) -> tuple[Tensor, Tensor, Tensor]: - txt = self.txt_in(txt) - img = img2seq(image_latent, self.patch_size) - bs, _, h, w = image_latent.shape - img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=image_latent.device) - pe = self.pe_embedder(img_ids) - return img, txt, pe - def _compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> Tensor: return self.time_in( get_timestep_embedding( @@ -717,43 +703,6 @@ def _compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> T ).to(dtype) ) - def _forward_transformers( - self, - image_latent: Tensor, - cross_attn_conditioning: Tensor, - timestep: Optional[Tensor] = None, - time_embedding: Optional[Tensor] = None, - attention_mask: Optional[Tensor] = None, - **block_kwargs: Any, - ) -> Tensor: - img = self.img_in(image_latent) - - if time_embedding is not None: - vec = time_embedding - else: - if timestep is None: - raise ValueError("Please provide either a timestep or a timestep_embedding") - vec = self._compute_timestep_embedding(timestep, dtype=img.dtype) - - for block in self.blocks: - if torch.is_grad_enabled() and self.gradient_checkpointing: - img = self._gradient_checkpointing_func( - block.__call__, - img, - cross_attn_conditioning, - vec, - block_kwargs.get("pe"), - block_kwargs.get("spatial_conditioning"), - attention_mask, - ) - else: - img = block( - img=img, txt=cross_attn_conditioning, vec=vec, attention_mask=attention_mask, **block_kwargs - ) - - img = self.final_layer(img, vec) - return img - def forward( self, image_latent: Tensor, @@ -797,6 +746,7 @@ def forward( lora_scale = attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 + if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) @@ -805,12 +755,50 @@ def forward( logger.warning( "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." ) - img_seq, txt, pe = self._process_inputs(image_latent, cross_attn_conditioning) - img_seq = self._forward_transformers(img_seq, txt, timestep, pe=pe, attention_mask=cross_attn_mask) - output = seq2img(img_seq, self.patch_size, image_latent.shape) + + # Process text conditioning + txt = self.txt_in(cross_attn_conditioning) + + # Convert image to sequence and embed + img = img2seq(image_latent, self.patch_size) + img = self.img_in(img) + + # Generate positional embeddings + bs, _, h, w = image_latent.shape + img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=image_latent.device) + pe = self.pe_embedder(img_ids) + + # Compute time embedding + vec = self._compute_timestep_embedding(timestep, dtype=img.dtype) + + # Apply transformer blocks + for block in self.blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + img = self._gradient_checkpointing_func( + block.__call__, + img, + txt, + vec, + pe, + cross_attn_mask, + ) + else: + img = block( + img=img, + txt=txt, + vec=vec, + pe=pe, + attention_mask=cross_attn_mask, + ) + + # Final layer and convert back to image + img = self.final_layer(img, vec) + output = seq2img(img, self.patch_size, image_latent.shape) + if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer unscale_lora_layers(self, lora_scale) + if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/photon/pipeline_photon.py b/src/diffusers/pipelines/photon/pipeline_photon.py index 4986893b068c..c09f2b16081e 100644 --- a/src/diffusers/pipelines/photon/pipeline_photon.py +++ b/src/diffusers/pipelines/photon/pipeline_photon.py @@ -16,14 +16,13 @@ import inspect import re import urllib.parse as ul -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union import ftfy import torch from transformers import ( AutoTokenizer, GemmaTokenizerFast, - T5EncoderModel, T5TokenizerFast, ) from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder @@ -31,7 +30,7 @@ from diffusers.image_processor import PixArtImageProcessor from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models import AutoencoderDC, AutoencoderKL -from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel, seq2img +from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel from diffusers.pipelines.photon.pipeline_output import PhotonPipelineOutput from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.schedulers import FlowMatchEulerDiscreteScheduler @@ -45,29 +44,29 @@ DEFAULT_RESOLUTION = 512 ASPECT_RATIO_256_BIN = { - "0.46": [160, 352], - "0.6": [192, 320], - "0.78": [224, 288], - "1.0": [256, 256], - "1.29": [288, 224], - "1.67": [320, 192], - "2.2": [352, 160], + "0.46": [160, 352], + "0.6": [192, 320], + "0.78": [224, 288], + "1.0": [256, 256], + "1.29": [288, 224], + "1.67": [320, 192], + "2.2": [352, 160], } ASPECT_RATIO_512_BIN = { - "0.5": [352, 704], - "0.57": [384, 672], - "0.6": [384, 640], - "0.68": [416, 608], - "0.78": [448, 576], - "0.88": [480, 544], - "1.0": [512, 512], - "1.13": [544, 480], - "1.29": [576, 448], - "1.46": [608, 416], - "1.67": [640, 384], - "1.75": [672, 384], - "2.0": [704, 352], + "0.5": [352, 704], + "0.57": [384, 672], + "0.6": [384, 640], + "0.68": [416, 608], + "0.78": [448, 576], + "0.88": [480, 544], + "1.0": [512, 512], + "1.13": [544, 480], + "1.29": [576, 448], + "1.46": [608, 416], + "1.67": [640, 384], + "1.75": [672, 384], + "2.0": [704, 352], } logger = logging.get_logger(__name__) @@ -283,7 +282,7 @@ def __init__( def vae_spatial_compression_ratio(self): if hasattr(self.vae, "spatial_compression_ratio"): return self.vae.spatial_compression_ratio - else: # Flux VAE + else: # Flux VAE return 2 ** (len(self.vae.config.block_out_channels) - 1) @property @@ -461,8 +460,8 @@ def __call__( Whether or not to return a [`~pipelines.photon.PhotonPipelineOutput`] instead of a plain tuple. use_resolution_binning (`bool`, *optional*, defaults to `True`): If set to `True`, the requested height and width are first mapped to the closest resolutions using - predefined aspect ratio bins. After the produced latents are decoded into images, they are resized back to - the requested resolution. Useful for generating non-square images at optimal resolutions. + predefined aspect ratio bins. After the produced latents are decoded into images, they are resized back + to the requested resolution. Useful for generating non-square images at optimal resolutions. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self, step, timestep, callback_kwargs)`. @@ -572,20 +571,15 @@ def __call__( # Normalize timestep for the transformer t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).to(device) - # Process inputs for transformer - img_seq, txt, pe = self.transformer._process_inputs(latents_in, ca_embed) - - # Forward through transformer layers - img_seq = self.transformer._forward_transformers( - img_seq, - txt, - time_embedding=self.transformer._compute_timestep_embedding(t_cont, img_seq.dtype), - pe=pe, - attention_mask=ca_mask, - ) - - # Convert back to image format - noise_pred = seq2img(img_seq, self.transformer.patch_size, latents_in.shape) + # Forward through transformer + noise_pred = self.transformer( + image_latent=latents_in, + timestep=t_cont, + cross_attn_conditioning=ca_embed, + micro_conditioning=None, + cross_attn_mask=ca_mask, + return_dict=False, + )[0] # Apply CFG if self.do_classifier_free_guidance: From cdfa6361ef1a1116943aa1eb6cfe819bc6327690 Mon Sep 17 00:00:00 2001 From: davidb Date: Fri, 10 Oct 2025 14:30:30 +0000 Subject: [PATCH 11/52] update doc --- docs/source/en/api/pipelines/photon.md | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/docs/source/en/api/pipelines/photon.md b/docs/source/en/api/pipelines/photon.md index 71c9a02bcf10..6eeb030f1023 100644 --- a/docs/source/en/api/pipelines/photon.md +++ b/docs/source/en/api/pipelines/photon.md @@ -56,7 +56,7 @@ from diffusers.pipelines.photon import PhotonPipeline pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i") pipe.to("cuda") -prompt = "A vast night sky over a quiet city suddenly blazes with enormous glowing neon letters spelling “PHOTON.” The word hums and flickers dramatically, as if trying a little too hard to look epic. The soft glow bathes the rooftops and streets below in blue and pink light. A few people look up, squinting, some taking selfies; a cat blinks lazily at the sky’s new centerpiece. The air feels cinematic and electric — like a sci-fi movie that doesn’t take itself too seriously. Mist swirls around the neon glow, adding a dreamy, aesthetic touch to the humor of it all." +prompt = prompt = "A digital painting or a heavily manipulated photograph, appearing as a surreal portrait of a young woman. The composition is a close-up, focusing on the face. The woman's face is partially obscured by fragmented, cracked, light teal and off-white pieces resembling peeling paint or decaying skin. These fragments are irregularly shaped and layered, creating a sense of depth and texture. The woman's skin is subtly illuminated, with a warm, golden light highlighting her features, particularly her lips and eyes. Her eyes are a striking light blue, contrasting with the cool tones of the fragmented elements. The overall color palette is muted, with teal, beige, and golden hues dominating. The atmosphere is melancholic and mysterious, with a hint of ethereal beauty. The style is surreal and painterly, blending realistic portraiture with abstract elements. The vibe is introspective and unsettling, suggesting themes of vulnerability, fragility, and hidden identity. The lighting is dramatic, with a chiaroscuro effect emphasizing the texture and form of the fragmented elements" image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0] image.save("photon_output.png") ``` @@ -85,12 +85,12 @@ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( # Load T5Gemma text encoder t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2") -text_encoder = t5gemma_model.encoder +text_encoder = t5gemma_model.encoder.to(dtype=torch.bfloat16) tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2") tokenizer.model_max_length = 256 # Load VAE - choose either Flux VAE or DC-AE # Flux VAE (16 latent channels): -vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae") +vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae").to(dtype=torch.bfloat16) # Or DC-AE (32 latent channels): # vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers") @@ -134,15 +134,15 @@ Key parameters for image generation: # Example with custom parameters import torch from diffusers.pipelines.photon import PhotonPipeline -with torch.autocast("cuda", dtype=torch.bfloat16): - pipe = pipe( - prompt="A highly detailed 3D animated scene of a cute, intelligent duck scientist in a futuristic laboratory. The duck stands on a shiny metallic floor surrounded by glowing glass tubes filled with colorful liquids—blue, green, and purple—connected by translucent hoses emitting soft light. The duck wears a tiny white lab coat, safety goggles, and has a curious, determined expression while conducting an experiment. Sparks of energy and soft particle effects fill the air as scientific instruments hum with power. In the background, holographic screens display molecular diagrams and equations. Above the duck’s head, the word “PHOTON” glows vividly in midair as if made of pure light, illuminating the scene with a warm golden glow. The lighting is cinematic, with rich reflections and subtle depth of field, emphasizing a Pixar-like, ultra-polished 3D animation style. Rendered in ultra high resolution, realistic subsurface scattering on the duck’s feathers, and vibrant color grading that gives a sense of wonder and scientific discovery.", - num_inference_steps=28, - guidance_scale=4.0, - height=512, - width=512, - generator=torch.Generator("cuda").manual_seed(42) - ).images[0] +pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i", torch_dtype=torch.bfloat16) +pipe = pipe( + prompt = "A digital painting or a heavily manipulated photograph, appearing as a surreal portrait of a young woman. The composition is a close-up, focusing on the face. The woman's face is partially obscured by fragmented, cracked, light teal and off-white pieces resembling peeling paint or decaying skin. These fragments are irregularly shaped and layered, creating a sense of depth and texture. The woman's skin is subtly illuminated, with a warm, golden light highlighting her features, particularly her lips and eyes. Her eyes are a striking light blue, contrasting with the cool tones of the fragmented elements. The overall color palette is muted, with teal, beige, and golden hues dominating. The atmosphere is melancholic and mysterious, with a hint of ethereal beauty. The style is surreal and painterly, blending realistic portraiture with abstract elements. The vibe is introspective and unsettling, suggesting themes of vulnerability, fragility, and hidden identity. The lighting is dramatic, with a chiaroscuro effect emphasizing the texture and form of the fragmented elements" + num_inference_steps=28, + guidance_scale=4.0, + height=512, + width=512, + generator=torch.Generator("cuda").manual_seed(42) +).images[0] ``` ## Memory Optimization @@ -153,7 +153,7 @@ For memory-constrained environments: import torch from diffusers.pipelines.photon import PhotonPipeline -pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i", torch_dtype=torch.float16) +pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i", torch_dtype=torch.bfloat16) pipe.enable_model_cpu_offload() # Offload components to CPU when not in use # Or use sequential CPU offload for even lower memory From 03a7df3ab4ee59c832d503c2d8249f0ed75382c4 Mon Sep 17 00:00:00 2001 From: davidb Date: Fri, 10 Oct 2025 19:17:15 +0000 Subject: [PATCH 12/52] update doc --- docs/source/en/api/pipelines/photon.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/en/api/pipelines/photon.md b/docs/source/en/api/pipelines/photon.md index 6eeb030f1023..a326c50ae7c2 100644 --- a/docs/source/en/api/pipelines/photon.md +++ b/docs/source/en/api/pipelines/photon.md @@ -53,10 +53,10 @@ Refer to [this](https://huggingface.co/collections/Photoroom/photon-models-68e66 from diffusers.pipelines.photon import PhotonPipeline # Load pipeline - VAE and text encoder will be loaded from HuggingFace -pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i") +pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft", torch_dtype=torch.bfloat16) pipe.to("cuda") -prompt = prompt = "A digital painting or a heavily manipulated photograph, appearing as a surreal portrait of a young woman. The composition is a close-up, focusing on the face. The woman's face is partially obscured by fragmented, cracked, light teal and off-white pieces resembling peeling paint or decaying skin. These fragments are irregularly shaped and layered, creating a sense of depth and texture. The woman's skin is subtly illuminated, with a warm, golden light highlighting her features, particularly her lips and eyes. Her eyes are a striking light blue, contrasting with the cool tones of the fragmented elements. The overall color palette is muted, with teal, beige, and golden hues dominating. The atmosphere is melancholic and mysterious, with a hint of ethereal beauty. The style is surreal and painterly, blending realistic portraiture with abstract elements. The vibe is introspective and unsettling, suggesting themes of vulnerability, fragility, and hidden identity. The lighting is dramatic, with a chiaroscuro effect emphasizing the texture and form of the fragmented elements" +prompt = "A vibrant night sky filled with colorful fireworks, with one large firework burst forming the glowing text “PRX” in bright, sparkling light" image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0] image.save("photon_output.png") ``` @@ -134,9 +134,9 @@ Key parameters for image generation: # Example with custom parameters import torch from diffusers.pipelines.photon import PhotonPipeline -pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i", torch_dtype=torch.bfloat16) +pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft", torch_dtype=torch.bfloat16) pipe = pipe( - prompt = "A digital painting or a heavily manipulated photograph, appearing as a surreal portrait of a young woman. The composition is a close-up, focusing on the face. The woman's face is partially obscured by fragmented, cracked, light teal and off-white pieces resembling peeling paint or decaying skin. These fragments are irregularly shaped and layered, creating a sense of depth and texture. The woman's skin is subtly illuminated, with a warm, golden light highlighting her features, particularly her lips and eyes. Her eyes are a striking light blue, contrasting with the cool tones of the fragmented elements. The overall color palette is muted, with teal, beige, and golden hues dominating. The atmosphere is melancholic and mysterious, with a hint of ethereal beauty. The style is surreal and painterly, blending realistic portraiture with abstract elements. The vibe is introspective and unsettling, suggesting themes of vulnerability, fragility, and hidden identity. The lighting is dramatic, with a chiaroscuro effect emphasizing the texture and form of the fragmented elements" + prompt = "A vibrant night sky filled with colorful fireworks, with one large firework burst forming the glowing text “PRX” in bright, sparkling light" num_inference_steps=28, guidance_scale=4.0, height=512, From c24883592908a4ea607e2e88db187913006e4b55 Mon Sep 17 00:00:00 2001 From: davidb Date: Fri, 10 Oct 2025 19:19:38 +0000 Subject: [PATCH 13/52] fix T5Gemma loading from hub --- scripts/convert_photon_to_diffusers.py | 12 +++++++----- src/diffusers/pipelines/photon/__init__.py | 11 +++++++++++ 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/scripts/convert_photon_to_diffusers.py b/scripts/convert_photon_to_diffusers.py index fc4161ff6275..2f060fd3cdc9 100644 --- a/scripts/convert_photon_to_diffusers.py +++ b/scripts/convert_photon_to_diffusers.py @@ -207,7 +207,6 @@ def download_and_save_vae(vae_type: str, output_path: str): def download_and_save_text_encoder(output_path: str): """Download and save T5Gemma text encoder and tokenizer.""" from transformers import GemmaTokenizerFast - from transformers.models.t5gemma.modeling_t5gemma import T5GemmaModel text_encoder_path = os.path.join(output_path, "text_encoder") tokenizer_path = os.path.join(output_path, "tokenizer") @@ -215,11 +214,14 @@ def download_and_save_text_encoder(output_path: str): os.makedirs(tokenizer_path, exist_ok=True) print("Downloading T5Gemma model from google/t5gemma-2b-2b-ul2...") + from transformers.models.t5gemma.modeling_t5gemma import T5GemmaModel + t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2") - # Extract and save only the encoder - t5gemma_encoder = t5gemma_model.encoder - t5gemma_encoder.save_pretrained(text_encoder_path) + # Save only the encoder + encoder = t5gemma_model.encoder + encoder.save_pretrained(text_encoder_path) + print(f"✓ Saved T5GemmaEncoder to {text_encoder_path}") print("Downloading tokenizer from google/t5gemma-2b-2b-ul2...") @@ -243,7 +245,7 @@ def create_model_index(vae_type: str, default_image_size: int, output_path: str) "_name_or_path": os.path.basename(output_path), "default_sample_size": default_image_size, "scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"], - "text_encoder": ["transformers.models.t5gemma.modeling_t5gemma", "T5GemmaEncoder"], + "text_encoder": ["photon", "T5GemmaEncoder"], "tokenizer": ["transformers", "GemmaTokenizerFast"], "transformer": ["diffusers", "PhotonTransformer2DModel"], "vae": ["diffusers", vae_class], diff --git a/src/diffusers/pipelines/photon/__init__.py b/src/diffusers/pipelines/photon/__init__.py index 559c9d0b1d2d..6f376e440fb2 100644 --- a/src/diffusers/pipelines/photon/__init__.py +++ b/src/diffusers/pipelines/photon/__init__.py @@ -1,5 +1,16 @@ +from typing import TYPE_CHECKING + from .pipeline_output import PhotonPipelineOutput from .pipeline_photon import PhotonPipeline __all__ = ["PhotonPipeline", "PhotonPipelineOutput"] + +# Make T5GemmaEncoder importable from this module for pipeline loading +if TYPE_CHECKING: + from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder +else: + try: + from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder + except ImportError: + pass From 94f469a36d304757c86cfd40440562b9dede972d Mon Sep 17 00:00:00 2001 From: davidb Date: Mon, 13 Oct 2025 09:18:13 +0000 Subject: [PATCH 14/52] fix timestep shift --- docs/source/en/api/pipelines/photon.md | 30 ++++++++++++++------------ scripts/convert_photon_to_diffusers.py | 23 ++++++++++++-------- 2 files changed, 30 insertions(+), 23 deletions(-) diff --git a/docs/source/en/api/pipelines/photon.md b/docs/source/en/api/pipelines/photon.md index a326c50ae7c2..2f0f6b428a13 100644 --- a/docs/source/en/api/pipelines/photon.md +++ b/docs/source/en/api/pipelines/photon.md @@ -36,14 +36,16 @@ Both **fine-tuned** and **non-fine-tuned** versions are available: - **Fine-tuned models**, trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist), enhance the **aesthetic quality** of the base models—especially when prompts are **less detailed**. -| Model | Recommended dtype | Resolution | Fine-tuned | -|:-----:|:-----------------:|:----------:|:----------:| -| [`Photoroom/photon-256-t2i`](https://huggingface.co/Photoroom/photon-256-t2i) | `torch.bfloat16` | 256x256 | No | -| [`Photoroom/photon-256-t2i-sft`](https://huggingface.co/Photoroom/photon-256-t2i-sft) | `torch.bfloat16` | 256x256 | Yes | -| [`Photoroom/photon-512-t2i`](https://huggingface.co/Photoroom/photon-512-t2i) | `torch.bfloat16` | 512x512 | No | -| [`Photoroom/photon-512-t2i-sft`](hhttps://huggingface.co/Photoroom/photon-512-t2i-sft) | `torch.bfloat16` | 512x512 | Yes | -| [`Photoroom/photon-512-t2i-dc-ae`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae) | `torch.bfloat16` | 512x512 | No | -| [`Photoroom/photon-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft) | `torch.bfloat16` | 512x512 | Yes | +| Model | Resolution | Fine-tuned | Distilled | Description | Suggested prompts | Suggested parameters | Recommended dtype | +|:-----:|:-----------------:|:----------:|:----------:|:----------:|:----------:|:----------:|:----------:| +| [`Photoroom/photon-256-t2i`](https://huggingface.co/Photoroom/photon-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | +| [`Photoroom/photon-256-t2i-sft`](https://huggingface.co/Photoroom/photon-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `torch.bfloat16` | +| [`Photoroom/photon-512-t2i`](https://huggingface.co/Photoroom/photon-512-t2i)| 512 | No | No | Base model pre-trained at 512 with Flux VAE |Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | +| [`Photoroom/photon-512-t2i-sft`](hhttps://huggingface.co/Photoroom/photon-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | +| [`Photoroom/photon-512-t2i-sft-distilled`](https://huggingface.co/Photoroom/photon-512-t2i-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/photon-512-t2i-sft`](https://huggingface.co/Photoroom/photon-512-t2i-sft) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` | +| [`Photoroom/photon-512-t2i-dc-ae`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae)| 512 | No | No | Base model pre-trained at 512 with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae)|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | +| [`Photoroom/photon-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae) | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | +| [`Photoroom/photon-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/photon-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft-distilled) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |s Refer to [this](https://huggingface.co/collections/Photoroom/photon-models-68e66254c202ebfab99ad38e) collection for more information. @@ -56,8 +58,8 @@ from diffusers.pipelines.photon import PhotonPipeline pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft", torch_dtype=torch.bfloat16) pipe.to("cuda") -prompt = "A vibrant night sky filled with colorful fireworks, with one large firework burst forming the glowing text “PRX” in bright, sparkling light" -image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0] +prompt = "A front-facing portrait of a lion the golden savanna at sunset." +image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0] image.save("photon_output.png") ``` @@ -75,12 +77,12 @@ from transformers import T5GemmaModel, GemmaTokenizerFast # Load transformer transformer = PhotonTransformer2DModel.from_pretrained( - "Photoroom/photon-512-t2i", subfolder="transformer" + "Photoroom/photon-512-t2i-sft", subfolder="transformer" ).to(dtype=torch.bfloat16) # Load scheduler scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( - "Photoroom/photon-512-t2i", subfolder="scheduler" + "Photoroom/photon-512-t2i-sft", subfolder="scheduler" ) # Load T5Gemma text encoder @@ -136,7 +138,7 @@ import torch from diffusers.pipelines.photon import PhotonPipeline pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft", torch_dtype=torch.bfloat16) pipe = pipe( - prompt = "A vibrant night sky filled with colorful fireworks, with one large firework burst forming the glowing text “PRX” in bright, sparkling light" + prompt = "A front-facing portrait of a lion the golden savanna at sunset." num_inference_steps=28, guidance_scale=4.0, height=512, @@ -153,7 +155,7 @@ For memory-constrained environments: import torch from diffusers.pipelines.photon import PhotonPipeline -pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i", torch_dtype=torch.bfloat16) +pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft", torch_dtype=torch.bfloat16) pipe.enable_model_cpu_offload() # Offload components to CPU when not in use # Or use sequential CPU offload for even lower memory diff --git a/scripts/convert_photon_to_diffusers.py b/scripts/convert_photon_to_diffusers.py index 2f060fd3cdc9..c9c07f191ff9 100644 --- a/scripts/convert_photon_to_diffusers.py +++ b/scripts/convert_photon_to_diffusers.py @@ -172,10 +172,10 @@ def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> Ph return transformer -def create_scheduler_config(output_path: str): +def create_scheduler_config(output_path: str, shift: float): """Create FlowMatchEulerDiscreteScheduler config.""" - scheduler_config = {"_class_name": "FlowMatchEulerDiscreteScheduler", "num_train_timesteps": 1000, "shift": 1.0} + scheduler_config = {"_class_name": "FlowMatchEulerDiscreteScheduler", "num_train_timesteps": 1000, "shift": shift} scheduler_path = os.path.join(output_path, "scheduler") os.makedirs(scheduler_path, exist_ok=True) @@ -207,6 +207,7 @@ def download_and_save_vae(vae_type: str, output_path: str): def download_and_save_text_encoder(output_path: str): """Download and save T5Gemma text encoder and tokenizer.""" from transformers import GemmaTokenizerFast + from transformers.models.t5gemma.modeling_t5gemma import T5GemmaModel text_encoder_path = os.path.join(output_path, "text_encoder") tokenizer_path = os.path.join(output_path, "tokenizer") @@ -214,14 +215,11 @@ def download_and_save_text_encoder(output_path: str): os.makedirs(tokenizer_path, exist_ok=True) print("Downloading T5Gemma model from google/t5gemma-2b-2b-ul2...") - from transformers.models.t5gemma.modeling_t5gemma import T5GemmaModel - t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2") - # Save only the encoder - encoder = t5gemma_model.encoder - encoder.save_pretrained(text_encoder_path) - + # Extract and save only the encoder + t5gemma_encoder = t5gemma_model.encoder + t5gemma_encoder.save_pretrained(text_encoder_path) print(f"✓ Saved T5GemmaEncoder to {text_encoder_path}") print("Downloading tokenizer from google/t5gemma-2b-2b-ul2...") @@ -284,7 +282,7 @@ def main(args): print(f"✓ Saved transformer to {transformer_path}") # Create scheduler config - create_scheduler_config(args.output_path) + create_scheduler_config(args.output_path, args.shift) download_and_save_vae(args.vae_type, args.output_path) download_and_save_text_encoder(args.output_path) @@ -342,6 +340,13 @@ def main(args): default=DEFAULT_RESOLUTION, help="Target resolution for the model (256, 512, or 1024). Affects the transformer's sample_size.", ) + + parser.add_argument( + "--shift", + type=float, + default=3.0, + help="Shift for the scheduler", + ) args = parser.parse_args() From ec2381e039347cf1b485becd4edc1d062192a957 Mon Sep 17 00:00:00 2001 From: davidb Date: Mon, 13 Oct 2025 11:21:27 +0000 Subject: [PATCH 15/52] remove lora support from doc --- docs/source/en/api/pipelines/photon.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/docs/source/en/api/pipelines/photon.md b/docs/source/en/api/pipelines/photon.md index 2f0f6b428a13..62133f93c490 100644 --- a/docs/source/en/api/pipelines/photon.md +++ b/docs/source/en/api/pipelines/photon.md @@ -14,9 +14,6 @@ # PhotonPipeline -
- LoRA -
Photon is a text-to-image diffusion model using simplified MMDIT architecture with flow matching for efficient high-quality image generation. The model uses T5Gemma as the text encoder and supports either Flux VAE (AutoencoderKL) or DC-AE (AutoencoderDC) for latent compression. From b445016b000a59f863be3aa52f044aa5d4ccfaa4 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Mon, 13 Oct 2025 11:23:22 +0000 Subject: [PATCH 16/52] Rename EmbedND for PhotoEmbedND --- src/diffusers/models/transformers/transformer_photon.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_photon.py b/src/diffusers/models/transformers/transformer_photon.py index 533eb356e006..c7b5ca518651 100644 --- a/src/diffusers/models/transformers/transformer_photon.py +++ b/src/diffusers/models/transformers/transformer_photon.py @@ -136,8 +136,8 @@ def __call__( return attn_output -# copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py -class EmbedND(nn.Module): +# inspired from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py +class PhotoEmbedND(nn.Module): r""" N-dimensional rotary positional embedding. @@ -672,7 +672,7 @@ def __init__( self.hidden_size = hidden_size self.num_heads = num_heads - self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim) + self.pe_embedder = PhotoEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim) self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True) self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) self.txt_in = nn.Linear(context_in_dim, self.hidden_size) From 7ae1af975e6d3562739e3cd30d00e1bef542eb7b Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Mon, 13 Oct 2025 11:29:43 +0000 Subject: [PATCH 17/52] remove modulation dataclass --- .../models/transformers/transformer_photon.py | 44 ++++++++----------- 1 file changed, 19 insertions(+), 25 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_photon.py b/src/diffusers/models/transformers/transformer_photon.py index c7b5ca518651..46565fd1d714 100644 --- a/src/diffusers/models/transformers/transformer_photon.py +++ b/src/diffusers/models/transformers/transformer_photon.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, Union import torch @@ -228,29 +227,20 @@ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: k = self.key_norm(k) return q.to(v), k.to(v) - -@dataclass -class ModulationOut: - shift: Tensor - scale: Tensor - gate: Tensor - - class Modulation(nn.Module): r""" Modulation network that generates scale, shift, and gating parameters. Given an input vector, the module projects it through a linear layer to produce six chunks, which are grouped into - two `ModulationOut` objects. + two tuples `(shift, scale, gate)`. Parameters: dim (`int`): Dimensionality of the input vector. The output will have `6 * dim` features internally. Returns: - (`ModulationOut`, `ModulationOut`): - A tuple of two modulation outputs. Each `ModulationOut` contains three components (e.g., scale, shift, - gate). + ((`torch.Tensor`, `torch.Tensor`, `torch.Tensor`), (`torch.Tensor`, `torch.Tensor`, `torch.Tensor`)): + Two tuples `(shift, scale, gate)`. """ def __init__(self, dim: int): @@ -259,9 +249,9 @@ def __init__(self, dim: int): nn.init.constant_(self.lin.weight, 0) nn.init.constant_(self.lin.bias, 0) - def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut]: + def forward(self, vec: Tensor) -> tuple[tuple[Tensor, Tensor, Tensor], tuple[Tensor, Tensor, Tensor]]: out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(6, dim=-1) - return ModulationOut(*out[:3]), ModulationOut(*out[3:]) + return tuple(out[:3]), tuple(out[3:]) class PhotonBlock(nn.Module): @@ -301,7 +291,7 @@ class PhotonBlock(nn.Module): modulation (`Modulation`): Produces scale/shift/gating parameters for modulated layers. - Methods: + Methods: attn_forward(img, txt, pe, modulation, attention_mask=None): Compute cross-attention between image and text tokens, with optional attention masking. @@ -312,8 +302,8 @@ class PhotonBlock(nn.Module): Text tokens of shape `(B, L_txt, hidden_size)`. pe (`torch.Tensor`): Rotary positional embeddings to apply to queries and keys. - modulation (`ModulationOut`): - Scale and shift parameters for modulating image tokens. + modulation ((`torch.Tensor`, `torch.Tensor`, `torch.Tensor`)): + Tuple `(shift, scale, gate)` for modulating image tokens. attention_mask (`torch.Tensor`, *optional*): Boolean mask of shape `(B, L_txt)` where 0 marks padding. @@ -371,11 +361,12 @@ def _attn_forward( img: Tensor, txt: Tensor, pe: Tensor, - modulation: ModulationOut, + modulation: tuple[Tensor, Tensor, Tensor], attention_mask: None | Tensor = None, ) -> Tensor: # image tokens proj and norm - img_mod = (1 + modulation.scale) * self.img_pre_norm(img) + modulation.shift + shift, scale, _gate = modulation + img_mod = (1 + scale) * self.img_pre_norm(img) + shift img_qkv = self.img_qkv_proj(img_mod) # Native PyTorch equivalent of: rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) @@ -433,8 +424,9 @@ def _attn_forward( return attn - def _ffn_forward(self, x: Tensor, modulation: ModulationOut) -> Tensor: - x = (1 + modulation.scale) * self.post_attention_layernorm(x) + modulation.shift + def _ffn_forward(self, x: Tensor, modulation: tuple[Tensor, Tensor, Tensor]) -> Tensor: + shift, scale, _gate = modulation + x = (1 + scale) * self.post_attention_layernorm(x) + shift return self.down_proj(self.mlp_act(self.gate_proj(x)) * self.up_proj(x)) def forward( @@ -470,15 +462,17 @@ def forward( """ mod_attn, mod_mlp = self.modulation(vec) + attn_shift, attn_scale, attn_gate = mod_attn + mlp_shift, mlp_scale, mlp_gate = mod_mlp - img = img + mod_attn.gate * self._attn_forward( + img = img + attn_gate * self._attn_forward( img, txt, pe, - mod_attn, + (attn_shift, attn_scale, attn_gate), attention_mask=attention_mask, ) - img = img + mod_mlp.gate * self._ffn_forward(img, mod_mlp) + img = img + mlp_gate * self._ffn_forward(img, (mlp_shift, mlp_scale, mlp_gate)) return img From a8216f779080c9af7cf410585fce4629d57dc97a Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Mon, 13 Oct 2025 11:39:06 +0000 Subject: [PATCH 18/52] put _attn_forward and _ffn_forward logic in PhotonBlock's forward --- .../models/transformers/transformer_photon.py | 144 ++++++------------ 1 file changed, 48 insertions(+), 96 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_photon.py b/src/diffusers/models/transformers/transformer_photon.py index 46565fd1d714..7cb3f1e000e7 100644 --- a/src/diffusers/models/transformers/transformer_photon.py +++ b/src/diffusers/models/transformers/transformer_photon.py @@ -292,24 +292,7 @@ class PhotonBlock(nn.Module): Produces scale/shift/gating parameters for modulated layers. Methods: - attn_forward(img, txt, pe, modulation, attention_mask=None): - Compute cross-attention between image and text tokens, with optional attention masking. - - Parameters: - img (`torch.Tensor`): - Image tokens of shape `(B, L_img, hidden_size)`. - txt (`torch.Tensor`): - Text tokens of shape `(B, L_txt, hidden_size)`. - pe (`torch.Tensor`): - Rotary positional embeddings to apply to queries and keys. - modulation ((`torch.Tensor`, `torch.Tensor`, `torch.Tensor`)): - Tuple `(shift, scale, gate)` for modulating image tokens. - attention_mask (`torch.Tensor`, *optional*): - Boolean mask of shape `(B, L_txt)` where 0 marks padding. - - Returns: - `torch.Tensor`: - Attention output of shape `(B, L_img, hidden_size)`. + The forward method performs cross-attention and the MLP inline with modulation. """ def __init__( @@ -356,78 +339,7 @@ def __init__( self.modulation = Modulation(hidden_size) - def _attn_forward( - self, - img: Tensor, - txt: Tensor, - pe: Tensor, - modulation: tuple[Tensor, Tensor, Tensor], - attention_mask: None | Tensor = None, - ) -> Tensor: - # image tokens proj and norm - shift, scale, _gate = modulation - img_mod = (1 + scale) * self.img_pre_norm(img) + shift - - img_qkv = self.img_qkv_proj(img_mod) - # Native PyTorch equivalent of: rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) - B, L, _ = img_qkv.shape - img_qkv = img_qkv.reshape(B, L, 3, self.num_heads, self.head_dim) # (B, L, K, H, D) - img_qkv = img_qkv.permute(2, 0, 3, 1, 4) # (K, B, H, L, D) - img_q, img_k, img_v = img_qkv[0], img_qkv[1], img_qkv[2] - img_q, img_k = self.qk_norm(img_q, img_k, img_v) - - # txt tokens proj and norm - txt_kv = self.txt_kv_proj(txt) - # Native PyTorch equivalent of: rearrange(txt_kv, "B L (K H D) -> K B H L D", K=2, H=self.num_heads) - B, L, _ = txt_kv.shape - txt_kv = txt_kv.reshape(B, L, 2, self.num_heads, self.head_dim) # (B, L, K, H, D) - txt_kv = txt_kv.permute(2, 0, 3, 1, 4) # (K, B, H, L, D) - txt_k, txt_v = txt_kv[0], txt_kv[1] - txt_k = self.k_norm(txt_k) - - # compute attention - img_q, img_k = apply_rope(img_q, pe), apply_rope(img_k, pe) - k = torch.cat((txt_k, img_k), dim=2) - v = torch.cat((txt_v, img_v), dim=2) - - # build multiplicative 0/1 mask for provided attention_mask over [cond?, text, image] keys - attn_mask: Tensor | None = None - if attention_mask is not None: - bs, _, l_img, _ = img_q.shape - l_txt = txt_k.shape[2] - - if attention_mask.dim() != 2: - raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}") - if attention_mask.shape[-1] != l_txt: - raise ValueError(f"attention_mask last dim {attention_mask.shape[-1]} must equal text length {l_txt}") - - device = img_q.device - - ones_img = torch.ones((bs, l_img), dtype=torch.bool, device=device) - - mask_parts = [ - attention_mask.to(torch.bool), - ones_img, - ] - joint_mask = torch.cat(mask_parts, dim=-1) # (B, L_all) - - # repeat across heads and query positions - attn_mask = joint_mask[:, None, None, :].expand(-1, self.num_heads, l_img, -1) # (B,H,L_img,L_all) - - kv_packed = torch.cat([k, v], dim=-1) - attn = self.attention( - hidden_states=img_q, - encoder_hidden_states=kv_packed, - attention_mask=attn_mask, - ) - - return attn - - def _ffn_forward(self, x: Tensor, modulation: tuple[Tensor, Tensor, Tensor]) -> Tensor: - shift, scale, _gate = modulation - x = (1 + scale) * self.post_attention_layernorm(x) + shift - return self.down_proj(self.mlp_act(self.gate_proj(x)) * self.up_proj(x)) def forward( self, @@ -465,14 +377,54 @@ def forward( attn_shift, attn_scale, attn_gate = mod_attn mlp_shift, mlp_scale, mlp_gate = mod_mlp - img = img + attn_gate * self._attn_forward( - img, - txt, - pe, - (attn_shift, attn_scale, attn_gate), - attention_mask=attention_mask, + # Inline attention forward + img_mod = (1 + attn_scale) * self.img_pre_norm(img) + attn_shift + + img_qkv = self.img_qkv_proj(img_mod) + B, L, _ = img_qkv.shape + img_qkv = img_qkv.reshape(B, L, 3, self.num_heads, self.head_dim) + img_qkv = img_qkv.permute(2, 0, 3, 1, 4) + img_q, img_k, img_v = img_qkv[0], img_qkv[1], img_qkv[2] + img_q, img_k = self.qk_norm(img_q, img_k, img_v) + + txt_kv = self.txt_kv_proj(txt) + B, L, _ = txt_kv.shape + txt_kv = txt_kv.reshape(B, L, 2, self.num_heads, self.head_dim) + txt_kv = txt_kv.permute(2, 0, 3, 1, 4) + txt_k, txt_v = txt_kv[0], txt_kv[1] + txt_k = self.k_norm(txt_k) + + img_q, img_k = apply_rope(img_q, pe), apply_rope(img_k, pe) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn_mask_tensor: Tensor | None = None + if attention_mask is not None: + bs, _, l_img, _ = img_q.shape + l_txt = txt_k.shape[2] + + if attention_mask.dim() != 2: + raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}") + if attention_mask.shape[-1] != l_txt: + raise ValueError(f"attention_mask last dim {attention_mask.shape[-1]} must equal text length {l_txt}") + + device = img_q.device + ones_img = torch.ones((bs, l_img), dtype=torch.bool, device=device) + joint_mask = torch.cat([attention_mask.to(torch.bool), ones_img], dim=-1) + attn_mask_tensor = joint_mask[:, None, None, :].expand(-1, self.num_heads, l_img, -1) + + kv_packed = torch.cat([k, v], dim=-1) + attn_out = self.attention( + hidden_states=img_q, + encoder_hidden_states=kv_packed, + attention_mask=attn_mask_tensor, ) - img = img + mlp_gate * self._ffn_forward(img, (mlp_shift, mlp_scale, mlp_gate)) + + img = img + attn_gate * attn_out + + # Inline FFN forward + x = (1 + mlp_scale) * self.post_attention_layernorm(img) + mlp_shift + img = img + mlp_gate * (self.down_proj(self.mlp_act(self.gate_proj(x)) * self.up_proj(x))) return img From 73261c8efc4d45ce78be023b0bfcf616f7913ac6 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Mon, 13 Oct 2025 11:40:52 +0000 Subject: [PATCH 19/52] renam LastLayer for FinalLayer --- src/diffusers/models/transformers/transformer_photon.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_photon.py b/src/diffusers/models/transformers/transformer_photon.py index 7cb3f1e000e7..cd7e2cb3ef06 100644 --- a/src/diffusers/models/transformers/transformer_photon.py +++ b/src/diffusers/models/transformers/transformer_photon.py @@ -428,7 +428,7 @@ def forward( return img -class LastLayer(nn.Module): +class FinalLayer(nn.Module): r""" Final projection layer with adaptive LayerNorm modulation. @@ -634,7 +634,7 @@ def __init__( ] ) - self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + self.final_layer = FinalLayer(self.hidden_size, 1, self.out_channels) self.gradient_checkpointing = False From f8f45fdc4817bae419df2e78aef3999fdea4369c Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Mon, 13 Oct 2025 11:45:45 +0000 Subject: [PATCH 20/52] remove lora related code --- .../models/transformers/transformer_photon.py | 27 +++---------------- 1 file changed, 3 insertions(+), 24 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_photon.py b/src/diffusers/models/transformers/transformer_photon.py index cd7e2cb3ef06..b5a89d642d43 100644 --- a/src/diffusers/models/transformers/transformer_photon.py +++ b/src/diffusers/models/transformers/transformer_photon.py @@ -19,7 +19,7 @@ from torch.nn.functional import fold, unfold from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import logging from ..attention import AttentionMixin from ..attention_processor import Attention from ..embeddings import get_timestep_embedding @@ -515,8 +515,7 @@ def seq2img(seq: Tensor, patch_size: int, shape: Tensor) -> Tensor: class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): r""" - Transformer-based 2D model for text to image generation. It supports attention processor injection and LoRA - scaling. + Transformer-based 2D model for text to image generation. Parameters: in_channels (`int`, *optional*, defaults to 16): @@ -677,8 +676,7 @@ def forward( cross_attn_mask (`torch.Tensor`, *optional*): Boolean mask of shape `(B, L_txt)`, where `0` marks padding in the text sequence. attention_kwargs (`dict`, *optional*): - Additional arguments passed to attention layers. If using the PEFT backend, the key `"scale"` controls - LoRA scaling (default: 1.0). + Additional arguments passed to attention layers. return_dict (`bool`, *optional*, defaults to `True`): Whether to return a `Transformer2DModelOutput` or a tuple. @@ -687,21 +685,6 @@ def forward( - `sample` (`torch.Tensor`): Output latent image of shape `(B, C, H, W)`. """ - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - # Process text conditioning txt = self.txt_in(cross_attn_conditioning) @@ -741,10 +724,6 @@ def forward( img = self.final_layer(img, vec) output = seq2img(img, self.patch_size, image_latent.shape) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) From e65de5fc55c136b8e5fe348d71da75effaf63a19 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Mon, 13 Oct 2025 11:49:25 +0000 Subject: [PATCH 21/52] rename vae_spatial_compression_ratio for vae_scale_factor --- src/diffusers/pipelines/photon/pipeline_photon.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/photon/pipeline_photon.py b/src/diffusers/pipelines/photon/pipeline_photon.py index c09f2b16081e..aa1d10070dd2 100644 --- a/src/diffusers/pipelines/photon/pipeline_photon.py +++ b/src/diffusers/pipelines/photon/pipeline_photon.py @@ -276,10 +276,10 @@ def __init__( self.register_to_config(default_sample_size=default_sample_size) - self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) @property - def vae_spatial_compression_ratio(self): + def vae_scale_factor(self): if hasattr(self.vae, "spatial_compression_ratio"): return self.vae.spatial_compression_ratio else: # Flux VAE @@ -303,7 +303,7 @@ def prepare_latents( ): """Prepare initial latents for the diffusion process.""" if latents is None: - spatial_compression = self.vae_spatial_compression_ratio + spatial_compression = self.vae_scale_factor latent_height, latent_width = ( height // spatial_compression, width // spatial_compression, @@ -387,7 +387,7 @@ def check_inputs( callback_on_step_end_tensor_inputs: Optional[List[str]] = None, ): """Check that all inputs are in correct format.""" - spatial_compression = self.vae_spatial_compression_ratio + spatial_compression = self.vae_scale_factor if height % spatial_compression != 0 or width % spatial_compression != 0: raise ValueError( f"`height` and `width` have to be divisible by {spatial_compression} but are {height} and {width}." From b765031ff52725174e29de2bd92aa3d30a292700 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Mon, 13 Oct 2025 12:12:28 +0000 Subject: [PATCH 22/52] support prompt_embeds in call --- .../pipelines/photon/pipeline_photon.py | 77 +++++++++++++++++-- 1 file changed, 72 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/photon/pipeline_photon.py b/src/diffusers/pipelines/photon/pipeline_photon.py index aa1d10070dd2..ea9844fee284 100644 --- a/src/diffusers/pipelines/photon/pipeline_photon.py +++ b/src/diffusers/pipelines/photon/pipeline_photon.py @@ -243,7 +243,7 @@ class PhotonPipeline( """ model_cpu_offload_seq = "text_encoder->transformer->vae" - _callback_tensor_inputs = ["latents"] + _callback_tensor_inputs = ["latents", "prompt_embeds"] _optional_components = ["vae"] def __init__( @@ -320,8 +320,21 @@ def encode_prompt( device: torch.device, do_classifier_free_guidance: bool = True, negative_prompt: str = "", + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.BoolTensor] = None, + negative_prompt_attention_mask: Optional[torch.BoolTensor] = None, ): - """Encode text prompt using standard text encoder and tokenizer.""" + """Encode text prompt using standard text encoder and tokenizer, or use precomputed embeddings.""" + if prompt_embeds is not None: + # Use precomputed embeddings + return ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds if do_classifier_free_guidance else None, + negative_prompt_attention_mask if do_classifier_free_guidance else None, + ) + if isinstance(prompt, str): prompt = [prompt] @@ -385,8 +398,30 @@ def check_inputs( width: int, guidance_scale: float, callback_on_step_end_tensor_inputs: Optional[List[str]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, ): """Check that all inputs are in correct format.""" + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + + if prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and guidance_scale > 1.0 and negative_prompt_embeds is None: + raise ValueError( + "When `prompt_embeds` is provided and `guidance_scale > 1.0`, " + "`negative_prompt_embeds` must also be provided for classifier-free guidance." + ) + spatial_compression = self.vae_scale_factor if height % spatial_compression != 0 or width % spatial_compression != 0: raise ValueError( @@ -401,6 +436,13 @@ def check_inputs( f"`callback_on_step_end_tensor_inputs` has to be a list but is {callback_on_step_end_tensor_inputs}" ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -414,6 +456,10 @@ def __call__( num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.BoolTensor] = None, + negative_prompt_attention_mask: Optional[torch.BoolTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, use_resolution_binning: bool = True, @@ -425,7 +471,7 @@ def __call__( Args: prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds` instead. height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. @@ -453,6 +499,19 @@ def __call__( Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided and `guidance_scale > 1`, negative embeddings will be generated from an + empty string. + prompt_attention_mask (`torch.BoolTensor`, *optional*): + Pre-generated attention mask for `prompt_embeds`. If not provided, attention mask will be generated + from `prompt` input argument. + negative_prompt_attention_mask (`torch.BoolTensor`, *optional*): + Pre-generated attention mask for `negative_prompt_embeds`. If not provided and `guidance_scale > 1`, + attention mask will be generated from an empty string. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -503,6 +562,8 @@ def __call__( width, guidance_scale, callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, ) if prompt is not None and isinstance(prompt, str): @@ -510,7 +571,7 @@ def __call__( elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: - raise ValueError("prompt must be provided as a string or list of strings") + batch_size = prompt_embeds.shape[0] device = self._execution_device @@ -518,7 +579,13 @@ def __call__( # 2. Encode input prompt text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask = self.encode_prompt( - prompt, device, do_classifier_free_guidance=self.do_classifier_free_guidance + prompt, + device, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, ) # 3. Prepare timesteps From 87cd6d26f1358aa9a2f9d7d23da0f91c523c3b55 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Mon, 13 Oct 2025 12:48:53 +0000 Subject: [PATCH 23/52] move xattention conditionning out computation out of the denoising loop --- .../pipelines/photon/pipeline_photon.py | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/src/diffusers/pipelines/photon/pipeline_photon.py b/src/diffusers/pipelines/photon/pipeline_photon.py index ea9844fee284..9e47f8ebc0ed 100644 --- a/src/diffusers/pipelines/photon/pipeline_photon.py +++ b/src/diffusers/pipelines/photon/pipeline_photon.py @@ -265,6 +265,7 @@ def __init__( self.text_encoder = text_encoder self.tokenizer = tokenizer self.text_preprocessor = TextPreprocessor() + self.default_sample_size = default_sample_size self.register_modules( transformer=transformer, @@ -274,7 +275,7 @@ def __init__( vae=vae, ) - self.register_to_config(default_sample_size=default_sample_size) + self.register_to_config(default_sample_size=self.default_sample_size) self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) @@ -539,13 +540,12 @@ def __call__( generated images. """ - # 0. Default height and width from config - default_sample_size = getattr(self.config, "default_sample_size", DEFAULT_RESOLUTION) - height = height or default_sample_size - width = width or default_sample_size + # 0. Set height and width + height = height or self.default_sample_size + width = width or self.default_sample_size if use_resolution_binning: - if default_sample_size <= 256: + if self.default_sample_size <= 256: aspect_ratio_bin = ASPECT_RATIO_256_BIN else: aspect_ratio_bin = ASPECT_RATIO_512_BIN @@ -616,7 +616,17 @@ def __call__( if accepts_eta: extra_step_kwargs["eta"] = 0.0 - # 6. Denoising loop + # 6. Prepare cross-attention embeddings and masks + if self.do_classifier_free_guidance: + ca_embed = torch.cat([uncond_text_embeddings, text_embeddings], dim=0) + ca_mask = None + if cross_attn_mask is not None and uncond_cross_attn_mask is not None: + ca_mask = torch.cat([uncond_cross_attn_mask, cross_attn_mask], dim=0) + else: + ca_embed = text_embeddings + ca_mask = cross_attn_mask + + # 7. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -624,17 +634,10 @@ def __call__( # Duplicate latents if using classifier-free guidance if self.do_classifier_free_guidance: latents_in = torch.cat([latents, latents], dim=0) - # Cross-attention batch (uncond, cond) - ca_embed = torch.cat([uncond_text_embeddings, text_embeddings], dim=0) - ca_mask = None - if cross_attn_mask is not None and uncond_cross_attn_mask is not None: - ca_mask = torch.cat([uncond_cross_attn_mask, cross_attn_mask], dim=0) # Normalize timestep for the transformer t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).repeat(2).to(device) else: latents_in = latents - ca_embed = text_embeddings - ca_mask = cross_attn_mask # Normalize timestep for the transformer t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).to(device) From 1b61bb26952625a6bf19c62f7617022aa9ed585c Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Mon, 13 Oct 2025 12:57:44 +0000 Subject: [PATCH 24/52] add negative prompts --- src/diffusers/pipelines/photon/pipeline_photon.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/diffusers/pipelines/photon/pipeline_photon.py b/src/diffusers/pipelines/photon/pipeline_photon.py index 9e47f8ebc0ed..5164d576f829 100644 --- a/src/diffusers/pipelines/photon/pipeline_photon.py +++ b/src/diffusers/pipelines/photon/pipeline_photon.py @@ -449,6 +449,7 @@ def check_inputs( def __call__( self, prompt: Union[str, List[str]] = None, + negative_prompt: str = "", height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 28, @@ -474,6 +475,9 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds` instead. + negative_prompt (`str`, *optional*, defaults to `""`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): @@ -582,6 +586,7 @@ def __call__( prompt, device, do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, prompt_attention_mask=prompt_attention_mask, From 9ca25fa370336ca45356af7e4fa6c9f83c59576f Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Mon, 13 Oct 2025 13:07:28 +0000 Subject: [PATCH 25/52] Use _import_structure for lazy loading --- src/diffusers/pipelines/photon/__init__.py | 63 +++++++++++++++++++--- 1 file changed, 56 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/photon/__init__.py b/src/diffusers/pipelines/photon/__init__.py index 6f376e440fb2..38e85c7285fe 100644 --- a/src/diffusers/pipelines/photon/__init__.py +++ b/src/diffusers/pipelines/photon/__init__.py @@ -1,16 +1,65 @@ from typing import TYPE_CHECKING -from .pipeline_output import PhotonPipelineOutput -from .pipeline_photon import PhotonPipeline +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) -__all__ = ["PhotonPipeline", "PhotonPipelineOutput"] +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["PhotonPipelineOutput"]} -# Make T5GemmaEncoder importable from this module for pipeline loading -if TYPE_CHECKING: - from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - try: + _import_structure["pipeline_photon"] = ["PhotonPipeline"] + +# Import T5GemmaEncoder for pipeline loading compatibility +try: + if is_transformers_available(): from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder + + _additional_imports["T5GemmaEncoder"] = T5GemmaEncoder +except ImportError: + pass + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_output import PhotonPipelineOutput + from .pipeline_photon import PhotonPipeline + + try: + if is_transformers_available(): + from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder except ImportError: pass + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) From 98ed747743088475d57e421cc771e4b4411e28e2 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Mon, 13 Oct 2025 13:09:00 +0000 Subject: [PATCH 26/52] make quality + style --- scripts/convert_photon_to_diffusers.py | 2 +- src/diffusers/models/transformers/transformer_photon.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/scripts/convert_photon_to_diffusers.py b/scripts/convert_photon_to_diffusers.py index c9c07f191ff9..0f24e5036977 100644 --- a/scripts/convert_photon_to_diffusers.py +++ b/scripts/convert_photon_to_diffusers.py @@ -340,7 +340,7 @@ def main(args): default=DEFAULT_RESOLUTION, help="Target resolution for the model (256, 512, or 1024). Affects the transformer's sample_size.", ) - + parser.add_argument( "--shift", type=float, diff --git a/src/diffusers/models/transformers/transformer_photon.py b/src/diffusers/models/transformers/transformer_photon.py index b5a89d642d43..b77e2f8d6f1c 100644 --- a/src/diffusers/models/transformers/transformer_photon.py +++ b/src/diffusers/models/transformers/transformer_photon.py @@ -227,6 +227,7 @@ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: k = self.key_norm(k) return q.to(v), k.to(v) + class Modulation(nn.Module): r""" Modulation network that generates scale, shift, and gating parameters. @@ -339,8 +340,6 @@ def __init__( self.modulation = Modulation(hidden_size) - - def forward( self, img: Tensor, From 9819ff1f6473232cca39711da60932fb1987def6 Mon Sep 17 00:00:00 2001 From: DavidBert Date: Wed, 15 Oct 2025 12:31:25 +0000 Subject: [PATCH 27/52] add pipeline test + corresponding fixes --- .../models/transformers/transformer_photon.py | 5 +- src/diffusers/pipelines/photon/__init__.py | 10 +- .../pipelines/photon/pipeline_photon.py | 94 +++++-- tests/pipelines/photon/__init__.py | 0 .../pipelines/photon/test_pipeline_photon.py | 258 ++++++++++++++++++ 5 files changed, 344 insertions(+), 23 deletions(-) create mode 100644 tests/pipelines/photon/__init__.py create mode 100644 tests/pipelines/photon/test_pipeline_photon.py diff --git a/src/diffusers/models/transformers/transformer_photon.py b/src/diffusers/models/transformers/transformer_photon.py index b77e2f8d6f1c..a5b20619064c 100644 --- a/src/diffusers/models/transformers/transformer_photon.py +++ b/src/diffusers/models/transformers/transformer_photon.py @@ -74,6 +74,8 @@ def apply_rope(xq: Tensor, freqs_cis: Tensor) -> Tensor: Tensor of the same shape as `xq` with rotary embeddings applied. """ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + # Ensure freqs_cis is on the same device as queries to avoid device mismatches with offloading + freqs_cis = freqs_cis.to(device=xq.device, dtype=xq_.dtype) xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] return xq_out.reshape(*xq.shape).type_as(xq) @@ -409,7 +411,8 @@ def forward( device = img_q.device ones_img = torch.ones((bs, l_img), dtype=torch.bool, device=device) - joint_mask = torch.cat([attention_mask.to(torch.bool), ones_img], dim=-1) + attention_mask = attention_mask.to(device=device, dtype=torch.bool) + joint_mask = torch.cat([attention_mask, ones_img], dim=-1) attn_mask_tensor = joint_mask[:, None, None, :].expand(-1, self.num_heads, l_img, -1) kv_packed = torch.cat([k, v], dim=-1) diff --git a/src/diffusers/pipelines/photon/__init__.py b/src/diffusers/pipelines/photon/__init__.py index 38e85c7285fe..e21e31d4225f 100644 --- a/src/diffusers/pipelines/photon/__init__.py +++ b/src/diffusers/pipelines/photon/__init__.py @@ -27,9 +27,13 @@ # Import T5GemmaEncoder for pipeline loading compatibility try: if is_transformers_available(): + import transformers from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder _additional_imports["T5GemmaEncoder"] = T5GemmaEncoder + # Patch transformers module directly for serialization + if not hasattr(transformers, "T5GemmaEncoder"): + transformers.T5GemmaEncoder = T5GemmaEncoder except ImportError: pass @@ -43,12 +47,6 @@ from .pipeline_output import PhotonPipelineOutput from .pipeline_photon import PhotonPipeline - try: - if is_transformers_available(): - from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder - except ImportError: - pass - else: import sys diff --git a/src/diffusers/pipelines/photon/pipeline_photon.py b/src/diffusers/pipelines/photon/pipeline_photon.py index 5164d576f829..0c2384657988 100644 --- a/src/diffusers/pipelines/photon/pipeline_photon.py +++ b/src/diffusers/pipelines/photon/pipeline_photon.py @@ -266,6 +266,7 @@ def __init__( self.tokenizer = tokenizer self.text_preprocessor = TextPreprocessor() self.default_sample_size = default_sample_size + self._guidance_scale = 1.0 self.register_modules( transformer=transformer, @@ -277,10 +278,15 @@ def __init__( self.register_to_config(default_sample_size=self.default_sample_size) - self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + if vae is not None: + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + else: + self.image_processor = None @property def vae_scale_factor(self): + if self.vae is None: + return 8 if hasattr(self.vae, "spatial_compression_ratio"): return self.vae.spatial_compression_ratio else: # Flux VAE @@ -291,6 +297,10 @@ def do_classifier_free_guidance(self): """Check if classifier-free guidance is enabled based on guidance scale.""" return self._guidance_scale > 1.0 + @property + def guidance_scale(self): + return self._guidance_scale + def prepare_latents( self, batch_size: int, @@ -318,28 +328,58 @@ def prepare_latents( def encode_prompt( self, prompt: Union[str, List[str]], - device: torch.device, + device: Optional[torch.device] = None, do_classifier_free_guidance: bool = True, negative_prompt: str = "", + num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, prompt_attention_mask: Optional[torch.BoolTensor] = None, negative_prompt_attention_mask: Optional[torch.BoolTensor] = None, ): """Encode text prompt using standard text encoder and tokenizer, or use precomputed embeddings.""" - if prompt_embeds is not None: - # Use precomputed embeddings - return ( - prompt_embeds, - prompt_attention_mask, - negative_prompt_embeds if do_classifier_free_guidance else None, - negative_prompt_attention_mask if do_classifier_free_guidance else None, + if device is None: + device = self._execution_device + + if prompt_embeds is None: + if isinstance(prompt, str): + prompt = [prompt] + # Encode the prompts + text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask = ( + self._encode_prompt_standard(prompt, device, do_classifier_free_guidance, negative_prompt) ) - - if isinstance(prompt, str): - prompt = [prompt] - - return self._encode_prompt_standard(prompt, device, do_classifier_free_guidance, negative_prompt) + prompt_embeds = text_embeddings + prompt_attention_mask = cross_attn_mask + negative_prompt_embeds = uncond_text_embeddings + negative_prompt_attention_mask = uncond_cross_attn_mask + + # Duplicate embeddings for each generation per prompt + if num_images_per_prompt > 1: + # Repeat prompt embeddings + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if prompt_attention_mask is not None: + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + # Repeat negative embeddings if using CFG + if do_classifier_free_guidance and negative_prompt_embeds is not None: + bs_embed, seq_len, _ = negative_prompt_embeds.shape + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if negative_prompt_attention_mask is not None: + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + + return ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds if do_classifier_free_guidance else None, + negative_prompt_attention_mask if do_classifier_free_guidance else None, + ) def _tokenize_prompts(self, prompts: List[str], device: torch.device): """Tokenize and clean prompts.""" @@ -549,6 +589,11 @@ def __call__( width = width or self.default_sample_size if use_resolution_binning: + if self.image_processor is None: + raise ValueError( + "Resolution binning requires a VAE with image_processor, but VAE is not available. " + "Set use_resolution_binning=False or provide a VAE." + ) if self.default_sample_size <= 256: aspect_ratio_bin = ASPECT_RATIO_256_BIN else: @@ -570,6 +615,12 @@ def __call__( negative_prompt_embeds, ) + if self.vae is None and output_type not in ["latent", "pt"]: + raise ValueError( + f"VAE is required for output_type='{output_type}' but it is not available. " + "Either provide a VAE or set output_type='latent' or 'pt' to get latent outputs." + ) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -577,6 +628,7 @@ def __call__( else: batch_size = prompt_embeds.shape[0] + # Use execution device (handles offloading scenarios including group offloading) device = self._execution_device self._guidance_scale = guidance_scale @@ -587,11 +639,15 @@ def __call__( device, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, ) + # Expose standard names for callbacks parity + prompt_embeds = text_embeddings + negative_prompt_embeds = uncond_text_embeddings # 3. Prepare timesteps if timesteps is not None: @@ -602,8 +658,14 @@ def __call__( self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps + self.num_timesteps = len(timesteps) + # 4. Prepare latent variables - num_channels_latents = self.vae.config.latent_channels + if self.vae is not None: + num_channels_latents = self.vae.config.latent_channels + else: + # When vae is None, get latent channels from transformer + num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, @@ -675,7 +737,7 @@ def __call__( progress_bar.update() # 8. Post-processing - if output_type == "latent": + if output_type == "latent" or (output_type == "pt" and self.vae is None): image = latents else: # Unscale latents for VAE (supports both AutoencoderKL and AutoencoderDC) diff --git a/tests/pipelines/photon/__init__.py b/tests/pipelines/photon/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/photon/test_pipeline_photon.py b/tests/pipelines/photon/test_pipeline_photon.py new file mode 100644 index 000000000000..9ac361c75b2e --- /dev/null +++ b/tests/pipelines/photon/test_pipeline_photon.py @@ -0,0 +1,258 @@ +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer +from transformers.models.t5gemma.configuration_t5gemma import T5GemmaConfig, T5GemmaModuleConfig +from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder + +from diffusers.models import AutoencoderDC, AutoencoderKL +from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel +from diffusers.pipelines.photon.pipeline_photon import PhotonPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler + +from ..pipeline_params import TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = PhotonPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = frozenset(["prompt", "negative_prompt", "num_images_per_prompt"]) + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = True + + @classmethod + def setUpClass(cls): + # Ensure PhotonPipeline has an _execution_device property expected by __call__ + if not isinstance(getattr(PhotonPipeline, "_execution_device", None), property): + try: + setattr(PhotonPipeline, "_execution_device", property(lambda self: torch.device("cpu"))) + except Exception: + pass + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = PhotonTransformer2DModel( + patch_size=1, + in_channels=4, + context_in_dim=8, + hidden_size=8, + mlp_ratio=2.0, + num_heads=2, + depth=1, + axes_dim=[2, 2], + ) + + torch.manual_seed(0) + vae = AutoencoderKL( + sample_size=32, + in_channels=3, + out_channels=3, + block_out_channels=(4,), + layers_per_block=1, + latent_channels=4, + norm_num_groups=1, + use_quant_conv=False, + use_post_quant_conv=False, + shift_factor=0.0, + scaling_factor=1.0, + ).eval() + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler() + + torch.manual_seed(0) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/dummy-gemma") + tokenizer.model_max_length = 64 + + torch.manual_seed(0) + + encoder_params = dict( + vocab_size=tokenizer.vocab_size, + hidden_size=8, + intermediate_size=16, + num_hidden_layers=1, + num_attention_heads=2, + num_key_value_heads=1, + head_dim=4, + max_position_embeddings=64, + layer_types=["full_attention"], + attention_bias=False, + attention_dropout=0.0, + dropout_rate=0.0, + hidden_activation="gelu_pytorch_tanh", + rms_norm_eps=1e-06, + attn_logit_softcapping=50.0, + final_logit_softcapping=30.0, + query_pre_attn_scalar=4, + rope_theta=10000.0, + sliding_window=4096, + ) + encoder_config = T5GemmaModuleConfig(**encoder_params) + text_encoder_config = T5GemmaConfig(encoder=encoder_config, is_encoder_decoder=False, **encoder_params) + text_encoder = T5GemmaEncoder(text_encoder_config) + + return { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + return { + "prompt": "", + "negative_prompt": "", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 1.0, + "height": 32, + "width": 32, + "output_type": "pt", + "use_resolution_binning": False, + } + + def test_inference(self): + device = "cpu" + components = self.get_dummy_components() + pipe = PhotonPipeline(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + try: + pipe.register_to_config(_execution_device="cpu") + except Exception: + pass + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs)[0] + generated_image = image[0] + + self.assertEqual(generated_image.shape, (3, 32, 32)) + expected_image = torch.zeros(3, 32, 32) + max_diff = np.abs(generated_image - expected_image).max() + self.assertLessEqual(max_diff, 1e10) + + def test_callback_inputs(self): + components = self.get_dummy_components() + pipe = PhotonPipeline(**components) + pipe = pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + try: + pipe.register_to_config(_execution_device="cpu") + except Exception: + pass + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {PhotonPipeline} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + for tensor_name in callback_kwargs.keys(): + assert tensor_name in pipe._callback_tensor_inputs + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + for tensor_name in callback_kwargs.keys(): + assert tensor_name in pipe._callback_tensor_inputs + return callback_kwargs + + inputs = self.get_dummy_inputs("cpu") + + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + _ = pipe(**inputs)[0] + + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + _ = pipe(**inputs)[0] + + def test_attention_slicing_forward_pass(self, expected_max_diff=1e-3): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + def to_np_local(tensor): + if isinstance(tensor, torch.Tensor): + return tensor.detach().cpu().numpy() + return tensor + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + max_diff1 = np.abs(to_np_local(output_with_slicing1) - to_np_local(output_without_slicing)).max() + max_diff2 = np.abs(to_np_local(output_with_slicing2) - to_np_local(output_without_slicing)).max() + self.assertLess(max(max_diff1, max_diff2), expected_max_diff) + + def test_inference_with_autoencoder_dc(self): + """Test PhotonPipeline with AutoencoderDC (DCAE) instead of AutoencoderKL.""" + device = "cpu" + + components = self.get_dummy_components() + + torch.manual_seed(0) + vae_dc = AutoencoderDC( + in_channels=3, + latent_channels=4, + attention_head_dim=2, + encoder_block_types=( + "ResBlock", + "EfficientViTBlock", + ), + decoder_block_types=( + "ResBlock", + "EfficientViTBlock", + ), + encoder_block_out_channels=(8, 8), + decoder_block_out_channels=(8, 8), + encoder_qkv_multiscales=((), (5,)), + decoder_qkv_multiscales=((), (5,)), + encoder_layers_per_block=(1, 1), + decoder_layers_per_block=(1, 1), + upsample_block_type="interpolate", + downsample_block_type="stride_conv", + decoder_norm_types="rms_norm", + decoder_act_fns="silu", + ).eval() + + components["vae"] = vae_dc + + pipe = PhotonPipeline(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + expected_scale_factor = vae_dc.spatial_compression_ratio + self.assertEqual(pipe.vae_scale_factor, expected_scale_factor) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs)[0] + generated_image = image[0] + + self.assertEqual(generated_image.shape, (3, 32, 32)) + expected_image = torch.zeros(3, 32, 32) + max_diff = np.abs(generated_image - expected_image).max() + self.assertLessEqual(max_diff, 1e10) From 77b4f8fa06904cefdddcaf0e4795b28537e4bef4 Mon Sep 17 00:00:00 2001 From: DavidBert Date: Wed, 15 Oct 2025 15:16:51 +0000 Subject: [PATCH 28/52] utility function that determines the default resolution given the VAE --- .../pipelines/photon/pipeline_photon.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/photon/pipeline_photon.py b/src/diffusers/pipelines/photon/pipeline_photon.py index 0c2384657988..1a6fbd194a1c 100644 --- a/src/diffusers/pipelines/photon/pipeline_photon.py +++ b/src/diffusers/pipelines/photon/pipeline_photon.py @@ -301,6 +301,18 @@ def do_classifier_free_guidance(self): def guidance_scale(self): return self._guidance_scale + def get_default_resolution(self): + """Determine the default resolution based on the loaded VAE and config. + + Returns: + int: The default sample size (height/width) to use for generation. + """ + default_from_config = getattr(self.config, "default_sample_size", None) + if default_from_config is not None: + return default_from_config + + return DEFAULT_RESOLUTION + def prepare_latents( self, batch_size: int, @@ -585,8 +597,9 @@ def __call__( """ # 0. Set height and width - height = height or self.default_sample_size - width = width or self.default_sample_size + default_resolution = self.get_default_resolution() + height = height or default_resolution + width = width or default_resolution if use_resolution_binning: if self.image_processor is None: From 19f9c47a5334583a934f002ddd0e685a39400075 Mon Sep 17 00:00:00 2001 From: DavidBert Date: Thu, 16 Oct 2025 07:49:41 +0000 Subject: [PATCH 29/52] Refactor PhotonAttention to match Flux pattern --- scripts/convert_photon_to_diffusers.py | 46 ++-- .../models/transformers/transformer_photon.py | 213 +++++++++++------- 2 files changed, 152 insertions(+), 107 deletions(-) diff --git a/scripts/convert_photon_to_diffusers.py b/scripts/convert_photon_to_diffusers.py index 0f24e5036977..e30ba12ada4e 100644 --- a/scripts/convert_photon_to_diffusers.py +++ b/scripts/convert_photon_to_diffusers.py @@ -67,13 +67,23 @@ def create_parameter_mapping(depth: int) -> dict: # Key mappings for structural changes mapping = {} - # RMSNorm: scale -> weight + # Map old structure (layers in PhotonBlock) to new structure (layers in PhotonAttention) for i in range(depth): - mapping[f"blocks.{i}.qk_norm.query_norm.scale"] = f"blocks.{i}.qk_norm.query_norm.weight" - mapping[f"blocks.{i}.qk_norm.key_norm.scale"] = f"blocks.{i}.qk_norm.key_norm.weight" - mapping[f"blocks.{i}.k_norm.scale"] = f"blocks.{i}.k_norm.weight" + # QKV projections moved to attention module + mapping[f"blocks.{i}.img_qkv_proj.weight"] = f"blocks.{i}.attention.img_qkv_proj.weight" + mapping[f"blocks.{i}.txt_kv_proj.weight"] = f"blocks.{i}.attention.txt_kv_proj.weight" - # Attention: attn_out -> attention.to_out.0 + # QK norm moved to attention module + mapping[f"blocks.{i}.qk_norm.query_norm.scale"] = f"blocks.{i}.attention.qk_norm.query_norm.weight" + mapping[f"blocks.{i}.qk_norm.key_norm.scale"] = f"blocks.{i}.attention.qk_norm.key_norm.weight" + mapping[f"blocks.{i}.qk_norm.query_norm.weight"] = f"blocks.{i}.attention.qk_norm.query_norm.weight" + mapping[f"blocks.{i}.qk_norm.key_norm.weight"] = f"blocks.{i}.attention.qk_norm.key_norm.weight" + + # K norm moved to attention module + mapping[f"blocks.{i}.k_norm.scale"] = f"blocks.{i}.attention.k_norm.weight" + mapping[f"blocks.{i}.k_norm.weight"] = f"blocks.{i}.attention.k_norm.weight" + + # Attention output projection mapping[f"blocks.{i}.attn_out.weight"] = f"blocks.{i}.attention.to_out.0.weight" return mapping @@ -95,31 +105,7 @@ def convert_checkpoint_parameters(old_state_dict: Dict[str, torch.Tensor], depth new_key = mapping[key] print(f" Mapped: {key} -> {new_key}") - # Handle img_qkv_proj -> split to to_q, to_k, to_v - if "img_qkv_proj.weight" in key: - print(f" Found QKV projection: {key}") - # Split QKV weight into separate Q, K, V projections - qkv_weight = value - q_weight, k_weight, v_weight = qkv_weight.chunk(3, dim=0) - - # Extract layer number from key (e.g., blocks.0.img_qkv_proj.weight -> 0) - parts = key.split(".") - layer_idx = None - for i, part in enumerate(parts): - if part == "blocks" and i + 1 < len(parts) and parts[i + 1].isdigit(): - layer_idx = parts[i + 1] - break - - if layer_idx is not None: - converted_state_dict[f"blocks.{layer_idx}.attention.to_q.weight"] = q_weight - converted_state_dict[f"blocks.{layer_idx}.attention.to_k.weight"] = k_weight - converted_state_dict[f"blocks.{layer_idx}.attention.to_v.weight"] = v_weight - print(f" Split QKV for layer {layer_idx}") - - # Also keep the original img_qkv_proj for backward compatibility - converted_state_dict[new_key] = value - else: - converted_state_dict[new_key] = value + converted_state_dict[new_key] = value print(f"✓ Converted {len(converted_state_dict)} parameters") return converted_state_dict diff --git a/src/diffusers/models/transformers/transformer_photon.py b/src/diffusers/models/transformers/transformer_photon.py index a5b20619064c..b36c9c696aa6 100644 --- a/src/diffusers/models/transformers/transformer_photon.py +++ b/src/diffusers/models/transformers/transformer_photon.py @@ -20,7 +20,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging -from ..attention import AttentionMixin +from ..attention import AttentionMixin, AttentionModuleMixin from ..attention_processor import Attention from ..embeddings import get_timestep_embedding from ..modeling_outputs import Transformer2DModelOutput @@ -86,43 +86,86 @@ class PhotonAttnProcessor2_0: diffusers Attention module while handling Photon-specific logic. """ + _attention_backend = None + def __init__(self): if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): raise ImportError("PhotonAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.") def __call__( self, - attn: "Attention", + attn: "PhotonAttention", hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: """ - Apply Photon attention using standard diffusers interface. + Apply Photon attention using PhotonAttention module. - Expected tensor formats from PhotonBlock.attn_forward(): - - hidden_states: Image queries with RoPE applied [B, H, L_img, D] - - encoder_hidden_states: Packed key+value tensors [B, H, L_all, 2*D] (concatenated keys and values from text + - image + spatial conditioning) - - attention_mask: Custom attention mask [B, H, L_img, L_all] or None + Parameters: + attn: PhotonAttention module containing projection layers + hidden_states: Image tokens [B, L_img, D] + encoder_hidden_states: Text tokens [B, L_txt, D] + attention_mask: Boolean mask for text tokens [B, L_txt] + image_rotary_emb: Rotary positional embeddings [B, 1, L_img, head_dim//2, 2, 2] """ if encoder_hidden_states is None: raise ValueError( - "PhotonAttnProcessor2_0 requires 'encoder_hidden_states' containing packed key+value tensors. " - "This should be provided by PhotonBlock.attn_forward()." + "PhotonAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens." ) - # Unpack the combined key+value tensor - # encoder_hidden_states is [B, H, L_all, 2*D] containing [keys, values] - key, value = encoder_hidden_states.chunk(2, dim=-1) # Each [B, H, L_all, D] + # Project image tokens to Q, K, V + img_qkv = attn.img_qkv_proj(hidden_states) + B, L_img, _ = img_qkv.shape + img_qkv = img_qkv.reshape(B, L_img, 3, attn.heads, attn.head_dim) + img_qkv = img_qkv.permute(2, 0, 3, 1, 4) # [3, B, H, L_img, D] + img_q, img_k, img_v = img_qkv[0], img_qkv[1], img_qkv[2] + + # Apply QK normalization to image tokens + img_q, img_k = attn.qk_norm(img_q, img_k, img_v) + + # Project text tokens to K, V + txt_kv = attn.txt_kv_proj(encoder_hidden_states) + B, L_txt, _ = txt_kv.shape + txt_kv = txt_kv.reshape(B, L_txt, 2, attn.heads, attn.head_dim) + txt_kv = txt_kv.permute(2, 0, 3, 1, 4) # [2, B, H, L_txt, D] + txt_k, txt_v = txt_kv[0], txt_kv[1] + + # Apply K normalization to text tokens + txt_k = attn.k_norm(txt_k) + + # Apply RoPE to image queries and keys + if image_rotary_emb is not None: + img_q = apply_rope(img_q, image_rotary_emb) + img_k = apply_rope(img_k, image_rotary_emb) + + # Concatenate text and image keys/values + k = torch.cat((txt_k, img_k), dim=2) # [B, H, L_txt + L_img, D] + v = torch.cat((txt_v, img_v), dim=2) # [B, H, L_txt + L_img, D] + + # Build attention mask if provided + attn_mask_tensor = None + if attention_mask is not None: + bs, _, l_img, _ = img_q.shape + l_txt = txt_k.shape[2] + + if attention_mask.dim() != 2: + raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}") + if attention_mask.shape[-1] != l_txt: + raise ValueError(f"attention_mask last dim {attention_mask.shape[-1]} must equal text length {l_txt}") - # Apply scaled dot-product attention with Photon's processed tensors - # hidden_states is image queries [B, H, L_img, D] + device = img_q.device + ones_img = torch.ones((bs, l_img), dtype=torch.bool, device=device) + attention_mask = attention_mask.to(device=device, dtype=torch.bool) + joint_mask = torch.cat([attention_mask, ones_img], dim=-1) + attn_mask_tensor = joint_mask[:, None, None, :].expand(-1, attn.heads, l_img, -1) + + # Apply scaled dot-product attention attn_output = torch.nn.functional.scaled_dot_product_attention( - hidden_states.contiguous(), key.contiguous(), value.contiguous(), attn_mask=attention_mask + img_q.contiguous(), k.contiguous(), v.contiguous(), attn_mask=attn_mask_tensor ) # Reshape from [B, H, L_img, D] to [B, L_img, H*D] @@ -137,6 +180,67 @@ def __call__( return attn_output +class PhotonAttention(nn.Module, AttentionModuleMixin): + r""" + Photon-style attention module that handles multi-source tokens and RoPE. + Similar to FluxAttention but adapted for Photon's architecture. + """ + + _default_processor_cls = PhotonAttnProcessor2_0 + _available_processors = [PhotonAttnProcessor2_0] + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + bias: bool = False, + out_bias: bool = False, + eps: float = 1e-6, + processor=None, + ): + super().__init__() + + self.heads = heads + self.head_dim = dim_head + self.inner_dim = dim_head * heads + self.query_dim = query_dim + + # Image QKV projections + self.img_qkv_proj = nn.Linear(query_dim, query_dim * 3, bias=bias) + self.qk_norm = QKNorm(self.head_dim) + + # Text KV projections + self.txt_kv_proj = nn.Linear(query_dim, query_dim * 2, bias=bias) + self.k_norm = RMSNorm(self.head_dim, eps=eps) + + # Output projection + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, query_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(0.0)) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + **kwargs, + ) + + # inspired from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py class PhotoEmbedND(nn.Module): r""" @@ -273,18 +377,10 @@ class PhotonBlock(nn.Module): Attributes: img_pre_norm (`nn.LayerNorm`): - Pre-normalization applied to image tokens before QKV projection. - img_qkv_proj (`nn.Linear`): - Linear projection to produce image queries, keys, and values. - qk_norm (`QKNorm`): - RMS normalization applied separately to image queries and keys. - txt_kv_proj (`nn.Linear`): - Linear projection to produce text keys and values. - k_norm (`RMSNorm`): - RMS normalization applied to text keys. - attention (`Attention`): - Multi-head attention module for cross-attention between image, text, and optional spatial conditioning - tokens. + Pre-normalization applied to image tokens before attention. + attention (`PhotonAttention`): + Multi-head attention module with built-in QKV projections and normalizations for cross-attention between + image and text tokens. post_attention_layernorm (`nn.LayerNorm`): Normalization applied after attention. gate_proj / up_proj / down_proj (`nn.Linear`): @@ -295,7 +391,7 @@ class PhotonBlock(nn.Module): Produces scale/shift/gating parameters for modulated layers. Methods: - The forward method performs cross-attention and the MLP inline with modulation. + The forward method performs cross-attention and the MLP with modulation. """ def __init__( @@ -315,21 +411,17 @@ def __init__( self.mlp_hidden_dim = int(hidden_size * mlp_ratio) self.hidden_size = hidden_size - # img qkv + # Pre-attention normalization for image tokens self.img_pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.img_qkv_proj = nn.Linear(hidden_size, hidden_size * 3, bias=False) - self.qk_norm = QKNorm(self.head_dim) - - # txt kv - self.txt_kv_proj = nn.Linear(hidden_size, hidden_size * 2, bias=False) - self.k_norm = RMSNorm(self.head_dim, eps=1e-6) - self.attention = Attention( + # PhotonAttention module with built-in projections and norms + self.attention = PhotonAttention( query_dim=hidden_size, heads=num_heads, dim_head=self.head_dim, bias=False, out_bias=False, + eps=1e-6, processor=PhotonAttnProcessor2_0(), ) @@ -378,48 +470,15 @@ def forward( attn_shift, attn_scale, attn_gate = mod_attn mlp_shift, mlp_scale, mlp_gate = mod_mlp - # Inline attention forward + # Apply modulation and pre-normalization to image tokens img_mod = (1 + attn_scale) * self.img_pre_norm(img) + attn_shift - img_qkv = self.img_qkv_proj(img_mod) - B, L, _ = img_qkv.shape - img_qkv = img_qkv.reshape(B, L, 3, self.num_heads, self.head_dim) - img_qkv = img_qkv.permute(2, 0, 3, 1, 4) - img_q, img_k, img_v = img_qkv[0], img_qkv[1], img_qkv[2] - img_q, img_k = self.qk_norm(img_q, img_k, img_v) - - txt_kv = self.txt_kv_proj(txt) - B, L, _ = txt_kv.shape - txt_kv = txt_kv.reshape(B, L, 2, self.num_heads, self.head_dim) - txt_kv = txt_kv.permute(2, 0, 3, 1, 4) - txt_k, txt_v = txt_kv[0], txt_kv[1] - txt_k = self.k_norm(txt_k) - - img_q, img_k = apply_rope(img_q, pe), apply_rope(img_k, pe) - k = torch.cat((txt_k, img_k), dim=2) - v = torch.cat((txt_v, img_v), dim=2) - - attn_mask_tensor: Tensor | None = None - if attention_mask is not None: - bs, _, l_img, _ = img_q.shape - l_txt = txt_k.shape[2] - - if attention_mask.dim() != 2: - raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}") - if attention_mask.shape[-1] != l_txt: - raise ValueError(f"attention_mask last dim {attention_mask.shape[-1]} must equal text length {l_txt}") - - device = img_q.device - ones_img = torch.ones((bs, l_img), dtype=torch.bool, device=device) - attention_mask = attention_mask.to(device=device, dtype=torch.bool) - joint_mask = torch.cat([attention_mask, ones_img], dim=-1) - attn_mask_tensor = joint_mask[:, None, None, :].expand(-1, self.num_heads, l_img, -1) - - kv_packed = torch.cat([k, v], dim=-1) + # Forward through PhotonAttention module attn_out = self.attention( - hidden_states=img_q, - encoder_hidden_states=kv_packed, - attention_mask=attn_mask_tensor, + hidden_states=img_mod, + encoder_hidden_states=txt, + attention_mask=attention_mask, + image_rotary_emb=pe, ) img = img + attn_gate * attn_out From cd780320d7d51ddb12fe422ea269ee2d63b731a4 Mon Sep 17 00:00:00 2001 From: DavidBert Date: Thu, 16 Oct 2025 08:21:04 +0000 Subject: [PATCH 30/52] built-in RMSNorm --- .gitignore | 22 ++++++++++- scripts/convert_photon_to_diffusers.py | 20 +++++----- .../models/transformers/transformer_photon.py | 39 ++++--------------- 3 files changed, 39 insertions(+), 42 deletions(-) diff --git a/.gitignore b/.gitignore index a55026febd5a..b8ec51a2dbac 100644 --- a/.gitignore +++ b/.gitignore @@ -178,4 +178,24 @@ tags .ruff_cache # wandb -wandb \ No newline at end of file +wandb +converted_mirage_dcae_/ +converted_mirage_flux_/ +converted_photon_dcae +converted_photon_dcae/ +converted_photon_flux/ +diffusers_pipeline_checkpoints/ +*.png +convert_checkpoints.py +example_usage.py +plan.md +test_existing_checkpoints_with_timestep_change.py +test_timestep_embedding.py +test_updated_checkpoint.py +testhf.ipynb +update_checkpoint_parameters.py +verify_checkpoint_parameters.py +for_claude/mirage_layers.py +for_claude/mirage.py +for_claude/text_tower.py +for_claude/vae_tower.py diff --git a/scripts/convert_photon_to_diffusers.py b/scripts/convert_photon_to_diffusers.py index e30ba12ada4e..6e4a49de37cf 100644 --- a/scripts/convert_photon_to_diffusers.py +++ b/scripts/convert_photon_to_diffusers.py @@ -73,15 +73,17 @@ def create_parameter_mapping(depth: int) -> dict: mapping[f"blocks.{i}.img_qkv_proj.weight"] = f"blocks.{i}.attention.img_qkv_proj.weight" mapping[f"blocks.{i}.txt_kv_proj.weight"] = f"blocks.{i}.attention.txt_kv_proj.weight" - # QK norm moved to attention module - mapping[f"blocks.{i}.qk_norm.query_norm.scale"] = f"blocks.{i}.attention.qk_norm.query_norm.weight" - mapping[f"blocks.{i}.qk_norm.key_norm.scale"] = f"blocks.{i}.attention.qk_norm.key_norm.weight" - mapping[f"blocks.{i}.qk_norm.query_norm.weight"] = f"blocks.{i}.attention.qk_norm.query_norm.weight" - mapping[f"blocks.{i}.qk_norm.key_norm.weight"] = f"blocks.{i}.attention.qk_norm.key_norm.weight" - - # K norm moved to attention module - mapping[f"blocks.{i}.k_norm.scale"] = f"blocks.{i}.attention.k_norm.weight" - mapping[f"blocks.{i}.k_norm.weight"] = f"blocks.{i}.attention.k_norm.weight" + # QK norm moved to attention module and renamed to match Attention's qk_norm structure + # Old: qk_norm.query_norm / qk_norm.key_norm -> New: norm_q / norm_k + mapping[f"blocks.{i}.qk_norm.query_norm.scale"] = f"blocks.{i}.attention.norm_q.weight" + mapping[f"blocks.{i}.qk_norm.key_norm.scale"] = f"blocks.{i}.attention.norm_k.weight" + mapping[f"blocks.{i}.qk_norm.query_norm.weight"] = f"blocks.{i}.attention.norm_q.weight" + mapping[f"blocks.{i}.qk_norm.key_norm.weight"] = f"blocks.{i}.attention.norm_k.weight" + + # K norm for text tokens moved to attention module + # Old: k_norm -> New: norm_added_k + mapping[f"blocks.{i}.k_norm.scale"] = f"blocks.{i}.attention.norm_added_k.weight" + mapping[f"blocks.{i}.k_norm.weight"] = f"blocks.{i}.attention.norm_added_k.weight" # Attention output projection mapping[f"blocks.{i}.attn_out.weight"] = f"blocks.{i}.attention.to_out.0.weight" diff --git a/src/diffusers/models/transformers/transformer_photon.py b/src/diffusers/models/transformers/transformer_photon.py index b36c9c696aa6..b80f33e22265 100644 --- a/src/diffusers/models/transformers/transformer_photon.py +++ b/src/diffusers/models/transformers/transformer_photon.py @@ -125,7 +125,8 @@ def __call__( img_q, img_k, img_v = img_qkv[0], img_qkv[1], img_qkv[2] # Apply QK normalization to image tokens - img_q, img_k = attn.qk_norm(img_q, img_k, img_v) + img_q = attn.norm_q(img_q) + img_k = attn.norm_k(img_k) # Project text tokens to K, V txt_kv = attn.txt_kv_proj(encoder_hidden_states) @@ -135,7 +136,7 @@ def __call__( txt_k, txt_v = txt_kv[0], txt_kv[1] # Apply K normalization to text tokens - txt_k = attn.k_norm(txt_k) + txt_k = attn.norm_added_k(txt_k) # Apply RoPE to image queries and keys if image_rotary_emb is not None: @@ -206,15 +207,14 @@ def __init__( self.inner_dim = dim_head * heads self.query_dim = query_dim - # Image QKV projections self.img_qkv_proj = nn.Linear(query_dim, query_dim * 3, bias=bias) - self.qk_norm = QKNorm(self.head_dim) - # Text KV projections + self.norm_q = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True) + self.norm_k = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True) + self.txt_kv_proj = nn.Linear(query_dim, query_dim * 2, bias=bias) - self.k_norm = RMSNorm(self.head_dim, eps=eps) + self.norm_added_k = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True) - # Output projection self.to_out = nn.ModuleList([]) self.to_out.append(nn.Linear(self.inner_dim, query_dim, bias=out_bias)) self.to_out.append(nn.Dropout(0.0)) @@ -309,31 +309,6 @@ def forward(self, x: Tensor) -> Tensor: return self.out_layer(self.silu(self.in_layer(x))) -class QKNorm(torch.nn.Module): - r""" - Applies RMS normalization to query and key tensors separately before attention which can help stabilize training - and improve numerical precision. - - Parameters: - dim (`int`): - Dimensionality of the query and key vectors. - - Returns: - (`torch.Tensor`, `torch.Tensor`): - A tuple `(q, k)` where both are normalized and cast to the same dtype as the value tensor `v`. - """ - - def __init__(self, dim: int): - super().__init__() - self.query_norm = RMSNorm(dim, eps=1e-6) - self.key_norm = RMSNorm(dim, eps=1e-6) - - def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: - q = self.query_norm(q) - k = self.key_norm(k) - return q.to(v), k.to(v) - - class Modulation(nn.Module): r""" Modulation network that generates scale, shift, and gating parameters. From 69872404d1dce0af20aa74d8b09db77b4fe46dad Mon Sep 17 00:00:00 2001 From: DavidBert Date: Thu, 16 Oct 2025 08:32:27 +0000 Subject: [PATCH 31/52] Revert accidental .gitignore change --- .gitignore | 22 +--------------------- 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/.gitignore b/.gitignore index b8ec51a2dbac..a55026febd5a 100644 --- a/.gitignore +++ b/.gitignore @@ -178,24 +178,4 @@ tags .ruff_cache # wandb -wandb -converted_mirage_dcae_/ -converted_mirage_flux_/ -converted_photon_dcae -converted_photon_dcae/ -converted_photon_flux/ -diffusers_pipeline_checkpoints/ -*.png -convert_checkpoints.py -example_usage.py -plan.md -test_existing_checkpoints_with_timestep_change.py -test_timestep_embedding.py -test_updated_checkpoint.py -testhf.ipynb -update_checkpoint_parameters.py -verify_checkpoint_parameters.py -for_claude/mirage_layers.py -for_claude/mirage.py -for_claude/text_tower.py -for_claude/vae_tower.py +wandb \ No newline at end of file From c92ee55bbc4bb48ab193aad2147ef8ad2f21700f Mon Sep 17 00:00:00 2001 From: DavidBert Date: Thu, 16 Oct 2025 08:34:52 +0000 Subject: [PATCH 32/52] parameter names match the standard diffusers conventions --- .../models/transformers/transformer_photon.py | 51 +++++++++---------- 1 file changed, 24 insertions(+), 27 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_photon.py b/src/diffusers/models/transformers/transformer_photon.py index b80f33e22265..6c94e9f67ab3 100644 --- a/src/diffusers/models/transformers/transformer_photon.py +++ b/src/diffusers/models/transformers/transformer_photon.py @@ -411,57 +411,54 @@ def __init__( def forward( self, - img: Tensor, - txt: Tensor, - vec: Tensor, - pe: Tensor, + hidden_states: Tensor, + encoder_hidden_states: Tensor, + temb: Tensor, + image_rotary_emb: Tensor, attention_mask: Tensor | None = None, - **_: dict[str, Any], + **kwargs: dict[str, Any], ) -> Tensor: r""" Runs modulation-gated cross-attention and MLP, with residual connections. Parameters: - img (`torch.Tensor`): + hidden_states (`torch.Tensor`): Image tokens of shape `(B, L_img, hidden_size)`. - txt (`torch.Tensor`): + encoder_hidden_states (`torch.Tensor`): Text tokens of shape `(B, L_txt, hidden_size)`. - vec (`torch.Tensor`): + temb (`torch.Tensor`): Conditioning vector used by `Modulation` to produce scale/shift/gates, shape `(B, hidden_size)` (or broadcastable). - pe (`torch.Tensor`): + image_rotary_emb (`torch.Tensor`): Rotary positional embeddings applied inside attention. attention_mask (`torch.Tensor`, *optional*): Boolean mask for text tokens of shape `(B, L_txt)`, where `0` marks padding. - **_: - Ignored additional keyword arguments for API compatibility. + **kwargs: + Additional keyword arguments for API compatibility. Returns: `torch.Tensor`: Updated image tokens of shape `(B, L_img, hidden_size)`. """ - mod_attn, mod_mlp = self.modulation(vec) + mod_attn, mod_mlp = self.modulation(temb) attn_shift, attn_scale, attn_gate = mod_attn mlp_shift, mlp_scale, mlp_gate = mod_mlp - # Apply modulation and pre-normalization to image tokens - img_mod = (1 + attn_scale) * self.img_pre_norm(img) + attn_shift + hidden_states_mod = (1 + attn_scale) * self.img_pre_norm(hidden_states) + attn_shift - # Forward through PhotonAttention module attn_out = self.attention( - hidden_states=img_mod, - encoder_hidden_states=txt, + hidden_states=hidden_states_mod, + encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, - image_rotary_emb=pe, + image_rotary_emb=image_rotary_emb, ) - img = img + attn_gate * attn_out + hidden_states = hidden_states + attn_gate * attn_out - # Inline FFN forward - x = (1 + mlp_scale) * self.post_attention_layernorm(img) + mlp_shift - img = img + mlp_gate * (self.down_proj(self.mlp_act(self.gate_proj(x)) * self.up_proj(x))) - return img + x = (1 + mlp_scale) * self.post_attention_layernorm(hidden_states) + mlp_shift + hidden_states = hidden_states + mlp_gate * (self.down_proj(self.mlp_act(self.gate_proj(x)) * self.up_proj(x))) + return hidden_states class FinalLayer(nn.Module): @@ -749,10 +746,10 @@ def forward( ) else: img = block( - img=img, - txt=txt, - vec=vec, - pe=pe, + hidden_states=img, + encoder_hidden_states=txt, + temb=vec, + image_rotary_emb=pe, attention_mask=cross_attn_mask, ) From 4f74d940609e5d1faa05af07f4742da9531745bd Mon Sep 17 00:00:00 2001 From: DavidBert Date: Thu, 16 Oct 2025 08:43:54 +0000 Subject: [PATCH 33/52] renaming and remove unecessary attributes setting --- src/diffusers/pipelines/photon/pipeline_photon.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/photon/pipeline_photon.py b/src/diffusers/pipelines/photon/pipeline_photon.py index 1a6fbd194a1c..457bbd2223cb 100644 --- a/src/diffusers/pipelines/photon/pipeline_photon.py +++ b/src/diffusers/pipelines/photon/pipeline_photon.py @@ -262,8 +262,6 @@ def __init__( "PhotonTransformer2DModel is not available. Please ensure the transformer_photon module is properly installed." ) - self.text_encoder = text_encoder - self.tokenizer = tokenizer self.text_preprocessor = TextPreprocessor() self.default_sample_size = default_sample_size self._guidance_scale = 1.0 @@ -357,13 +355,9 @@ def encode_prompt( if isinstance(prompt, str): prompt = [prompt] # Encode the prompts - text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask = ( + prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = ( self._encode_prompt_standard(prompt, device, do_classifier_free_guidance, negative_prompt) ) - prompt_embeds = text_embeddings - prompt_attention_mask = cross_attn_mask - negative_prompt_embeds = uncond_text_embeddings - negative_prompt_attention_mask = uncond_cross_attn_mask # Duplicate embeddings for each generation per prompt if num_images_per_prompt > 1: From d219e8cfa002ca5cb48093f66e706a2a465767d6 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Thu, 16 Oct 2025 10:48:24 +0200 Subject: [PATCH 34/52] Update docs/source/en/api/pipelines/photon.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/api/pipelines/photon.md | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/docs/source/en/api/pipelines/photon.md b/docs/source/en/api/pipelines/photon.md index 62133f93c490..a46e9a4fc552 100644 --- a/docs/source/en/api/pipelines/photon.md +++ b/docs/source/en/api/pipelines/photon.md @@ -15,15 +15,7 @@ # PhotonPipeline -Photon is a text-to-image diffusion model using simplified MMDIT architecture with flow matching for efficient high-quality image generation. The model uses T5Gemma as the text encoder and supports either Flux VAE (AutoencoderKL) or DC-AE (AutoencoderDC) for latent compression. - -Key features: - -- **Simplified MMDIT architecture**: Uses a simplified MMDIT architecture for image generation where text tokens are not updated through the transformer blocks -- **Flow Matching**: Employs flow matching with discrete scheduling for efficient sampling -- **Flexible VAE Support**: Compatible with both Flux VAE (8x compression, 16 latent channels) and DC-AE (32x compression, 32 latent channels) -- **T5Gemma Text Encoder**: Uses Google's T5Gemma-2B-2B-UL2 model for text encoding offering multiple language support -- **Efficient Architecture**: ~1.3B parameters in the transformer, enabling fast inference while maintaining quality +Photon generates high-quality images from text using a simplified MMDIT architecture where text tokens don't update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google's T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing. ## Available models: We offer a range of **Photon models** featuring different **VAE configurations**, each optimized for generating images at various resolutions. From 5d57f44f5fce8539017b6473ccc428fb85a7cff2 Mon Sep 17 00:00:00 2001 From: DavidBert Date: Thu, 16 Oct 2025 09:23:43 +0000 Subject: [PATCH 35/52] quantization example --- docs/source/en/api/pipelines/photon.md | 70 ++++++++------------------ 1 file changed, 20 insertions(+), 50 deletions(-) diff --git a/docs/source/en/api/pipelines/photon.md b/docs/source/en/api/pipelines/photon.md index a46e9a4fc552..737760aabd11 100644 --- a/docs/source/en/api/pipelines/photon.md +++ b/docs/source/en/api/pipelines/photon.md @@ -54,36 +54,46 @@ image.save("photon_output.png") ### Manual Component Loading -You can also load components individually: +Load components individually to customize the pipeline for instance to use quantized models. ```py import torch -from diffusers import PhotonPipeline +from diffusers.pipelines.photon import PhotonPipeline from diffusers.models import AutoencoderKL, AutoencoderDC from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from transformers import T5GemmaModel, GemmaTokenizerFast +from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig +from transformers import BitsAndBytesConfig as BitsAndBytesConfig +quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True) # Load transformer transformer = PhotonTransformer2DModel.from_pretrained( - "Photoroom/photon-512-t2i-sft", subfolder="transformer" -).to(dtype=torch.bfloat16) + "checkpoints/photon-512-t2i-sft", + subfolder="transformer", + quantization_config=quant_config, + torch_dtype=torch.bfloat16, +) # Load scheduler scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( - "Photoroom/photon-512-t2i-sft", subfolder="scheduler" + "checkpoints/photon-512-t2i-sft", subfolder="scheduler" ) # Load T5Gemma text encoder -t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2") +t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2", + quantization_config=quant_config, + torch_dtype=torch.bfloat16) text_encoder = t5gemma_model.encoder.to(dtype=torch.bfloat16) tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2") tokenizer.model_max_length = 256 + # Load VAE - choose either Flux VAE or DC-AE -# Flux VAE (16 latent channels): -vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae").to(dtype=torch.bfloat16) -# Or DC-AE (32 latent channels): -# vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers") +# Flux VAE +vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", + subfolder="vae", + quantization_config=quant_config, + torch_dtype=torch.bfloat16) pipe = PhotonPipeline( transformer=transformer, @@ -95,46 +105,6 @@ pipe = PhotonPipeline( pipe.to("cuda") ``` -## VAE Variants - -Photon supports two VAE configurations: - -### Flux VAE (AutoencoderKL) -- **Compression**: 8x spatial compression -- **Latent channels**: 16 -- **Model**: `black-forest-labs/FLUX.1-dev` (subfolder: "vae") -- **Use case**: Balanced quality and speed - -### DC-AE (AutoencoderDC) -- **Compression**: 32x spatial compression -- **Latent channels**: 32 -- **Model**: `mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers` -- **Use case**: Higher compression for faster processing - -The VAE type is automatically determined from the checkpoint's `model_index.json` configuration. - -## Generation Parameters - -Key parameters for image generation: - -- **num_inference_steps**: Number of denoising steps (default: 28). More steps generally improve quality at the cost of speed. -- **guidance_scale**: Classifier-free guidance strength (default: 4.0). Higher values produce images more closely aligned with the prompt. -- **height/width**: Output image dimensions (default: 512x512). Can be customized in the checkpoint configuration. - -```py -# Example with custom parameters -import torch -from diffusers.pipelines.photon import PhotonPipeline -pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft", torch_dtype=torch.bfloat16) -pipe = pipe( - prompt = "A front-facing portrait of a lion the golden savanna at sunset." - num_inference_steps=28, - guidance_scale=4.0, - height=512, - width=512, - generator=torch.Generator("cuda").manual_seed(42) -).images[0] -``` ## Memory Optimization From 5270316d01489069cc9c7cfee039610d81cc5e5a Mon Sep 17 00:00:00 2001 From: DavidBert Date: Thu, 16 Oct 2025 09:47:32 +0000 Subject: [PATCH 36/52] added doc to toctree --- docs/source/en/_toctree.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 85e996a1b772..3abe89437fa5 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -541,6 +541,8 @@ title: PAG - local: api/pipelines/paint_by_example title: Paint by Example + - local: api/pipelines/photon + title: Photon - local: api/pipelines/pixart title: PixArt-α - local: api/pipelines/pixart_sigma From f7f516fa8280e42345b0625a010469a173a12c14 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Thu, 16 Oct 2025 10:50:31 +0200 Subject: [PATCH 37/52] Update docs/source/en/api/pipelines/photon.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/api/pipelines/photon.md | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/docs/source/en/api/pipelines/photon.md b/docs/source/en/api/pipelines/photon.md index 737760aabd11..52b5c63c8f0e 100644 --- a/docs/source/en/api/pipelines/photon.md +++ b/docs/source/en/api/pipelines/photon.md @@ -17,12 +17,9 @@ Photon generates high-quality images from text using a simplified MMDIT architecture where text tokens don't update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google's T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing. -## Available models: -We offer a range of **Photon models** featuring different **VAE configurations**, each optimized for generating images at various resolutions. -Both **fine-tuned** and **non-fine-tuned** versions are available: +## Available models -- **Non-fine-tuned models** perform best with **highly detailed prompts**, capturing fine nuances and complex compositions. -- **Fine-tuned models**, trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist), enhance the **aesthetic quality** of the base models—especially when prompts are **less detailed**. +Photon offers multiple variants with different VAE configurations, each optimized for specific resolutions. Base models excel with detailed prompts, capturing complex compositions and subtle details. Fine-tuned models trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) improve aesthetic quality, especially with simpler prompts. | Model | Resolution | Fine-tuned | Distilled | Description | Suggested prompts | Suggested parameters | Recommended dtype | From fb98a3ac4cfc742e3aab83ed582df84af4e55b67 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Thu, 16 Oct 2025 10:50:47 +0200 Subject: [PATCH 38/52] Update docs/source/en/api/pipelines/photon.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/api/pipelines/photon.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/photon.md b/docs/source/en/api/pipelines/photon.md index 52b5c63c8f0e..b65d6cb42911 100644 --- a/docs/source/en/api/pipelines/photon.md +++ b/docs/source/en/api/pipelines/photon.md @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. --> -# PhotonPipeline +# Photon Photon generates high-quality images from text using a simplified MMDIT architecture where text tokens don't update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google's T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing. From f89132300ce1874c6e8d6b8f2778a219bf4cce9c Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Thu, 16 Oct 2025 10:51:09 +0200 Subject: [PATCH 39/52] Update docs/source/en/api/pipelines/photon.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/api/pipelines/photon.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/photon.md b/docs/source/en/api/pipelines/photon.md index b65d6cb42911..f9d6ba5a1792 100644 --- a/docs/source/en/api/pipelines/photon.md +++ b/docs/source/en/api/pipelines/photon.md @@ -35,7 +35,9 @@ Photon offers multiple variants with different VAE configurations, each optimize Refer to [this](https://huggingface.co/collections/Photoroom/photon-models-68e66254c202ebfab99ad38e) collection for more information. -## Loading the Pipeline +## Loading the pipeline + +Load the pipeline with [`~DiffusionPipeline.from_pretrained`]. ```py from diffusers.pipelines.photon import PhotonPipeline From 5fcb8e64edb56fb938d6c04c1a720deaebecb72e Mon Sep 17 00:00:00 2001 From: DavidBert Date: Fri, 17 Oct 2025 15:21:24 +0000 Subject: [PATCH 40/52] use dispatch_attention_fn for multiple attention backend support --- scripts/convert_photon_to_diffusers.py | 7 +- src/diffusers/__init__.py | 2 + src/diffusers/models/__init__.py | 1 + .../models/transformers/transformer_photon.py | 100 ++++++++++-------- src/diffusers/pipelines/__init__.py | 1 + .../pipelines/photon/pipeline_photon.py | 11 +- .../pipelines/photon/test_pipeline_photon.py | 44 ++++---- 7 files changed, 85 insertions(+), 81 deletions(-) diff --git a/scripts/convert_photon_to_diffusers.py b/scripts/convert_photon_to_diffusers.py index 6e4a49de37cf..c66bc314181f 100644 --- a/scripts/convert_photon_to_diffusers.py +++ b/scripts/convert_photon_to_diffusers.py @@ -13,9 +13,6 @@ import torch from safetensors.torch import save_file - -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) - from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel from diffusers.pipelines.photon import PhotonPipeline @@ -74,14 +71,12 @@ def create_parameter_mapping(depth: int) -> dict: mapping[f"blocks.{i}.txt_kv_proj.weight"] = f"blocks.{i}.attention.txt_kv_proj.weight" # QK norm moved to attention module and renamed to match Attention's qk_norm structure - # Old: qk_norm.query_norm / qk_norm.key_norm -> New: norm_q / norm_k mapping[f"blocks.{i}.qk_norm.query_norm.scale"] = f"blocks.{i}.attention.norm_q.weight" mapping[f"blocks.{i}.qk_norm.key_norm.scale"] = f"blocks.{i}.attention.norm_k.weight" mapping[f"blocks.{i}.qk_norm.query_norm.weight"] = f"blocks.{i}.attention.norm_q.weight" mapping[f"blocks.{i}.qk_norm.key_norm.weight"] = f"blocks.{i}.attention.norm_k.weight" # K norm for text tokens moved to attention module - # Old: k_norm -> New: norm_added_k mapping[f"blocks.{i}.k_norm.scale"] = f"blocks.{i}.attention.norm_added_k.weight" mapping[f"blocks.{i}.k_norm.weight"] = f"blocks.{i}.attention.norm_added_k.weight" @@ -306,7 +301,7 @@ def main(args): parser = argparse.ArgumentParser(description="Convert Photon checkpoint to diffusers format") parser.add_argument( - "--checkpoint_path", type=str, required=True, help="Path to the original Photon checkpoint (.pth file)" + "--checkpoint_path", type=str, required=True, help="Path to the original Photon checkpoint (.pth file )" ) parser.add_argument( diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index c2528bc50fe5..28b2ae25499a 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -516,6 +516,7 @@ "MusicLDMPipeline", "OmniGenPipeline", "PaintByExamplePipeline", + "PhotonPipeline", "PIAPipeline", "PixArtAlphaPipeline", "PixArtSigmaPAGPipeline", @@ -1180,6 +1181,7 @@ MusicLDMPipeline, OmniGenPipeline, PaintByExamplePipeline, + PhotonPipeline, PIAPipeline, PixArtAlphaPipeline, PixArtSigmaPAGPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index f3164e48cfbf..2151e602b2e2 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -191,6 +191,7 @@ LuminaNextDiT2DModel, MochiTransformer3DModel, OmniGenTransformer2DModel, + PhotonTransformer2DModel, PixArtTransformer2DModel, PriorTransformer, QwenImageTransformer2DModel, diff --git a/src/diffusers/models/transformers/transformer_photon.py b/src/diffusers/models/transformers/transformer_photon.py index 6c94e9f67ab3..1a40a829719e 100644 --- a/src/diffusers/models/transformers/transformer_photon.py +++ b/src/diffusers/models/transformers/transformer_photon.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import Tensor, nn @@ -21,7 +21,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging from ..attention import AttentionMixin, AttentionModuleMixin -from ..attention_processor import Attention +from ..attention_dispatch import dispatch_attention_fn from ..embeddings import get_timestep_embedding from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -35,7 +35,7 @@ def get_image_ids(batch_size: int, height: int, width: int, patch_size: int, dev r""" Generates 2D patch coordinate indices for a batch of images. - Parameters: + Args: batch_size (`int`): Number of images in the batch. height (`int`): @@ -63,7 +63,7 @@ def apply_rope(xq: Tensor, freqs_cis: Tensor) -> Tensor: r""" Applies rotary positional embeddings (RoPE) to a query tensor. - Parameters: + Args: xq (`torch.Tensor`): Input tensor of shape `(..., dim)` representing the queries. freqs_cis (`torch.Tensor`): @@ -82,11 +82,12 @@ def apply_rope(xq: Tensor, freqs_cis: Tensor) -> Tensor: class PhotonAttnProcessor2_0: r""" - Processor for implementing Photon-style attention with multi-source tokens and RoPE. Properly integrates with - diffusers Attention module while handling Photon-specific logic. + Processor for implementing Photon-style attention with multi-source tokens and RoPE. Supports multiple attention + backends (Flash Attention, Sage Attention, etc.) via dispatch_attention_fn. """ _attention_backend = None + _parallel_config = None def __init__(self): if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): @@ -104,7 +105,7 @@ def __call__( """ Apply Photon attention using PhotonAttention module. - Parameters: + Args: attn: PhotonAttention module containing projection layers hidden_states: Image tokens [B, L_img, D] encoder_hidden_states: Text tokens [B, L_txt, D] @@ -113,9 +114,7 @@ def __call__( """ if encoder_hidden_states is None: - raise ValueError( - "PhotonAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens." - ) + raise ValueError("PhotonAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens.") # Project image tokens to Q, K, V img_qkv = attn.img_qkv_proj(hidden_states) @@ -164,14 +163,24 @@ def __call__( joint_mask = torch.cat([attention_mask, ones_img], dim=-1) attn_mask_tensor = joint_mask[:, None, None, :].expand(-1, attn.heads, l_img, -1) - # Apply scaled dot-product attention - attn_output = torch.nn.functional.scaled_dot_product_attention( - img_q.contiguous(), k.contiguous(), v.contiguous(), attn_mask=attn_mask_tensor + # Apply attention using dispatch_attention_fn for backend support + # Reshape to match dispatch_attention_fn expectations: [B, L, H, D] + query = img_q.transpose(1, 2) # [B, L_img, H, D] + key = k.transpose(1, 2) # [B, L_txt + L_img, H, D] + value = v.transpose(1, 2) # [B, L_txt + L_img, H, D] + + attn_output = dispatch_attention_fn( + query, + key, + value, + attn_mask=attn_mask_tensor, + backend=self._attention_backend, + parallel_config=self._parallel_config, ) - # Reshape from [B, H, L_img, D] to [B, L_img, H*D] - batch_size, num_heads, seq_len, head_dim = attn_output.shape - attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, num_heads * head_dim) + # Reshape from [B, L_img, H, D] to [B, L_img, H*D] + batch_size, seq_len, num_heads, head_dim = attn_output.shape + attn_output = attn_output.reshape(batch_size, seq_len, num_heads * head_dim) # Apply output projection attn_output = attn.to_out[0](attn_output) @@ -183,8 +192,8 @@ def __call__( class PhotonAttention(nn.Module, AttentionModuleMixin): r""" - Photon-style attention module that handles multi-source tokens and RoPE. - Similar to FluxAttention but adapted for Photon's architecture. + Photon-style attention module that handles multi-source tokens and RoPE. Similar to FluxAttention but adapted for + Photon's architecture. """ _default_processor_cls = PhotonAttnProcessor2_0 @@ -242,14 +251,14 @@ def forward( # inspired from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py -class PhotoEmbedND(nn.Module): +class PhotonEmbedND(nn.Module): r""" N-dimensional rotary positional embedding. This module creates rotary embeddings (RoPE) across multiple axes, where each axis can have its own embedding dimension. The embeddings are combined and returned as a single tensor - Parameters: + Args: dim (int): Base embedding dimension (must be even). theta (int): @@ -258,7 +267,7 @@ class PhotoEmbedND(nn.Module): List of embedding dimensions for each axis (each must be even). """ - def __init__(self, dim: int, theta: int, axes_dim: list[int]): + def __init__(self, dim: int, theta: int, axes_dim: List[int]): super().__init__() self.dim = dim self.theta = theta @@ -288,7 +297,7 @@ class MLPEmbedder(nn.Module): r""" A simple 2-layer MLP used for embedding inputs. - Parameters: + Args: in_dim (`int`): Dimensionality of the input features. hidden_dim (`int`): @@ -316,7 +325,7 @@ class Modulation(nn.Module): Given an input vector, the module projects it through a linear layer to produce six chunks, which are grouped into two tuples `(shift, scale, gate)`. - Parameters: + Args: dim (`int`): Dimensionality of the input vector. The output will have `6 * dim` features internally. @@ -340,7 +349,7 @@ class PhotonBlock(nn.Module): r""" Multimodal transformer block with text–image cross-attention, modulation, and MLP. - Parameters: + Args: hidden_size (`int`): Dimension of the hidden representations. num_heads (`int`): @@ -421,7 +430,7 @@ def forward( r""" Runs modulation-gated cross-attention and MLP, with residual connections. - Parameters: + Args: hidden_states (`torch.Tensor`): Image tokens of shape `(B, L_img, hidden_size)`. encoder_hidden_states (`torch.Tensor`): @@ -468,7 +477,7 @@ class FinalLayer(nn.Module): This layer applies a normalized and modulated transformation to input tokens and projects them into patch-level outputs. - Parameters: + Args: hidden_size (`int`): Dimensionality of the input tokens. patch_size (`int`): @@ -505,7 +514,7 @@ def img2seq(img: Tensor, patch_size: int) -> Tensor: r""" Flattens an image tensor into a sequence of non-overlapping patches. - Parameters: + Args: img (`torch.Tensor`): Input image tensor of shape `(B, C, H, W)`. patch_size (`int`): @@ -523,7 +532,7 @@ def seq2img(seq: Tensor, patch_size: int, shape: Tensor) -> Tensor: r""" Reconstructs an image tensor from a sequence of patches (inverse of `img2seq`). - Parameters: + Args: seq (`torch.Tensor`): Patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W // patch_size)`. @@ -550,7 +559,7 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): r""" Transformer-based 2D model for text to image generation. - Parameters: + Args: in_channels (`int`, *optional*, defaults to 16): Number of input channels in the latent image. patch_size (`int`, *optional*, defaults to 2): @@ -650,7 +659,7 @@ def __init__( self.hidden_size = hidden_size self.num_heads = num_heads - self.pe_embedder = PhotoEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim) + self.pe_embedder = PhotonEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim) self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True) self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) self.txt_in = nn.Linear(context_in_dim, self.hidden_size) @@ -683,11 +692,10 @@ def _compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> T def forward( self, - image_latent: Tensor, + hidden_states: Tensor, timestep: Tensor, - cross_attn_conditioning: Tensor, - micro_conditioning: Tensor, - cross_attn_mask: None | Tensor = None, + encoder_hidden_states: Tensor, + attention_mask: Optional[Tensor] = None, attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]: @@ -697,16 +705,14 @@ def forward( The latent image is split into patch tokens, combined with text conditioning, and processed through a stack of transformer blocks modulated by the timestep. The output is reconstructed into the latent image space. - Parameters: - image_latent (`torch.Tensor`): + Args: + hidden_states (`torch.Tensor`): Input latent image tensor of shape `(B, C, H, W)`. timestep (`torch.Tensor`): Timestep tensor of shape `(B,)` or `(1,)`, used for temporal conditioning. - cross_attn_conditioning (`torch.Tensor`): + encoder_hidden_states (`torch.Tensor`): Text conditioning tensor of shape `(B, L_txt, context_in_dim)`. - micro_conditioning (`torch.Tensor`): - Extra conditioning vector (currently unused, reserved for future use). - cross_attn_mask (`torch.Tensor`, *optional*): + attention_mask (`torch.Tensor`, *optional*): Boolean mask of shape `(B, L_txt)`, where `0` marks padding in the text sequence. attention_kwargs (`dict`, *optional*): Additional arguments passed to attention layers. @@ -719,15 +725,15 @@ def forward( - `sample` (`torch.Tensor`): Output latent image of shape `(B, C, H, W)`. """ # Process text conditioning - txt = self.txt_in(cross_attn_conditioning) + txt = self.txt_in(encoder_hidden_states) # Convert image to sequence and embed - img = img2seq(image_latent, self.patch_size) + img = img2seq(hidden_states, self.patch_size) img = self.img_in(img) # Generate positional embeddings - bs, _, h, w = image_latent.shape - img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=image_latent.device) + bs, _, h, w = hidden_states.shape + img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=hidden_states.device) pe = self.pe_embedder(img_ids) # Compute time embedding @@ -742,7 +748,7 @@ def forward( txt, vec, pe, - cross_attn_mask, + attention_mask, ) else: img = block( @@ -750,12 +756,12 @@ def forward( encoder_hidden_states=txt, temb=vec, image_rotary_emb=pe, - attention_mask=cross_attn_mask, + attention_mask=attention_mask, ) # Final layer and convert back to image img = self.final_layer(img, vec) - output = seq2img(img, self.patch_size, image_latent.shape) + output = seq2img(img, self.patch_size, hidden_states.shape) if not return_dict: return (output,) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 1fa8dcf0c8b8..a44c92a834b2 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -718,6 +718,7 @@ StableDiffusionXLPAGPipeline, ) from .paint_by_example import PaintByExamplePipeline + from .photon import PhotonPipeline from .pia import PIAPipeline from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline from .qwenimage import ( diff --git a/src/diffusers/pipelines/photon/pipeline_photon.py b/src/diffusers/pipelines/photon/pipeline_photon.py index 457bbd2223cb..b394b12d83f4 100644 --- a/src/diffusers/pipelines/photon/pipeline_photon.py +++ b/src/diffusers/pipelines/photon/pipeline_photon.py @@ -206,11 +206,11 @@ def clean_text(self, text: str) -> str: >>> from diffusers import PhotonPipeline >>> # Load pipeline with from_pretrained - >>> pipe = PhotonPipeline.from_pretrained("path/to/photon_checkpoint") + >>> pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft") >>> pipe.to("cuda") >>> prompt = "A digital painting of a rusty, vintage tram on a sandy beach" - >>> image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0] + >>> image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0] >>> image.save("photon_output.png") ``` """ @@ -717,11 +717,10 @@ def __call__( # Forward through transformer noise_pred = self.transformer( - image_latent=latents_in, + hidden_states=latents_in, timestep=t_cont, - cross_attn_conditioning=ca_embed, - micro_conditioning=None, - cross_attn_mask=ca_mask, + encoder_hidden_states=ca_embed, + attention_mask=ca_mask, return_dict=False, )[0] diff --git a/tests/pipelines/photon/test_pipeline_photon.py b/tests/pipelines/photon/test_pipeline_photon.py index 9ac361c75b2e..9c5803b5d0f4 100644 --- a/tests/pipelines/photon/test_pipeline_photon.py +++ b/tests/pipelines/photon/test_pipeline_photon.py @@ -68,28 +68,28 @@ def get_dummy_components(self): tokenizer.model_max_length = 64 torch.manual_seed(0) - - encoder_params = dict( - vocab_size=tokenizer.vocab_size, - hidden_size=8, - intermediate_size=16, - num_hidden_layers=1, - num_attention_heads=2, - num_key_value_heads=1, - head_dim=4, - max_position_embeddings=64, - layer_types=["full_attention"], - attention_bias=False, - attention_dropout=0.0, - dropout_rate=0.0, - hidden_activation="gelu_pytorch_tanh", - rms_norm_eps=1e-06, - attn_logit_softcapping=50.0, - final_logit_softcapping=30.0, - query_pre_attn_scalar=4, - rope_theta=10000.0, - sliding_window=4096, - ) + + encoder_params = { + "vocab_size": tokenizer.vocab_size, + "hidden_size": 8, + "intermediate_size": 16, + "num_hidden_layers": 1, + "num_attention_heads": 2, + "num_key_value_heads": 1, + "head_dim": 4, + "max_position_embeddings": 64, + "layer_types": ["full_attention"], + "attention_bias": False, + "attention_dropout": 0.0, + "dropout_rate": 0.0, + "hidden_activation": "gelu_pytorch_tanh", + "rms_norm_eps": 1e-06, + "attn_logit_softcapping": 50.0, + "final_logit_softcapping": 30.0, + "query_pre_attn_scalar": 4, + "rope_theta": 10000.0, + "sliding_window": 4096, + } encoder_config = T5GemmaModuleConfig(**encoder_params) text_encoder_config = T5GemmaConfig(encoder=encoder_config, is_encoder_decoder=False, **encoder_params) text_encoder = T5GemmaEncoder(text_encoder_config) From 23392b02f81e71b1e5deac96a426a1addc563f08 Mon Sep 17 00:00:00 2001 From: DavidBert Date: Sat, 18 Oct 2025 05:08:17 +0000 Subject: [PATCH 41/52] naming changes --- .../transformers/test_models_transformer_photon.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_photon.py b/tests/models/transformers/test_models_transformer_photon.py index 1491b83bf65c..f5185245d399 100644 --- a/tests/models/transformers/test_models_transformer_photon.py +++ b/tests/models/transformers/test_models_transformer_photon.py @@ -28,7 +28,7 @@ class PhotonTransformerTests(ModelTesterMixin, unittest.TestCase): model_class = PhotonTransformer2DModel - main_input_name = "image_latent" + main_input_name = "hidden_states" uses_custom_attn_processor = True @property @@ -49,16 +49,14 @@ def prepare_dummy_input(self, height=16, width=16): sequence_length = 16 embedding_dim = 1792 - image_latent = torch.randn((batch_size, num_latent_channels, height, width)).to(torch_device) - cross_attn_conditioning = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) - micro_conditioning = torch.randn((batch_size, embedding_dim)).to(torch_device) + hidden_states = torch.randn((batch_size, num_latent_channels, height, width)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) return { - "image_latent": image_latent, + "hidden_states": hidden_states, "timestep": timestep, - "cross_attn_conditioning": cross_attn_conditioning, - "micro_conditioning": micro_conditioning, + "encoder_hidden_states": encoder_hidden_states, } def prepare_init_args_and_inputs_for_common(self): From 1f88313aef47e07052096749aabce47023470b3f Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Mon, 20 Oct 2025 07:19:45 +0000 Subject: [PATCH 42/52] make fix copy --- .../utils/dummy_torch_and_transformers_objects.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 3244ef12ef87..52c72579cd20 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1847,6 +1847,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class PhotonPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class PIAPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 167030cd127d468123cea1ac024bae2a96ba64ee Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Mon, 20 Oct 2025 09:21:14 +0200 Subject: [PATCH 43/52] Update docs/source/en/api/pipelines/photon.md Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- docs/source/en/api/pipelines/photon.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/photon.md b/docs/source/en/api/pipelines/photon.md index f9d6ba5a1792..293e05f0fdef 100644 --- a/docs/source/en/api/pipelines/photon.md +++ b/docs/source/en/api/pipelines/photon.md @@ -27,7 +27,7 @@ Photon offers multiple variants with different VAE configurations, each optimize | [`Photoroom/photon-256-t2i`](https://huggingface.co/Photoroom/photon-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | | [`Photoroom/photon-256-t2i-sft`](https://huggingface.co/Photoroom/photon-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `torch.bfloat16` | | [`Photoroom/photon-512-t2i`](https://huggingface.co/Photoroom/photon-512-t2i)| 512 | No | No | Base model pre-trained at 512 with Flux VAE |Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | -| [`Photoroom/photon-512-t2i-sft`](hhttps://huggingface.co/Photoroom/photon-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | +| [`Photoroom/photon-512-t2i-sft`](https://huggingface.co/Photoroom/photon-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | | [`Photoroom/photon-512-t2i-sft-distilled`](https://huggingface.co/Photoroom/photon-512-t2i-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/photon-512-t2i-sft`](https://huggingface.co/Photoroom/photon-512-t2i-sft) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` | | [`Photoroom/photon-512-t2i-dc-ae`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae)| 512 | No | No | Base model pre-trained at 512 with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae)|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | | [`Photoroom/photon-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae) | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | From 8b551d24764d235dcd1a13cacd4632ef67c54a08 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Mon, 20 Oct 2025 07:27:03 +0000 Subject: [PATCH 44/52] Add PhotonTransformer2DModel to TYPE_CHECKING imports --- src/diffusers/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 28b2ae25499a..b7086d2e0c44 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -928,6 +928,7 @@ MultiControlNetModel, OmniGenTransformer2DModel, ParallelConfig, + PhotonTransformer2DModel, PixArtTransformer2DModel, PriorTransformer, QwenImageControlNetModel, From 7f6bb8a4bfa2509f8dd97b30bfa64e0b5c28a43f Mon Sep 17 00:00:00 2001 From: DavidBert Date: Mon, 20 Oct 2025 23:10:51 +0000 Subject: [PATCH 45/52] make fix-copies --- src/diffusers/utils/dummy_pt_objects.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 5d62709c28fd..d379a5d4a77c 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1098,6 +1098,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class PhotonTransformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class PixArtTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] From 4264606609011bbb50b15ae60c224be6763b6ca4 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Tue, 21 Oct 2025 09:18:04 +0200 Subject: [PATCH 46/52] Use Tuple instead of tuple Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- src/diffusers/models/transformers/transformer_photon.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_photon.py b/src/diffusers/models/transformers/transformer_photon.py index 1a40a829719e..c5809bc2c094 100644 --- a/src/diffusers/models/transformers/transformer_photon.py +++ b/src/diffusers/models/transformers/transformer_photon.py @@ -340,7 +340,7 @@ def __init__(self, dim: int): nn.init.constant_(self.lin.weight, 0) nn.init.constant_(self.lin.bias, 0) - def forward(self, vec: Tensor) -> tuple[tuple[Tensor, Tensor, Tensor], tuple[Tensor, Tensor, Tensor]]: + def forward(self, vec: Tensor) -> Tuple[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]]: out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(6, dim=-1) return tuple(out[:3]), tuple(out[3:]) From 756fe953053e55b373a3868098f572343b8a32b9 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Tue, 21 Oct 2025 09:19:10 +0200 Subject: [PATCH 47/52] restrict the version of transformers Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- tests/pipelines/photon/test_pipeline_photon.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/pipelines/photon/test_pipeline_photon.py b/tests/pipelines/photon/test_pipeline_photon.py index 9c5803b5d0f4..24d96b9b420a 100644 --- a/tests/pipelines/photon/test_pipeline_photon.py +++ b/tests/pipelines/photon/test_pipeline_photon.py @@ -15,6 +15,7 @@ from ..test_pipelines_common import PipelineTesterMixin +@pytest.mark.xfail(condition=is_transformers_version(">", "4.57.1"), reason="See https://github.com/huggingface/diffusers/pull/12456#issuecomment-3424228544", strict=False) class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = PhotonPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} From 4402b41c50f7abf912b94ae74fae76492a6fb3de Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Tue, 21 Oct 2025 09:19:46 +0200 Subject: [PATCH 48/52] Update tests/pipelines/photon/test_pipeline_photon.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- tests/pipelines/photon/test_pipeline_photon.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/pipelines/photon/test_pipeline_photon.py b/tests/pipelines/photon/test_pipeline_photon.py index 24d96b9b420a..7910d6064be8 100644 --- a/tests/pipelines/photon/test_pipeline_photon.py +++ b/tests/pipelines/photon/test_pipeline_photon.py @@ -1,6 +1,7 @@ import unittest import numpy as np +import pytest import torch from transformers import AutoTokenizer from transformers.models.t5gemma.configuration_t5gemma import T5GemmaConfig, T5GemmaModuleConfig From 3b953ff04973554e47f379aaf1fa92f644f0db70 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Tue, 21 Oct 2025 09:20:14 +0200 Subject: [PATCH 49/52] Update tests/pipelines/photon/test_pipeline_photon.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- tests/pipelines/photon/test_pipeline_photon.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/pipelines/photon/test_pipeline_photon.py b/tests/pipelines/photon/test_pipeline_photon.py index 7910d6064be8..0267ebeda23c 100644 --- a/tests/pipelines/photon/test_pipeline_photon.py +++ b/tests/pipelines/photon/test_pipeline_photon.py @@ -11,6 +11,7 @@ from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel from diffusers.pipelines.photon.pipeline_photon import PhotonPipeline from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import is_transformers_version from ..pipeline_params import TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import PipelineTesterMixin From fda1f78291144dd2c34b09e199d081f036fe9dab Mon Sep 17 00:00:00 2001 From: DavidBert Date: Tue, 21 Oct 2025 14:33:51 +0000 Subject: [PATCH 50/52] change | for Optional --- src/diffusers/models/transformers/transformer_photon.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_photon.py b/src/diffusers/models/transformers/transformer_photon.py index c5809bc2c094..afa525b0443e 100644 --- a/src/diffusers/models/transformers/transformer_photon.py +++ b/src/diffusers/models/transformers/transformer_photon.py @@ -383,7 +383,7 @@ def __init__( hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, - qk_scale: float | None = None, + qk_scale: Optional[float] = None, ): super().__init__() @@ -424,7 +424,7 @@ def forward( encoder_hidden_states: Tensor, temb: Tensor, image_rotary_emb: Tensor, - attention_mask: Tensor | None = None, + attention_mask: Optional[Tensor] = None, **kwargs: dict[str, Any], ) -> Tensor: r""" From 3e2d292ea5e3ab044892322da0651b85a268f44b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 21 Oct 2025 04:31:48 -1000 Subject: [PATCH 51/52] fix nits. --- .../models/transformers/transformer_photon.py | 46 ++++++++++--------- .../pipelines/photon/pipeline_photon.py | 2 +- .../pipelines/photon/test_pipeline_photon.py | 6 ++- 3 files changed, 30 insertions(+), 24 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_photon.py b/src/diffusers/models/transformers/transformer_photon.py index afa525b0443e..f5c1291e4640 100644 --- a/src/diffusers/models/transformers/transformer_photon.py +++ b/src/diffusers/models/transformers/transformer_photon.py @@ -1,4 +1,4 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. +# Copyright 2025 The Photoroom and The HuggingFace Teams. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,7 +15,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch -from torch import Tensor, nn +from torch import nn from torch.nn.functional import fold, unfold from ...configuration_utils import ConfigMixin, register_to_config @@ -31,7 +31,7 @@ logger = logging.get_logger(__name__) -def get_image_ids(batch_size: int, height: int, width: int, patch_size: int, device: torch.device) -> Tensor: +def get_image_ids(batch_size: int, height: int, width: int, patch_size: int, device: torch.device) -> torch.Tensor: r""" Generates 2D patch coordinate indices for a batch of images. @@ -59,7 +59,7 @@ def get_image_ids(batch_size: int, height: int, width: int, patch_size: int, dev return img_ids.reshape((height // patch_size) * (width // patch_size), 2).unsqueeze(0).repeat(batch_size, 1, 1) -def apply_rope(xq: Tensor, freqs_cis: Tensor) -> Tensor: +def apply_rope(xq: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: r""" Applies rotary positional embeddings (RoPE) to a query tensor. @@ -273,7 +273,7 @@ def __init__(self, dim: int, theta: int, axes_dim: List[int]): self.theta = theta self.axes_dim = axes_dim - def rope(self, pos: Tensor, dim: int, theta: int) -> Tensor: + def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: assert dim % 2 == 0 scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim omega = 1.0 / (theta**scale) @@ -284,7 +284,7 @@ def rope(self, pos: Tensor, dim: int, theta: int) -> Tensor: out = out.reshape(*out.shape[:-1], 2, 2) return out.float() - def forward(self, ids: Tensor) -> Tensor: + def forward(self, ids: torch.Tensor) -> torch.Tensor: n_axes = ids.shape[-1] emb = torch.cat( [self.rope(ids[:, :, i], self.axes_dim[i], self.theta) for i in range(n_axes)], @@ -314,7 +314,7 @@ def __init__(self, in_dim: int, hidden_dim: int): self.silu = nn.SiLU() self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.out_layer(self.silu(self.in_layer(x))) @@ -340,7 +340,9 @@ def __init__(self, dim: int): nn.init.constant_(self.lin.weight, 0) nn.init.constant_(self.lin.bias, 0) - def forward(self, vec: Tensor) -> Tuple[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]]: + def forward( + self, vec: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(6, dim=-1) return tuple(out[:3]), tuple(out[3:]) @@ -420,13 +422,13 @@ def __init__( def forward( self, - hidden_states: Tensor, - encoder_hidden_states: Tensor, - temb: Tensor, - image_rotary_emb: Tensor, - attention_mask: Optional[Tensor] = None, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, **kwargs: dict[str, Any], - ) -> Tensor: + ) -> torch.Tensor: r""" Runs modulation-gated cross-attention and MLP, with residual connections. @@ -503,14 +505,14 @@ def __init__(self, hidden_size: int, patch_size: int, out_channels: int): self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) - def forward(self, x: Tensor, vec: Tensor) -> Tensor: + def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor: shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] x = self.linear(x) return x -def img2seq(img: Tensor, patch_size: int) -> Tensor: +def img2seq(img: torch.Tensor, patch_size: int) -> torch.Tensor: r""" Flattens an image tensor into a sequence of non-overlapping patches. @@ -528,7 +530,7 @@ def img2seq(img: Tensor, patch_size: int) -> Tensor: return unfold(img, kernel_size=patch_size, stride=patch_size).transpose(1, 2) -def seq2img(seq: Tensor, patch_size: int, shape: Tensor) -> Tensor: +def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Tensor: r""" Reconstructs an image tensor from a sequence of patches (inverse of `img2seq`). @@ -679,7 +681,7 @@ def __init__( self.gradient_checkpointing = False - def _compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> Tensor: + def _compute_timestep_embedding(self, timestep: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: return self.time_in( get_timestep_embedding( timesteps=timestep, @@ -692,10 +694,10 @@ def _compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> T def forward( self, - hidden_states: Tensor, - timestep: Tensor, - encoder_hidden_states: Tensor, - attention_mask: Optional[Tensor] = None, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]: diff --git a/src/diffusers/pipelines/photon/pipeline_photon.py b/src/diffusers/pipelines/photon/pipeline_photon.py index b394b12d83f4..4a10899ede61 100644 --- a/src/diffusers/pipelines/photon/pipeline_photon.py +++ b/src/diffusers/pipelines/photon/pipeline_photon.py @@ -1,4 +1,4 @@ -# Copyright 2025 The Photoroom and the HuggingFace Teams. All rights reserved. +# Copyright 2025 The Photoroom and The HuggingFace Teams. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/pipelines/photon/test_pipeline_photon.py b/tests/pipelines/photon/test_pipeline_photon.py index 0267ebeda23c..c29c6ce0b0dd 100644 --- a/tests/pipelines/photon/test_pipeline_photon.py +++ b/tests/pipelines/photon/test_pipeline_photon.py @@ -17,7 +17,11 @@ from ..test_pipelines_common import PipelineTesterMixin -@pytest.mark.xfail(condition=is_transformers_version(">", "4.57.1"), reason="See https://github.com/huggingface/diffusers/pull/12456#issuecomment-3424228544", strict=False) +@pytest.mark.xfail( + condition=is_transformers_version(">", "4.57.1"), + reason="See https://github.com/huggingface/diffusers/pull/12456#issuecomment-3424228544", + strict=False, +) class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = PhotonPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} From 803d0d1e7efb4a7e998f20ce963ad633e3fb4cab Mon Sep 17 00:00:00 2001 From: DavidBert Date: Tue, 21 Oct 2025 15:02:34 +0000 Subject: [PATCH 52/52] use typing Dict --- src/diffusers/models/transformers/transformer_photon.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_photon.py b/src/diffusers/models/transformers/transformer_photon.py index f5c1291e4640..6314020c1c74 100644 --- a/src/diffusers/models/transformers/transformer_photon.py +++ b/src/diffusers/models/transformers/transformer_photon.py @@ -427,7 +427,7 @@ def forward( temb: torch.Tensor, image_rotary_emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - **kwargs: dict[str, Any], + **kwargs: Dict[str, Any], ) -> torch.Tensor: r""" Runs modulation-gated cross-attention and MLP, with residual connections.