From b7fb0fe9d63bf766bbe3c42ac154a043796dd370 Mon Sep 17 00:00:00 2001 From: DavidBert Date: Tue, 21 Oct 2025 15:38:55 +0000 Subject: [PATCH 1/7] rename photon to prx --- .gitignore | 24 ++++++++- docs/source/en/_toctree.yml | 4 +- docs/source/en/api/pipelines/photon.md | 54 +++++++++---------- scripts/convert_photon_to_diffusers.py | 38 ++++++------- src/diffusers/__init__.py | 8 +-- src/diffusers/models/__init__.py | 4 +- src/diffusers/models/transformers/__init__.py | 2 +- src/diffusers/pipelines/__init__.py | 4 +- src/diffusers/pipelines/photon/__init__.py | 8 +-- .../pipelines/photon/pipeline_output.py | 4 +- src/diffusers/utils/dummy_pt_objects.py | 2 +- .../dummy_torch_and_transformers_objects.py | 2 +- 12 files changed, 88 insertions(+), 66 deletions(-) diff --git a/.gitignore b/.gitignore index a55026febd5a..c158ffe1378b 100644 --- a/.gitignore +++ b/.gitignore @@ -178,4 +178,26 @@ tags .ruff_cache # wandb -wandb \ No newline at end of file +wandb +convert_checkpoints.py +dcae_mirage_generated_image_.png +dcae_prx_generated_image.png +example_usage.py +META_TENSOR_FIX.md +mirage_generated_image__.png +mirage_generated_image_.png +prx_generated_image.png +plan.md +test_existing_checkpoints_with_timestep_change.py +test_timestep_embedding.py +test_updated_checkpoint.png +test_updated_checkpoint.py +testhf.ipynb +update_checkpoint_parameters.py +verify_checkpoint_parameters.py +for_claude/mirage_layers.py +for_claude/mirage.py +for_claude/text_tower.py +for_claude/vae_tower.py +prx_/prx_layers.py +prx_/prx.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 3abe89437fa5..2e81200c216c 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -541,8 +541,8 @@ title: PAG - local: api/pipelines/paint_by_example title: Paint by Example - - local: api/pipelines/photon - title: Photon + - local: api/pipelines/prx + title: PRX - local: api/pipelines/pixart title: PixArt-α - local: api/pipelines/pixart_sigma diff --git a/docs/source/en/api/pipelines/photon.md b/docs/source/en/api/pipelines/photon.md index 293e05f0fdef..16670f4bfc86 100644 --- a/docs/source/en/api/pipelines/photon.md +++ b/docs/source/en/api/pipelines/photon.md @@ -12,43 +12,43 @@ # See the License for the specific language governing permissions and # limitations under the License. --> -# Photon +# PRX -Photon generates high-quality images from text using a simplified MMDIT architecture where text tokens don't update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google's T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing. +PRX generates high-quality images from text using a simplified MMDIT architecture where text tokens don't update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google's T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing. ## Available models -Photon offers multiple variants with different VAE configurations, each optimized for specific resolutions. Base models excel with detailed prompts, capturing complex compositions and subtle details. Fine-tuned models trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) improve aesthetic quality, especially with simpler prompts. +PRX offers multiple variants with different VAE configurations, each optimized for specific resolutions. Base models excel with detailed prompts, capturing complex compositions and subtle details. Fine-tuned models trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) improve aesthetic quality, especially with simpler prompts. | Model | Resolution | Fine-tuned | Distilled | Description | Suggested prompts | Suggested parameters | Recommended dtype | |:-----:|:-----------------:|:----------:|:----------:|:----------:|:----------:|:----------:|:----------:| -| [`Photoroom/photon-256-t2i`](https://huggingface.co/Photoroom/photon-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | -| [`Photoroom/photon-256-t2i-sft`](https://huggingface.co/Photoroom/photon-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `torch.bfloat16` | -| [`Photoroom/photon-512-t2i`](https://huggingface.co/Photoroom/photon-512-t2i)| 512 | No | No | Base model pre-trained at 512 with Flux VAE |Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | -| [`Photoroom/photon-512-t2i-sft`](https://huggingface.co/Photoroom/photon-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | -| [`Photoroom/photon-512-t2i-sft-distilled`](https://huggingface.co/Photoroom/photon-512-t2i-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/photon-512-t2i-sft`](https://huggingface.co/Photoroom/photon-512-t2i-sft) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` | -| [`Photoroom/photon-512-t2i-dc-ae`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae)| 512 | No | No | Base model pre-trained at 512 with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae)|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | -| [`Photoroom/photon-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae) | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | -| [`Photoroom/photon-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/photon-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft-distilled) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |s +| [`Photoroom/prx-256-t2i`](https://huggingface.co/Photoroom/prx-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | +| [`Photoroom/prx-256-t2i-sft`](https://huggingface.co/Photoroom/prx-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `torch.bfloat16` | +| [`Photoroom/prx-512-t2i`](https://huggingface.co/Photoroom/prx-512-t2i)| 512 | No | No | Base model pre-trained at 512 with Flux VAE |Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | +| [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | +| [`Photoroom/prx-512-t2i-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` | +| [`Photoroom/prx-512-t2i-dc-ae`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae)| 512 | No | No | Base model pre-trained at 512 with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae)|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | +| [`Photoroom/prx-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae) | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | +| [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |s -Refer to [this](https://huggingface.co/collections/Photoroom/photon-models-68e66254c202ebfab99ad38e) collection for more information. +Refer to [this](https://huggingface.co/collections/Photoroom/prx-models-68e66254c202ebfab99ad38e) collection for more information. ## Loading the pipeline Load the pipeline with [`~DiffusionPipeline.from_pretrained`]. ```py -from diffusers.pipelines.photon import PhotonPipeline +from diffusers.pipelines.prx import PRXPipeline # Load pipeline - VAE and text encoder will be loaded from HuggingFace -pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft", torch_dtype=torch.bfloat16) +pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16) pipe.to("cuda") prompt = "A front-facing portrait of a lion the golden savanna at sunset." image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0] -image.save("photon_output.png") +image.save("prx_output.png") ``` ### Manual Component Loading @@ -57,9 +57,9 @@ Load components individually to customize the pipeline for instance to use quant ```py import torch -from diffusers.pipelines.photon import PhotonPipeline +from diffusers.pipelines.prx import PRXPipeline from diffusers.models import AutoencoderKL, AutoencoderDC -from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel +from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from transformers import T5GemmaModel, GemmaTokenizerFast from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig @@ -67,8 +67,8 @@ from transformers import BitsAndBytesConfig as BitsAndBytesConfig quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True) # Load transformer -transformer = PhotonTransformer2DModel.from_pretrained( - "checkpoints/photon-512-t2i-sft", +transformer = PRXTransformer2DModel.from_pretrained( + "checkpoints/prx-512-t2i-sft", subfolder="transformer", quantization_config=quant_config, torch_dtype=torch.bfloat16, @@ -76,7 +76,7 @@ transformer = PhotonTransformer2DModel.from_pretrained( # Load scheduler scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( - "checkpoints/photon-512-t2i-sft", subfolder="scheduler" + "checkpoints/prx-512-t2i-sft", subfolder="scheduler" ) # Load T5Gemma text encoder @@ -94,7 +94,7 @@ vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", quantization_config=quant_config, torch_dtype=torch.bfloat16) -pipe = PhotonPipeline( +pipe = PRXPipeline( transformer=transformer, scheduler=scheduler, text_encoder=text_encoder, @@ -111,21 +111,21 @@ For memory-constrained environments: ```py import torch -from diffusers.pipelines.photon import PhotonPipeline +from diffusers.pipelines.prx import PRXPipeline -pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft", torch_dtype=torch.bfloat16) +pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16) pipe.enable_model_cpu_offload() # Offload components to CPU when not in use # Or use sequential CPU offload for even lower memory pipe.enable_sequential_cpu_offload() ``` -## PhotonPipeline +## PRXPipeline -[[autodoc]] PhotonPipeline +[[autodoc]] PRXPipeline - all - __call__ -## PhotonPipelineOutput +## PRXPipelineOutput -[[autodoc]] pipelines.photon.pipeline_output.PhotonPipelineOutput +[[autodoc]] pipelines.prx.pipeline_output.PRXPipelineOutput diff --git a/scripts/convert_photon_to_diffusers.py b/scripts/convert_photon_to_diffusers.py index c66bc314181f..d9bde2f34d56 100644 --- a/scripts/convert_photon_to_diffusers.py +++ b/scripts/convert_photon_to_diffusers.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 """ -Script to convert Photon checkpoint from original codebase to diffusers format. +Script to convert PRX checkpoint from original codebase to diffusers format. """ import argparse @@ -13,15 +13,15 @@ import torch from safetensors.torch import save_file -from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel -from diffusers.pipelines.photon import PhotonPipeline +from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel +from diffusers.pipelines.prx import PRXPipeline DEFAULT_RESOLUTION = 512 @dataclass(frozen=True) -class PhotonBase: +class PRXBase: context_in_dim: int = 2304 hidden_size: int = 1792 mlp_ratio: float = 3.5 @@ -34,22 +34,22 @@ class PhotonBase: @dataclass(frozen=True) -class PhotonFlux(PhotonBase): +class PRXFlux(PRXBase): in_channels: int = 16 patch_size: int = 2 @dataclass(frozen=True) -class PhotonDCAE(PhotonBase): +class PRXDCAE(PRXBase): in_channels: int = 32 patch_size: int = 1 def build_config(vae_type: str) -> Tuple[dict, int]: if vae_type == "flux": - cfg = PhotonFlux() + cfg = PRXFlux() elif vae_type == "dc-ae": - cfg = PhotonDCAE() + cfg = PRXDCAE() else: raise ValueError(f"Unsupported VAE type: {vae_type}. Use 'flux' or 'dc-ae'") @@ -64,7 +64,7 @@ def create_parameter_mapping(depth: int) -> dict: # Key mappings for structural changes mapping = {} - # Map old structure (layers in PhotonBlock) to new structure (layers in PhotonAttention) + # Map old structure (layers in PRXBlock) to new structure (layers in PRXAttention) for i in range(depth): # QKV projections moved to attention module mapping[f"blocks.{i}.img_qkv_proj.weight"] = f"blocks.{i}.attention.img_qkv_proj.weight" @@ -108,8 +108,8 @@ def convert_checkpoint_parameters(old_state_dict: Dict[str, torch.Tensor], depth return converted_state_dict -def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> PhotonTransformer2DModel: - """Create and load PhotonTransformer2DModel from old checkpoint.""" +def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> PRXTransformer2DModel: + """Create and load PRXTransformer2DModel from old checkpoint.""" print(f"Loading checkpoint from: {checkpoint_path}") @@ -137,8 +137,8 @@ def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> Ph converted_state_dict = convert_checkpoint_parameters(state_dict, depth=model_depth) # Create transformer with config - print("Creating PhotonTransformer2DModel...") - transformer = PhotonTransformer2DModel(**config) + print("Creating PRXTransformer2DModel...") + transformer = PRXTransformer2DModel(**config) # Load state dict print("Loading converted parameters...") @@ -221,14 +221,14 @@ def create_model_index(vae_type: str, default_image_size: int, output_path: str) vae_class = "AutoencoderDC" model_index = { - "_class_name": "PhotonPipeline", + "_class_name": "PRXPipeline", "_diffusers_version": "0.31.0.dev0", "_name_or_path": os.path.basename(output_path), "default_sample_size": default_image_size, "scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"], - "text_encoder": ["photon", "T5GemmaEncoder"], + "text_encoder": ["prx", "T5GemmaEncoder"], "tokenizer": ["transformers", "GemmaTokenizerFast"], - "transformer": ["diffusers", "PhotonTransformer2DModel"], + "transformer": ["diffusers", "PRXTransformer2DModel"], "vae": ["diffusers", vae_class], } @@ -275,7 +275,7 @@ def main(args): # Verify the pipeline can be loaded try: - pipeline = PhotonPipeline.from_pretrained(args.output_path) + pipeline = PRXPipeline.from_pretrained(args.output_path) print("Pipeline loaded successfully!") print(f"Transformer: {type(pipeline.transformer).__name__}") print(f"VAE: {type(pipeline.vae).__name__}") @@ -298,10 +298,10 @@ def main(args): if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Convert Photon checkpoint to diffusers format") + parser = argparse.ArgumentParser(description="Convert PRX checkpoint to diffusers format") parser.add_argument( - "--checkpoint_path", type=str, required=True, help="Path to the original Photon checkpoint (.pth file )" + "--checkpoint_path", type=str, required=True, help="Path to the original PRX checkpoint (.pth file )" ) parser.add_argument( diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index b7086d2e0c44..47285f37d91a 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -232,7 +232,7 @@ "MultiControlNetModel", "OmniGenTransformer2DModel", "ParallelConfig", - "PhotonTransformer2DModel", + "PRXTransformer2DModel", "PixArtTransformer2DModel", "PriorTransformer", "QwenImageControlNetModel", @@ -516,7 +516,7 @@ "MusicLDMPipeline", "OmniGenPipeline", "PaintByExamplePipeline", - "PhotonPipeline", + "PRXPipeline", "PIAPipeline", "PixArtAlphaPipeline", "PixArtSigmaPAGPipeline", @@ -928,7 +928,7 @@ MultiControlNetModel, OmniGenTransformer2DModel, ParallelConfig, - PhotonTransformer2DModel, + PRXTransformer2DModel, PixArtTransformer2DModel, PriorTransformer, QwenImageControlNetModel, @@ -1182,7 +1182,7 @@ MusicLDMPipeline, OmniGenPipeline, PaintByExamplePipeline, - PhotonPipeline, + PRXPipeline, PIAPipeline, PixArtAlphaPipeline, PixArtSigmaPAGPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 2151e602b2e2..6d08c2f2e23b 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -96,7 +96,7 @@ _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"] - _import_structure["transformers.transformer_photon"] = ["PhotonTransformer2DModel"] + _import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"] _import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] _import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"] @@ -191,7 +191,7 @@ LuminaNextDiT2DModel, MochiTransformer3DModel, OmniGenTransformer2DModel, - PhotonTransformer2DModel, + PRXTransformer2DModel, PixArtTransformer2DModel, PriorTransformer, QwenImageTransformer2DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index ab5311518ba7..d8c3d9b57273 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -32,7 +32,7 @@ from .transformer_lumina2 import Lumina2Transformer2DModel from .transformer_mochi import MochiTransformer3DModel from .transformer_omnigen import OmniGenTransformer2DModel - from .transformer_photon import PhotonTransformer2DModel + from .transformer_prx import PRXTransformer2DModel from .transformer_qwenimage import QwenImageTransformer2DModel from .transformer_sd3 import SD3Transformer2DModel from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index a44c92a834b2..953f307fe1de 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -144,7 +144,7 @@ "FluxKontextPipeline", "FluxKontextInpaintPipeline", ] - _import_structure["photon"] = ["PhotonPipeline"] + _import_structure["prx"] = ["PRXPipeline"] _import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm2"] = [ "AudioLDM2Pipeline", @@ -718,7 +718,7 @@ StableDiffusionXLPAGPipeline, ) from .paint_by_example import PaintByExamplePipeline - from .photon import PhotonPipeline + from .prx import PRXPipeline from .pia import PIAPipeline from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline from .qwenimage import ( diff --git a/src/diffusers/pipelines/photon/__init__.py b/src/diffusers/pipelines/photon/__init__.py index e21e31d4225f..87aaefbd1368 100644 --- a/src/diffusers/pipelines/photon/__init__.py +++ b/src/diffusers/pipelines/photon/__init__.py @@ -12,7 +12,7 @@ _dummy_objects = {} _additional_imports = {} -_import_structure = {"pipeline_output": ["PhotonPipelineOutput"]} +_import_structure = {"pipeline_output": ["PRXPipelineOutput"]} try: if not (is_transformers_available() and is_torch_available()): @@ -22,7 +22,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - _import_structure["pipeline_photon"] = ["PhotonPipeline"] + _import_structure["pipeline_prx"] = ["PRXPipeline"] # Import T5GemmaEncoder for pipeline loading compatibility try: @@ -44,8 +44,8 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: - from .pipeline_output import PhotonPipelineOutput - from .pipeline_photon import PhotonPipeline + from .pipeline_output import PRXPipelineOutput + from .pipeline_prx import PRXPipeline else: import sys diff --git a/src/diffusers/pipelines/photon/pipeline_output.py b/src/diffusers/pipelines/photon/pipeline_output.py index d4b0ff462983..ea1bc9bf418a 100644 --- a/src/diffusers/pipelines/photon/pipeline_output.py +++ b/src/diffusers/pipelines/photon/pipeline_output.py @@ -22,9 +22,9 @@ @dataclass -class PhotonPipelineOutput(BaseOutput): +class PRXPipelineOutput(BaseOutput): """ - Output class for Photon pipelines. + Output class for PRX pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index d379a5d4a77c..417461ea459b 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1098,7 +1098,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class PhotonTransformer2DModel(metaclass=DummyObject): +class PRXTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 52c72579cd20..225cae9a1e21 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1847,7 +1847,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class PhotonPipeline(metaclass=DummyObject): +class PRXPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): From 9fd52e87ca9b75ceffc30b4ef98df85d04de43a1 Mon Sep 17 00:00:00 2001 From: DavidBert Date: Tue, 21 Oct 2025 20:58:33 +0000 Subject: [PATCH 2/7] rename photon into prx --- docs/source/en/_toctree.yml | 4 +- .../en/api/pipelines/{photon.md => prx.md} | 0 ...ffusers.py => convert_prx_to_diffusers.py} | 0 src/diffusers/__init__.py | 8 ++-- src/diffusers/models/__init__.py | 2 +- ...ansformer_photon.py => transformer_prx.py} | 46 +++++++++---------- src/diffusers/pipelines/__init__.py | 2 +- .../pipelines/{photon => prx}/__init__.py | 0 .../{photon => prx}/pipeline_output.py | 0 .../pipeline_prx.py} | 35 +++++++------- ...oton.py => test_models_transformer_prx.py} | 8 ++-- tests/pipelines/{photon => prx}/__init__.py | 0 .../test_pipeline_prx.py} | 26 +++++------ 13 files changed, 65 insertions(+), 66 deletions(-) rename docs/source/en/api/pipelines/{photon.md => prx.md} (100%) rename scripts/{convert_photon_to_diffusers.py => convert_prx_to_diffusers.py} (100%) rename src/diffusers/models/transformers/{transformer_photon.py => transformer_prx.py} (95%) rename src/diffusers/pipelines/{photon => prx}/__init__.py (100%) rename src/diffusers/pipelines/{photon => prx}/pipeline_output.py (100%) rename src/diffusers/pipelines/{photon/pipeline_photon.py => prx/pipeline_prx.py} (96%) rename tests/models/transformers/{test_models_transformer_photon.py => test_models_transformer_prx.py} (90%) rename tests/pipelines/{photon => prx}/__init__.py (100%) rename tests/pipelines/{photon/test_pipeline_photon.py => prx/test_pipeline_prx.py} (90%) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 2e81200c216c..1e5c4fe5501a 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -541,12 +541,12 @@ title: PAG - local: api/pipelines/paint_by_example title: Paint by Example - - local: api/pipelines/prx - title: PRX - local: api/pipelines/pixart title: PixArt-α - local: api/pipelines/pixart_sigma title: PixArt-Σ + - local: api/pipelines/prx + title: PRX - local: api/pipelines/qwenimage title: QwenImage - local: api/pipelines/sana diff --git a/docs/source/en/api/pipelines/photon.md b/docs/source/en/api/pipelines/prx.md similarity index 100% rename from docs/source/en/api/pipelines/photon.md rename to docs/source/en/api/pipelines/prx.md diff --git a/scripts/convert_photon_to_diffusers.py b/scripts/convert_prx_to_diffusers.py similarity index 100% rename from scripts/convert_photon_to_diffusers.py rename to scripts/convert_prx_to_diffusers.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 47285f37d91a..d9c6ba9bbf5d 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -232,9 +232,9 @@ "MultiControlNetModel", "OmniGenTransformer2DModel", "ParallelConfig", - "PRXTransformer2DModel", "PixArtTransformer2DModel", "PriorTransformer", + "PRXTransformer2DModel", "QwenImageControlNetModel", "QwenImageMultiControlNetModel", "QwenImageTransformer2DModel", @@ -516,11 +516,11 @@ "MusicLDMPipeline", "OmniGenPipeline", "PaintByExamplePipeline", - "PRXPipeline", "PIAPipeline", "PixArtAlphaPipeline", "PixArtSigmaPAGPipeline", "PixArtSigmaPipeline", + "PRXPipeline", "QwenImageControlNetInpaintPipeline", "QwenImageControlNetPipeline", "QwenImageEditInpaintPipeline", @@ -928,9 +928,9 @@ MultiControlNetModel, OmniGenTransformer2DModel, ParallelConfig, - PRXTransformer2DModel, PixArtTransformer2DModel, PriorTransformer, + PRXTransformer2DModel, QwenImageControlNetModel, QwenImageMultiControlNetModel, QwenImageTransformer2DModel, @@ -1182,11 +1182,11 @@ MusicLDMPipeline, OmniGenPipeline, PaintByExamplePipeline, - PRXPipeline, PIAPipeline, PixArtAlphaPipeline, PixArtSigmaPAGPipeline, PixArtSigmaPipeline, + PRXPipeline, QwenImageControlNetInpaintPipeline, QwenImageControlNetPipeline, QwenImageEditInpaintPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 6d08c2f2e23b..532db76a09e0 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -191,9 +191,9 @@ LuminaNextDiT2DModel, MochiTransformer3DModel, OmniGenTransformer2DModel, - PRXTransformer2DModel, PixArtTransformer2DModel, PriorTransformer, + PRXTransformer2DModel, QwenImageTransformer2DModel, SanaTransformer2DModel, SD3Transformer2DModel, diff --git a/src/diffusers/models/transformers/transformer_photon.py b/src/diffusers/models/transformers/transformer_prx.py similarity index 95% rename from src/diffusers/models/transformers/transformer_photon.py rename to src/diffusers/models/transformers/transformer_prx.py index 6314020c1c74..9b2664b9cb26 100644 --- a/src/diffusers/models/transformers/transformer_photon.py +++ b/src/diffusers/models/transformers/transformer_prx.py @@ -80,9 +80,9 @@ def apply_rope(xq: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: return xq_out.reshape(*xq.shape).type_as(xq) -class PhotonAttnProcessor2_0: +class PRXAttnProcessor2_0: r""" - Processor for implementing Photon-style attention with multi-source tokens and RoPE. Supports multiple attention + Processor for implementing PRX-style attention with multi-source tokens and RoPE. Supports multiple attention backends (Flash Attention, Sage Attention, etc.) via dispatch_attention_fn. """ @@ -91,11 +91,11 @@ class PhotonAttnProcessor2_0: def __init__(self): if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): - raise ImportError("PhotonAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.") + raise ImportError("PRXAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.") def __call__( self, - attn: "PhotonAttention", + attn: "PRXAttention", hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, @@ -103,10 +103,10 @@ def __call__( **kwargs, ) -> torch.Tensor: """ - Apply Photon attention using PhotonAttention module. + Apply PRX attention using PRXAttention module. Args: - attn: PhotonAttention module containing projection layers + attn: PRXAttention module containing projection layers hidden_states: Image tokens [B, L_img, D] encoder_hidden_states: Text tokens [B, L_txt, D] attention_mask: Boolean mask for text tokens [B, L_txt] @@ -114,7 +114,7 @@ def __call__( """ if encoder_hidden_states is None: - raise ValueError("PhotonAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens.") + raise ValueError("PRXAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens.") # Project image tokens to Q, K, V img_qkv = attn.img_qkv_proj(hidden_states) @@ -190,14 +190,14 @@ def __call__( return attn_output -class PhotonAttention(nn.Module, AttentionModuleMixin): +class PRXAttention(nn.Module, AttentionModuleMixin): r""" - Photon-style attention module that handles multi-source tokens and RoPE. Similar to FluxAttention but adapted for - Photon's architecture. + PRX-style attention module that handles multi-source tokens and RoPE. Similar to FluxAttention but adapted for + PRX's architecture. """ - _default_processor_cls = PhotonAttnProcessor2_0 - _available_processors = [PhotonAttnProcessor2_0] + _default_processor_cls = PRXAttnProcessor2_0 + _available_processors = [PRXAttnProcessor2_0] def __init__( self, @@ -251,7 +251,7 @@ def forward( # inspired from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py -class PhotonEmbedND(nn.Module): +class PRXEmbedND(nn.Module): r""" N-dimensional rotary positional embedding. @@ -347,7 +347,7 @@ def forward( return tuple(out[:3]), tuple(out[3:]) -class PhotonBlock(nn.Module): +class PRXBlock(nn.Module): r""" Multimodal transformer block with text–image cross-attention, modulation, and MLP. @@ -364,7 +364,7 @@ class PhotonBlock(nn.Module): Attributes: img_pre_norm (`nn.LayerNorm`): Pre-normalization applied to image tokens before attention. - attention (`PhotonAttention`): + attention (`PRXAttention`): Multi-head attention module with built-in QKV projections and normalizations for cross-attention between image and text tokens. post_attention_layernorm (`nn.LayerNorm`): @@ -400,15 +400,15 @@ def __init__( # Pre-attention normalization for image tokens self.img_pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - # PhotonAttention module with built-in projections and norms - self.attention = PhotonAttention( + # PRXAttention module with built-in projections and norms + self.attention = PRXAttention( query_dim=hidden_size, heads=num_heads, dim_head=self.head_dim, bias=False, out_bias=False, eps=1e-6, - processor=PhotonAttnProcessor2_0(), + processor=PRXAttnProcessor2_0(), ) # mlp @@ -557,7 +557,7 @@ def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Te return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size) -class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): +class PRXTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): r""" Transformer-based 2D model for text to image generation. @@ -595,7 +595,7 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): txt_in (`nn.Linear`): Projection layer for text conditioning. blocks (`nn.ModuleList`): - Stack of transformer blocks (`PhotonBlock`). + Stack of transformer blocks (`PRXBlock`). final_layer (`LastLayer`): Projection layer mapping hidden tokens back to patch outputs. @@ -661,14 +661,14 @@ def __init__( self.hidden_size = hidden_size self.num_heads = num_heads - self.pe_embedder = PhotonEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim) + self.pe_embedder = PRXEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim) self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True) self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) self.txt_in = nn.Linear(context_in_dim, self.hidden_size) self.blocks = nn.ModuleList( [ - PhotonBlock( + PRXBlock( self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio, @@ -702,7 +702,7 @@ def forward( return_dict: bool = True, ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]: r""" - Forward pass of the PhotonTransformer2DModel. + Forward pass of the PRXTransformer2DModel. The latent image is split into patch tokens, combined with text conditioning, and processed through a stack of transformer blocks modulated by the timestep. The output is reconstructed into the latent image space. diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 953f307fe1de..ff64958d4699 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -718,9 +718,9 @@ StableDiffusionXLPAGPipeline, ) from .paint_by_example import PaintByExamplePipeline - from .prx import PRXPipeline from .pia import PIAPipeline from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline + from .prx import PRXPipeline from .qwenimage import ( QwenImageControlNetInpaintPipeline, QwenImageControlNetPipeline, diff --git a/src/diffusers/pipelines/photon/__init__.py b/src/diffusers/pipelines/prx/__init__.py similarity index 100% rename from src/diffusers/pipelines/photon/__init__.py rename to src/diffusers/pipelines/prx/__init__.py diff --git a/src/diffusers/pipelines/photon/pipeline_output.py b/src/diffusers/pipelines/prx/pipeline_output.py similarity index 100% rename from src/diffusers/pipelines/photon/pipeline_output.py rename to src/diffusers/pipelines/prx/pipeline_output.py diff --git a/src/diffusers/pipelines/photon/pipeline_photon.py b/src/diffusers/pipelines/prx/pipeline_prx.py similarity index 96% rename from src/diffusers/pipelines/photon/pipeline_photon.py rename to src/diffusers/pipelines/prx/pipeline_prx.py index 4a10899ede61..a3bd3e6b45e7 100644 --- a/src/diffusers/pipelines/photon/pipeline_photon.py +++ b/src/diffusers/pipelines/prx/pipeline_prx.py @@ -30,9 +30,9 @@ from diffusers.image_processor import PixArtImageProcessor from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models import AutoencoderDC, AutoencoderKL -from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel -from diffusers.pipelines.photon.pipeline_output import PhotonPipelineOutput +from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.prx.pipeline_output import PRXPipelineOutput from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils import ( logging, @@ -73,7 +73,7 @@ class TextPreprocessor: - """Text preprocessing utility for PhotonPipeline.""" + """Text preprocessing utility for PRXPipeline.""" def __init__(self): """Initialize text preprocessor.""" @@ -203,34 +203,34 @@ def clean_text(self, text: str) -> str: Examples: ```py >>> import torch - >>> from diffusers import PhotonPipeline + >>> from diffusers import PRXPipeline >>> # Load pipeline with from_pretrained - >>> pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft") + >>> pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft") >>> pipe.to("cuda") >>> prompt = "A digital painting of a rusty, vintage tram on a sandy beach" >>> image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0] - >>> image.save("photon_output.png") + >>> image.save("prx_output.png") ``` """ -class PhotonPipeline( +class PRXPipeline( DiffusionPipeline, LoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin, ): r""" - Pipeline for text-to-image generation using Photon Transformer. + Pipeline for text-to-image generation using PRX Transformer. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: - transformer ([`PhotonTransformer2DModel`]): - The Photon transformer model to denoise the encoded image latents. + transformer ([`PRXTransformer2DModel`]): + The PRX transformer model to denoise the encoded image latents. scheduler ([`FlowMatchEulerDiscreteScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. text_encoder ([`T5GemmaEncoder`]): @@ -248,7 +248,7 @@ class PhotonPipeline( def __init__( self, - transformer: PhotonTransformer2DModel, + transformer: PRXTransformer2DModel, scheduler: FlowMatchEulerDiscreteScheduler, text_encoder: T5GemmaEncoder, tokenizer: Union[T5TokenizerFast, GemmaTokenizerFast, AutoTokenizer], @@ -257,9 +257,9 @@ def __init__( ): super().__init__() - if PhotonTransformer2DModel is None: + if PRXTransformer2DModel is None: raise ImportError( - "PhotonTransformer2DModel is not available. Please ensure the transformer_photon module is properly installed." + "PRXTransformer2DModel is not available. Please ensure the transformer_prx module is properly installed." ) self.text_preprocessor = TextPreprocessor() @@ -567,7 +567,7 @@ def __call__( The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.photon.PhotonPipelineOutput`] instead of a plain tuple. + Whether or not to return a [`~pipelines.prx.PRXPipelineOutput`] instead of a plain tuple. use_resolution_binning (`bool`, *optional*, defaults to `True`): If set to `True`, the requested height and width are first mapped to the closest resolutions using predefined aspect ratio bins. After the produced latents are decoded into images, they are resized back @@ -585,9 +585,8 @@ def __call__( Examples: Returns: - [`~pipelines.photon.PhotonPipelineOutput`] or `tuple`: [`~pipelines.photon.PhotonPipelineOutput`] if - `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the - generated images. + [`~pipelines.prx.PRXPipelineOutput`] or `tuple`: [`~pipelines.prx.PRXPipelineOutput`] if `return_dict` is + True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ # 0. Set height and width @@ -765,4 +764,4 @@ def __call__( if not return_dict: return (image,) - return PhotonPipelineOutput(images=image) + return PRXPipelineOutput(images=image) diff --git a/tests/models/transformers/test_models_transformer_photon.py b/tests/models/transformers/test_models_transformer_prx.py similarity index 90% rename from tests/models/transformers/test_models_transformer_photon.py rename to tests/models/transformers/test_models_transformer_prx.py index f5185245d399..1387625d5ea0 100644 --- a/tests/models/transformers/test_models_transformer_photon.py +++ b/tests/models/transformers/test_models_transformer_prx.py @@ -17,7 +17,7 @@ import torch -from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel +from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel from ...testing_utils import enable_full_determinism, torch_device from ..test_modeling_common import ModelTesterMixin @@ -26,8 +26,8 @@ enable_full_determinism() -class PhotonTransformerTests(ModelTesterMixin, unittest.TestCase): - model_class = PhotonTransformer2DModel +class PRXTransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = PRXTransformer2DModel main_input_name = "hidden_states" uses_custom_attn_processor = True @@ -75,7 +75,7 @@ def prepare_init_args_and_inputs_for_common(self): return init_dict, inputs_dict def test_gradient_checkpointing_is_applied(self): - expected_set = {"PhotonTransformer2DModel"} + expected_set = {"PRXTransformer2DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/pipelines/photon/__init__.py b/tests/pipelines/prx/__init__.py similarity index 100% rename from tests/pipelines/photon/__init__.py rename to tests/pipelines/prx/__init__.py diff --git a/tests/pipelines/photon/test_pipeline_photon.py b/tests/pipelines/prx/test_pipeline_prx.py similarity index 90% rename from tests/pipelines/photon/test_pipeline_photon.py rename to tests/pipelines/prx/test_pipeline_prx.py index c29c6ce0b0dd..46c6a5760e22 100644 --- a/tests/pipelines/photon/test_pipeline_photon.py +++ b/tests/pipelines/prx/test_pipeline_prx.py @@ -8,8 +8,8 @@ from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder from diffusers.models import AutoencoderDC, AutoencoderKL -from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel -from diffusers.pipelines.photon.pipeline_photon import PhotonPipeline +from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel +from diffusers.pipelines.prx.pipeline_prx import PRXPipeline from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils import is_transformers_version @@ -22,8 +22,8 @@ reason="See https://github.com/huggingface/diffusers/pull/12456#issuecomment-3424228544", strict=False, ) -class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase): - pipeline_class = PhotonPipeline +class PRXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = PRXPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} batch_params = frozenset(["prompt", "negative_prompt", "num_images_per_prompt"]) test_xformers_attention = False @@ -32,16 +32,16 @@ class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase): @classmethod def setUpClass(cls): - # Ensure PhotonPipeline has an _execution_device property expected by __call__ - if not isinstance(getattr(PhotonPipeline, "_execution_device", None), property): + # Ensure PRXPipeline has an _execution_device property expected by __call__ + if not isinstance(getattr(PRXPipeline, "_execution_device", None), property): try: - setattr(PhotonPipeline, "_execution_device", property(lambda self: torch.device("cpu"))) + setattr(PRXPipeline, "_execution_device", property(lambda self: torch.device("cpu"))) except Exception: pass def get_dummy_components(self): torch.manual_seed(0) - transformer = PhotonTransformer2DModel( + transformer = PRXTransformer2DModel( patch_size=1, in_channels=4, context_in_dim=8, @@ -129,7 +129,7 @@ def get_dummy_inputs(self, device, seed=0): def test_inference(self): device = "cpu" components = self.get_dummy_components() - pipe = PhotonPipeline(**components) + pipe = PRXPipeline(**components) pipe.to(device) pipe.set_progress_bar_config(disable=None) try: @@ -148,7 +148,7 @@ def test_inference(self): def test_callback_inputs(self): components = self.get_dummy_components() - pipe = PhotonPipeline(**components) + pipe = PRXPipeline(**components) pipe = pipe.to("cpu") pipe.set_progress_bar_config(disable=None) try: @@ -157,7 +157,7 @@ def test_callback_inputs(self): pass self.assertTrue( hasattr(pipe, "_callback_tensor_inputs"), - f" {PhotonPipeline} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + f" {PRXPipeline} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", ) def callback_inputs_subset(pipe, i, t, callback_kwargs): @@ -216,7 +216,7 @@ def to_np_local(tensor): self.assertLess(max(max_diff1, max_diff2), expected_max_diff) def test_inference_with_autoencoder_dc(self): - """Test PhotonPipeline with AutoencoderDC (DCAE) instead of AutoencoderKL.""" + """Test PRXPipeline with AutoencoderDC (DCAE) instead of AutoencoderKL.""" device = "cpu" components = self.get_dummy_components() @@ -248,7 +248,7 @@ def test_inference_with_autoencoder_dc(self): components["vae"] = vae_dc - pipe = PhotonPipeline(**components) + pipe = PRXPipeline(**components) pipe.to(device) pipe.set_progress_bar_config(disable=None) From 163521b5a5b113bfdc61ddf96fa19ecb97a3c6c7 Mon Sep 17 00:00:00 2001 From: DavidBert Date: Tue, 21 Oct 2025 21:12:52 +0000 Subject: [PATCH 3/7] Revert .gitignore to state before commit b7fb0fe9d63bf766bbe3c42ac154a043796dd370 --- .gitignore | 24 +----------------------- 1 file changed, 1 insertion(+), 23 deletions(-) diff --git a/.gitignore b/.gitignore index c158ffe1378b..a55026febd5a 100644 --- a/.gitignore +++ b/.gitignore @@ -178,26 +178,4 @@ tags .ruff_cache # wandb -wandb -convert_checkpoints.py -dcae_mirage_generated_image_.png -dcae_prx_generated_image.png -example_usage.py -META_TENSOR_FIX.md -mirage_generated_image__.png -mirage_generated_image_.png -prx_generated_image.png -plan.md -test_existing_checkpoints_with_timestep_change.py -test_timestep_embedding.py -test_updated_checkpoint.png -test_updated_checkpoint.py -testhf.ipynb -update_checkpoint_parameters.py -verify_checkpoint_parameters.py -for_claude/mirage_layers.py -for_claude/mirage.py -for_claude/text_tower.py -for_claude/vae_tower.py -prx_/prx_layers.py -prx_/prx.py +wandb \ No newline at end of file From 8417ff57dd0c14942e0bed87bd49f5d84adc1150 Mon Sep 17 00:00:00 2001 From: DavidBert Date: Tue, 21 Oct 2025 15:38:55 +0000 Subject: [PATCH 4/7] rename photon to prx --- .gitignore | 24 ++++++++- docs/source/en/_toctree.yml | 4 +- docs/source/en/api/pipelines/photon.md | 54 +++++++++---------- scripts/convert_photon_to_diffusers.py | 38 ++++++------- src/diffusers/__init__.py | 8 +-- src/diffusers/models/__init__.py | 4 +- src/diffusers/models/transformers/__init__.py | 2 +- src/diffusers/pipelines/__init__.py | 4 +- src/diffusers/pipelines/photon/__init__.py | 8 +-- .../pipelines/photon/pipeline_output.py | 4 +- src/diffusers/utils/dummy_pt_objects.py | 2 +- .../dummy_torch_and_transformers_objects.py | 2 +- 12 files changed, 88 insertions(+), 66 deletions(-) diff --git a/.gitignore b/.gitignore index a55026febd5a..c158ffe1378b 100644 --- a/.gitignore +++ b/.gitignore @@ -178,4 +178,26 @@ tags .ruff_cache # wandb -wandb \ No newline at end of file +wandb +convert_checkpoints.py +dcae_mirage_generated_image_.png +dcae_prx_generated_image.png +example_usage.py +META_TENSOR_FIX.md +mirage_generated_image__.png +mirage_generated_image_.png +prx_generated_image.png +plan.md +test_existing_checkpoints_with_timestep_change.py +test_timestep_embedding.py +test_updated_checkpoint.png +test_updated_checkpoint.py +testhf.ipynb +update_checkpoint_parameters.py +verify_checkpoint_parameters.py +for_claude/mirage_layers.py +for_claude/mirage.py +for_claude/text_tower.py +for_claude/vae_tower.py +prx_/prx_layers.py +prx_/prx.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 3abe89437fa5..2e81200c216c 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -541,8 +541,8 @@ title: PAG - local: api/pipelines/paint_by_example title: Paint by Example - - local: api/pipelines/photon - title: Photon + - local: api/pipelines/prx + title: PRX - local: api/pipelines/pixart title: PixArt-α - local: api/pipelines/pixart_sigma diff --git a/docs/source/en/api/pipelines/photon.md b/docs/source/en/api/pipelines/photon.md index 293e05f0fdef..16670f4bfc86 100644 --- a/docs/source/en/api/pipelines/photon.md +++ b/docs/source/en/api/pipelines/photon.md @@ -12,43 +12,43 @@ # See the License for the specific language governing permissions and # limitations under the License. --> -# Photon +# PRX -Photon generates high-quality images from text using a simplified MMDIT architecture where text tokens don't update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google's T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing. +PRX generates high-quality images from text using a simplified MMDIT architecture where text tokens don't update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google's T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing. ## Available models -Photon offers multiple variants with different VAE configurations, each optimized for specific resolutions. Base models excel with detailed prompts, capturing complex compositions and subtle details. Fine-tuned models trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) improve aesthetic quality, especially with simpler prompts. +PRX offers multiple variants with different VAE configurations, each optimized for specific resolutions. Base models excel with detailed prompts, capturing complex compositions and subtle details. Fine-tuned models trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) improve aesthetic quality, especially with simpler prompts. | Model | Resolution | Fine-tuned | Distilled | Description | Suggested prompts | Suggested parameters | Recommended dtype | |:-----:|:-----------------:|:----------:|:----------:|:----------:|:----------:|:----------:|:----------:| -| [`Photoroom/photon-256-t2i`](https://huggingface.co/Photoroom/photon-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | -| [`Photoroom/photon-256-t2i-sft`](https://huggingface.co/Photoroom/photon-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `torch.bfloat16` | -| [`Photoroom/photon-512-t2i`](https://huggingface.co/Photoroom/photon-512-t2i)| 512 | No | No | Base model pre-trained at 512 with Flux VAE |Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | -| [`Photoroom/photon-512-t2i-sft`](https://huggingface.co/Photoroom/photon-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | -| [`Photoroom/photon-512-t2i-sft-distilled`](https://huggingface.co/Photoroom/photon-512-t2i-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/photon-512-t2i-sft`](https://huggingface.co/Photoroom/photon-512-t2i-sft) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` | -| [`Photoroom/photon-512-t2i-dc-ae`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae)| 512 | No | No | Base model pre-trained at 512 with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae)|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | -| [`Photoroom/photon-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae) | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | -| [`Photoroom/photon-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/photon-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft-distilled) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |s +| [`Photoroom/prx-256-t2i`](https://huggingface.co/Photoroom/prx-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | +| [`Photoroom/prx-256-t2i-sft`](https://huggingface.co/Photoroom/prx-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `torch.bfloat16` | +| [`Photoroom/prx-512-t2i`](https://huggingface.co/Photoroom/prx-512-t2i)| 512 | No | No | Base model pre-trained at 512 with Flux VAE |Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | +| [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | +| [`Photoroom/prx-512-t2i-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` | +| [`Photoroom/prx-512-t2i-dc-ae`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae)| 512 | No | No | Base model pre-trained at 512 with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae)|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | +| [`Photoroom/prx-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae) | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | +| [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |s -Refer to [this](https://huggingface.co/collections/Photoroom/photon-models-68e66254c202ebfab99ad38e) collection for more information. +Refer to [this](https://huggingface.co/collections/Photoroom/prx-models-68e66254c202ebfab99ad38e) collection for more information. ## Loading the pipeline Load the pipeline with [`~DiffusionPipeline.from_pretrained`]. ```py -from diffusers.pipelines.photon import PhotonPipeline +from diffusers.pipelines.prx import PRXPipeline # Load pipeline - VAE and text encoder will be loaded from HuggingFace -pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft", torch_dtype=torch.bfloat16) +pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16) pipe.to("cuda") prompt = "A front-facing portrait of a lion the golden savanna at sunset." image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0] -image.save("photon_output.png") +image.save("prx_output.png") ``` ### Manual Component Loading @@ -57,9 +57,9 @@ Load components individually to customize the pipeline for instance to use quant ```py import torch -from diffusers.pipelines.photon import PhotonPipeline +from diffusers.pipelines.prx import PRXPipeline from diffusers.models import AutoencoderKL, AutoencoderDC -from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel +from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from transformers import T5GemmaModel, GemmaTokenizerFast from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig @@ -67,8 +67,8 @@ from transformers import BitsAndBytesConfig as BitsAndBytesConfig quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True) # Load transformer -transformer = PhotonTransformer2DModel.from_pretrained( - "checkpoints/photon-512-t2i-sft", +transformer = PRXTransformer2DModel.from_pretrained( + "checkpoints/prx-512-t2i-sft", subfolder="transformer", quantization_config=quant_config, torch_dtype=torch.bfloat16, @@ -76,7 +76,7 @@ transformer = PhotonTransformer2DModel.from_pretrained( # Load scheduler scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( - "checkpoints/photon-512-t2i-sft", subfolder="scheduler" + "checkpoints/prx-512-t2i-sft", subfolder="scheduler" ) # Load T5Gemma text encoder @@ -94,7 +94,7 @@ vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", quantization_config=quant_config, torch_dtype=torch.bfloat16) -pipe = PhotonPipeline( +pipe = PRXPipeline( transformer=transformer, scheduler=scheduler, text_encoder=text_encoder, @@ -111,21 +111,21 @@ For memory-constrained environments: ```py import torch -from diffusers.pipelines.photon import PhotonPipeline +from diffusers.pipelines.prx import PRXPipeline -pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft", torch_dtype=torch.bfloat16) +pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16) pipe.enable_model_cpu_offload() # Offload components to CPU when not in use # Or use sequential CPU offload for even lower memory pipe.enable_sequential_cpu_offload() ``` -## PhotonPipeline +## PRXPipeline -[[autodoc]] PhotonPipeline +[[autodoc]] PRXPipeline - all - __call__ -## PhotonPipelineOutput +## PRXPipelineOutput -[[autodoc]] pipelines.photon.pipeline_output.PhotonPipelineOutput +[[autodoc]] pipelines.prx.pipeline_output.PRXPipelineOutput diff --git a/scripts/convert_photon_to_diffusers.py b/scripts/convert_photon_to_diffusers.py index c66bc314181f..d9bde2f34d56 100644 --- a/scripts/convert_photon_to_diffusers.py +++ b/scripts/convert_photon_to_diffusers.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 """ -Script to convert Photon checkpoint from original codebase to diffusers format. +Script to convert PRX checkpoint from original codebase to diffusers format. """ import argparse @@ -13,15 +13,15 @@ import torch from safetensors.torch import save_file -from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel -from diffusers.pipelines.photon import PhotonPipeline +from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel +from diffusers.pipelines.prx import PRXPipeline DEFAULT_RESOLUTION = 512 @dataclass(frozen=True) -class PhotonBase: +class PRXBase: context_in_dim: int = 2304 hidden_size: int = 1792 mlp_ratio: float = 3.5 @@ -34,22 +34,22 @@ class PhotonBase: @dataclass(frozen=True) -class PhotonFlux(PhotonBase): +class PRXFlux(PRXBase): in_channels: int = 16 patch_size: int = 2 @dataclass(frozen=True) -class PhotonDCAE(PhotonBase): +class PRXDCAE(PRXBase): in_channels: int = 32 patch_size: int = 1 def build_config(vae_type: str) -> Tuple[dict, int]: if vae_type == "flux": - cfg = PhotonFlux() + cfg = PRXFlux() elif vae_type == "dc-ae": - cfg = PhotonDCAE() + cfg = PRXDCAE() else: raise ValueError(f"Unsupported VAE type: {vae_type}. Use 'flux' or 'dc-ae'") @@ -64,7 +64,7 @@ def create_parameter_mapping(depth: int) -> dict: # Key mappings for structural changes mapping = {} - # Map old structure (layers in PhotonBlock) to new structure (layers in PhotonAttention) + # Map old structure (layers in PRXBlock) to new structure (layers in PRXAttention) for i in range(depth): # QKV projections moved to attention module mapping[f"blocks.{i}.img_qkv_proj.weight"] = f"blocks.{i}.attention.img_qkv_proj.weight" @@ -108,8 +108,8 @@ def convert_checkpoint_parameters(old_state_dict: Dict[str, torch.Tensor], depth return converted_state_dict -def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> PhotonTransformer2DModel: - """Create and load PhotonTransformer2DModel from old checkpoint.""" +def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> PRXTransformer2DModel: + """Create and load PRXTransformer2DModel from old checkpoint.""" print(f"Loading checkpoint from: {checkpoint_path}") @@ -137,8 +137,8 @@ def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> Ph converted_state_dict = convert_checkpoint_parameters(state_dict, depth=model_depth) # Create transformer with config - print("Creating PhotonTransformer2DModel...") - transformer = PhotonTransformer2DModel(**config) + print("Creating PRXTransformer2DModel...") + transformer = PRXTransformer2DModel(**config) # Load state dict print("Loading converted parameters...") @@ -221,14 +221,14 @@ def create_model_index(vae_type: str, default_image_size: int, output_path: str) vae_class = "AutoencoderDC" model_index = { - "_class_name": "PhotonPipeline", + "_class_name": "PRXPipeline", "_diffusers_version": "0.31.0.dev0", "_name_or_path": os.path.basename(output_path), "default_sample_size": default_image_size, "scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"], - "text_encoder": ["photon", "T5GemmaEncoder"], + "text_encoder": ["prx", "T5GemmaEncoder"], "tokenizer": ["transformers", "GemmaTokenizerFast"], - "transformer": ["diffusers", "PhotonTransformer2DModel"], + "transformer": ["diffusers", "PRXTransformer2DModel"], "vae": ["diffusers", vae_class], } @@ -275,7 +275,7 @@ def main(args): # Verify the pipeline can be loaded try: - pipeline = PhotonPipeline.from_pretrained(args.output_path) + pipeline = PRXPipeline.from_pretrained(args.output_path) print("Pipeline loaded successfully!") print(f"Transformer: {type(pipeline.transformer).__name__}") print(f"VAE: {type(pipeline.vae).__name__}") @@ -298,10 +298,10 @@ def main(args): if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Convert Photon checkpoint to diffusers format") + parser = argparse.ArgumentParser(description="Convert PRX checkpoint to diffusers format") parser.add_argument( - "--checkpoint_path", type=str, required=True, help="Path to the original Photon checkpoint (.pth file )" + "--checkpoint_path", type=str, required=True, help="Path to the original PRX checkpoint (.pth file )" ) parser.add_argument( diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index b7086d2e0c44..47285f37d91a 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -232,7 +232,7 @@ "MultiControlNetModel", "OmniGenTransformer2DModel", "ParallelConfig", - "PhotonTransformer2DModel", + "PRXTransformer2DModel", "PixArtTransformer2DModel", "PriorTransformer", "QwenImageControlNetModel", @@ -516,7 +516,7 @@ "MusicLDMPipeline", "OmniGenPipeline", "PaintByExamplePipeline", - "PhotonPipeline", + "PRXPipeline", "PIAPipeline", "PixArtAlphaPipeline", "PixArtSigmaPAGPipeline", @@ -928,7 +928,7 @@ MultiControlNetModel, OmniGenTransformer2DModel, ParallelConfig, - PhotonTransformer2DModel, + PRXTransformer2DModel, PixArtTransformer2DModel, PriorTransformer, QwenImageControlNetModel, @@ -1182,7 +1182,7 @@ MusicLDMPipeline, OmniGenPipeline, PaintByExamplePipeline, - PhotonPipeline, + PRXPipeline, PIAPipeline, PixArtAlphaPipeline, PixArtSigmaPAGPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 2151e602b2e2..6d08c2f2e23b 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -96,7 +96,7 @@ _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"] - _import_structure["transformers.transformer_photon"] = ["PhotonTransformer2DModel"] + _import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"] _import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] _import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"] @@ -191,7 +191,7 @@ LuminaNextDiT2DModel, MochiTransformer3DModel, OmniGenTransformer2DModel, - PhotonTransformer2DModel, + PRXTransformer2DModel, PixArtTransformer2DModel, PriorTransformer, QwenImageTransformer2DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index ab5311518ba7..d8c3d9b57273 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -32,7 +32,7 @@ from .transformer_lumina2 import Lumina2Transformer2DModel from .transformer_mochi import MochiTransformer3DModel from .transformer_omnigen import OmniGenTransformer2DModel - from .transformer_photon import PhotonTransformer2DModel + from .transformer_prx import PRXTransformer2DModel from .transformer_qwenimage import QwenImageTransformer2DModel from .transformer_sd3 import SD3Transformer2DModel from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index a44c92a834b2..953f307fe1de 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -144,7 +144,7 @@ "FluxKontextPipeline", "FluxKontextInpaintPipeline", ] - _import_structure["photon"] = ["PhotonPipeline"] + _import_structure["prx"] = ["PRXPipeline"] _import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm2"] = [ "AudioLDM2Pipeline", @@ -718,7 +718,7 @@ StableDiffusionXLPAGPipeline, ) from .paint_by_example import PaintByExamplePipeline - from .photon import PhotonPipeline + from .prx import PRXPipeline from .pia import PIAPipeline from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline from .qwenimage import ( diff --git a/src/diffusers/pipelines/photon/__init__.py b/src/diffusers/pipelines/photon/__init__.py index e21e31d4225f..87aaefbd1368 100644 --- a/src/diffusers/pipelines/photon/__init__.py +++ b/src/diffusers/pipelines/photon/__init__.py @@ -12,7 +12,7 @@ _dummy_objects = {} _additional_imports = {} -_import_structure = {"pipeline_output": ["PhotonPipelineOutput"]} +_import_structure = {"pipeline_output": ["PRXPipelineOutput"]} try: if not (is_transformers_available() and is_torch_available()): @@ -22,7 +22,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - _import_structure["pipeline_photon"] = ["PhotonPipeline"] + _import_structure["pipeline_prx"] = ["PRXPipeline"] # Import T5GemmaEncoder for pipeline loading compatibility try: @@ -44,8 +44,8 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: - from .pipeline_output import PhotonPipelineOutput - from .pipeline_photon import PhotonPipeline + from .pipeline_output import PRXPipelineOutput + from .pipeline_prx import PRXPipeline else: import sys diff --git a/src/diffusers/pipelines/photon/pipeline_output.py b/src/diffusers/pipelines/photon/pipeline_output.py index d4b0ff462983..ea1bc9bf418a 100644 --- a/src/diffusers/pipelines/photon/pipeline_output.py +++ b/src/diffusers/pipelines/photon/pipeline_output.py @@ -22,9 +22,9 @@ @dataclass -class PhotonPipelineOutput(BaseOutput): +class PRXPipelineOutput(BaseOutput): """ - Output class for Photon pipelines. + Output class for PRX pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index d379a5d4a77c..417461ea459b 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1098,7 +1098,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class PhotonTransformer2DModel(metaclass=DummyObject): +class PRXTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 52c72579cd20..225cae9a1e21 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1847,7 +1847,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class PhotonPipeline(metaclass=DummyObject): +class PRXPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): From c44717b08051d2ecbaf04c5f8c9985874e102e6d Mon Sep 17 00:00:00 2001 From: DavidBert Date: Tue, 21 Oct 2025 20:58:33 +0000 Subject: [PATCH 5/7] rename photon into prx --- docs/source/en/_toctree.yml | 4 +- .../en/api/pipelines/{photon.md => prx.md} | 0 ...ffusers.py => convert_prx_to_diffusers.py} | 0 src/diffusers/__init__.py | 8 ++-- src/diffusers/models/__init__.py | 2 +- ...ansformer_photon.py => transformer_prx.py} | 46 +++++++++---------- src/diffusers/pipelines/__init__.py | 2 +- .../pipelines/{photon => prx}/__init__.py | 0 .../{photon => prx}/pipeline_output.py | 0 .../pipeline_prx.py} | 35 +++++++------- ...oton.py => test_models_transformer_prx.py} | 8 ++-- tests/pipelines/{photon => prx}/__init__.py | 0 .../test_pipeline_prx.py} | 26 +++++------ 13 files changed, 65 insertions(+), 66 deletions(-) rename docs/source/en/api/pipelines/{photon.md => prx.md} (100%) rename scripts/{convert_photon_to_diffusers.py => convert_prx_to_diffusers.py} (100%) rename src/diffusers/models/transformers/{transformer_photon.py => transformer_prx.py} (95%) rename src/diffusers/pipelines/{photon => prx}/__init__.py (100%) rename src/diffusers/pipelines/{photon => prx}/pipeline_output.py (100%) rename src/diffusers/pipelines/{photon/pipeline_photon.py => prx/pipeline_prx.py} (96%) rename tests/models/transformers/{test_models_transformer_photon.py => test_models_transformer_prx.py} (90%) rename tests/pipelines/{photon => prx}/__init__.py (100%) rename tests/pipelines/{photon/test_pipeline_photon.py => prx/test_pipeline_prx.py} (90%) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 2e81200c216c..1e5c4fe5501a 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -541,12 +541,12 @@ title: PAG - local: api/pipelines/paint_by_example title: Paint by Example - - local: api/pipelines/prx - title: PRX - local: api/pipelines/pixart title: PixArt-α - local: api/pipelines/pixart_sigma title: PixArt-Σ + - local: api/pipelines/prx + title: PRX - local: api/pipelines/qwenimage title: QwenImage - local: api/pipelines/sana diff --git a/docs/source/en/api/pipelines/photon.md b/docs/source/en/api/pipelines/prx.md similarity index 100% rename from docs/source/en/api/pipelines/photon.md rename to docs/source/en/api/pipelines/prx.md diff --git a/scripts/convert_photon_to_diffusers.py b/scripts/convert_prx_to_diffusers.py similarity index 100% rename from scripts/convert_photon_to_diffusers.py rename to scripts/convert_prx_to_diffusers.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 47285f37d91a..d9c6ba9bbf5d 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -232,9 +232,9 @@ "MultiControlNetModel", "OmniGenTransformer2DModel", "ParallelConfig", - "PRXTransformer2DModel", "PixArtTransformer2DModel", "PriorTransformer", + "PRXTransformer2DModel", "QwenImageControlNetModel", "QwenImageMultiControlNetModel", "QwenImageTransformer2DModel", @@ -516,11 +516,11 @@ "MusicLDMPipeline", "OmniGenPipeline", "PaintByExamplePipeline", - "PRXPipeline", "PIAPipeline", "PixArtAlphaPipeline", "PixArtSigmaPAGPipeline", "PixArtSigmaPipeline", + "PRXPipeline", "QwenImageControlNetInpaintPipeline", "QwenImageControlNetPipeline", "QwenImageEditInpaintPipeline", @@ -928,9 +928,9 @@ MultiControlNetModel, OmniGenTransformer2DModel, ParallelConfig, - PRXTransformer2DModel, PixArtTransformer2DModel, PriorTransformer, + PRXTransformer2DModel, QwenImageControlNetModel, QwenImageMultiControlNetModel, QwenImageTransformer2DModel, @@ -1182,11 +1182,11 @@ MusicLDMPipeline, OmniGenPipeline, PaintByExamplePipeline, - PRXPipeline, PIAPipeline, PixArtAlphaPipeline, PixArtSigmaPAGPipeline, PixArtSigmaPipeline, + PRXPipeline, QwenImageControlNetInpaintPipeline, QwenImageControlNetPipeline, QwenImageEditInpaintPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 6d08c2f2e23b..532db76a09e0 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -191,9 +191,9 @@ LuminaNextDiT2DModel, MochiTransformer3DModel, OmniGenTransformer2DModel, - PRXTransformer2DModel, PixArtTransformer2DModel, PriorTransformer, + PRXTransformer2DModel, QwenImageTransformer2DModel, SanaTransformer2DModel, SD3Transformer2DModel, diff --git a/src/diffusers/models/transformers/transformer_photon.py b/src/diffusers/models/transformers/transformer_prx.py similarity index 95% rename from src/diffusers/models/transformers/transformer_photon.py rename to src/diffusers/models/transformers/transformer_prx.py index 6314020c1c74..9b2664b9cb26 100644 --- a/src/diffusers/models/transformers/transformer_photon.py +++ b/src/diffusers/models/transformers/transformer_prx.py @@ -80,9 +80,9 @@ def apply_rope(xq: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: return xq_out.reshape(*xq.shape).type_as(xq) -class PhotonAttnProcessor2_0: +class PRXAttnProcessor2_0: r""" - Processor for implementing Photon-style attention with multi-source tokens and RoPE. Supports multiple attention + Processor for implementing PRX-style attention with multi-source tokens and RoPE. Supports multiple attention backends (Flash Attention, Sage Attention, etc.) via dispatch_attention_fn. """ @@ -91,11 +91,11 @@ class PhotonAttnProcessor2_0: def __init__(self): if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): - raise ImportError("PhotonAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.") + raise ImportError("PRXAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.") def __call__( self, - attn: "PhotonAttention", + attn: "PRXAttention", hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, @@ -103,10 +103,10 @@ def __call__( **kwargs, ) -> torch.Tensor: """ - Apply Photon attention using PhotonAttention module. + Apply PRX attention using PRXAttention module. Args: - attn: PhotonAttention module containing projection layers + attn: PRXAttention module containing projection layers hidden_states: Image tokens [B, L_img, D] encoder_hidden_states: Text tokens [B, L_txt, D] attention_mask: Boolean mask for text tokens [B, L_txt] @@ -114,7 +114,7 @@ def __call__( """ if encoder_hidden_states is None: - raise ValueError("PhotonAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens.") + raise ValueError("PRXAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens.") # Project image tokens to Q, K, V img_qkv = attn.img_qkv_proj(hidden_states) @@ -190,14 +190,14 @@ def __call__( return attn_output -class PhotonAttention(nn.Module, AttentionModuleMixin): +class PRXAttention(nn.Module, AttentionModuleMixin): r""" - Photon-style attention module that handles multi-source tokens and RoPE. Similar to FluxAttention but adapted for - Photon's architecture. + PRX-style attention module that handles multi-source tokens and RoPE. Similar to FluxAttention but adapted for + PRX's architecture. """ - _default_processor_cls = PhotonAttnProcessor2_0 - _available_processors = [PhotonAttnProcessor2_0] + _default_processor_cls = PRXAttnProcessor2_0 + _available_processors = [PRXAttnProcessor2_0] def __init__( self, @@ -251,7 +251,7 @@ def forward( # inspired from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py -class PhotonEmbedND(nn.Module): +class PRXEmbedND(nn.Module): r""" N-dimensional rotary positional embedding. @@ -347,7 +347,7 @@ def forward( return tuple(out[:3]), tuple(out[3:]) -class PhotonBlock(nn.Module): +class PRXBlock(nn.Module): r""" Multimodal transformer block with text–image cross-attention, modulation, and MLP. @@ -364,7 +364,7 @@ class PhotonBlock(nn.Module): Attributes: img_pre_norm (`nn.LayerNorm`): Pre-normalization applied to image tokens before attention. - attention (`PhotonAttention`): + attention (`PRXAttention`): Multi-head attention module with built-in QKV projections and normalizations for cross-attention between image and text tokens. post_attention_layernorm (`nn.LayerNorm`): @@ -400,15 +400,15 @@ def __init__( # Pre-attention normalization for image tokens self.img_pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - # PhotonAttention module with built-in projections and norms - self.attention = PhotonAttention( + # PRXAttention module with built-in projections and norms + self.attention = PRXAttention( query_dim=hidden_size, heads=num_heads, dim_head=self.head_dim, bias=False, out_bias=False, eps=1e-6, - processor=PhotonAttnProcessor2_0(), + processor=PRXAttnProcessor2_0(), ) # mlp @@ -557,7 +557,7 @@ def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Te return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size) -class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): +class PRXTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): r""" Transformer-based 2D model for text to image generation. @@ -595,7 +595,7 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): txt_in (`nn.Linear`): Projection layer for text conditioning. blocks (`nn.ModuleList`): - Stack of transformer blocks (`PhotonBlock`). + Stack of transformer blocks (`PRXBlock`). final_layer (`LastLayer`): Projection layer mapping hidden tokens back to patch outputs. @@ -661,14 +661,14 @@ def __init__( self.hidden_size = hidden_size self.num_heads = num_heads - self.pe_embedder = PhotonEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim) + self.pe_embedder = PRXEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim) self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True) self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) self.txt_in = nn.Linear(context_in_dim, self.hidden_size) self.blocks = nn.ModuleList( [ - PhotonBlock( + PRXBlock( self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio, @@ -702,7 +702,7 @@ def forward( return_dict: bool = True, ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]: r""" - Forward pass of the PhotonTransformer2DModel. + Forward pass of the PRXTransformer2DModel. The latent image is split into patch tokens, combined with text conditioning, and processed through a stack of transformer blocks modulated by the timestep. The output is reconstructed into the latent image space. diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 953f307fe1de..ff64958d4699 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -718,9 +718,9 @@ StableDiffusionXLPAGPipeline, ) from .paint_by_example import PaintByExamplePipeline - from .prx import PRXPipeline from .pia import PIAPipeline from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline + from .prx import PRXPipeline from .qwenimage import ( QwenImageControlNetInpaintPipeline, QwenImageControlNetPipeline, diff --git a/src/diffusers/pipelines/photon/__init__.py b/src/diffusers/pipelines/prx/__init__.py similarity index 100% rename from src/diffusers/pipelines/photon/__init__.py rename to src/diffusers/pipelines/prx/__init__.py diff --git a/src/diffusers/pipelines/photon/pipeline_output.py b/src/diffusers/pipelines/prx/pipeline_output.py similarity index 100% rename from src/diffusers/pipelines/photon/pipeline_output.py rename to src/diffusers/pipelines/prx/pipeline_output.py diff --git a/src/diffusers/pipelines/photon/pipeline_photon.py b/src/diffusers/pipelines/prx/pipeline_prx.py similarity index 96% rename from src/diffusers/pipelines/photon/pipeline_photon.py rename to src/diffusers/pipelines/prx/pipeline_prx.py index 4a10899ede61..a3bd3e6b45e7 100644 --- a/src/diffusers/pipelines/photon/pipeline_photon.py +++ b/src/diffusers/pipelines/prx/pipeline_prx.py @@ -30,9 +30,9 @@ from diffusers.image_processor import PixArtImageProcessor from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models import AutoencoderDC, AutoencoderKL -from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel -from diffusers.pipelines.photon.pipeline_output import PhotonPipelineOutput +from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.prx.pipeline_output import PRXPipelineOutput from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils import ( logging, @@ -73,7 +73,7 @@ class TextPreprocessor: - """Text preprocessing utility for PhotonPipeline.""" + """Text preprocessing utility for PRXPipeline.""" def __init__(self): """Initialize text preprocessor.""" @@ -203,34 +203,34 @@ def clean_text(self, text: str) -> str: Examples: ```py >>> import torch - >>> from diffusers import PhotonPipeline + >>> from diffusers import PRXPipeline >>> # Load pipeline with from_pretrained - >>> pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft") + >>> pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft") >>> pipe.to("cuda") >>> prompt = "A digital painting of a rusty, vintage tram on a sandy beach" >>> image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0] - >>> image.save("photon_output.png") + >>> image.save("prx_output.png") ``` """ -class PhotonPipeline( +class PRXPipeline( DiffusionPipeline, LoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin, ): r""" - Pipeline for text-to-image generation using Photon Transformer. + Pipeline for text-to-image generation using PRX Transformer. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: - transformer ([`PhotonTransformer2DModel`]): - The Photon transformer model to denoise the encoded image latents. + transformer ([`PRXTransformer2DModel`]): + The PRX transformer model to denoise the encoded image latents. scheduler ([`FlowMatchEulerDiscreteScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. text_encoder ([`T5GemmaEncoder`]): @@ -248,7 +248,7 @@ class PhotonPipeline( def __init__( self, - transformer: PhotonTransformer2DModel, + transformer: PRXTransformer2DModel, scheduler: FlowMatchEulerDiscreteScheduler, text_encoder: T5GemmaEncoder, tokenizer: Union[T5TokenizerFast, GemmaTokenizerFast, AutoTokenizer], @@ -257,9 +257,9 @@ def __init__( ): super().__init__() - if PhotonTransformer2DModel is None: + if PRXTransformer2DModel is None: raise ImportError( - "PhotonTransformer2DModel is not available. Please ensure the transformer_photon module is properly installed." + "PRXTransformer2DModel is not available. Please ensure the transformer_prx module is properly installed." ) self.text_preprocessor = TextPreprocessor() @@ -567,7 +567,7 @@ def __call__( The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.photon.PhotonPipelineOutput`] instead of a plain tuple. + Whether or not to return a [`~pipelines.prx.PRXPipelineOutput`] instead of a plain tuple. use_resolution_binning (`bool`, *optional*, defaults to `True`): If set to `True`, the requested height and width are first mapped to the closest resolutions using predefined aspect ratio bins. After the produced latents are decoded into images, they are resized back @@ -585,9 +585,8 @@ def __call__( Examples: Returns: - [`~pipelines.photon.PhotonPipelineOutput`] or `tuple`: [`~pipelines.photon.PhotonPipelineOutput`] if - `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the - generated images. + [`~pipelines.prx.PRXPipelineOutput`] or `tuple`: [`~pipelines.prx.PRXPipelineOutput`] if `return_dict` is + True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ # 0. Set height and width @@ -765,4 +764,4 @@ def __call__( if not return_dict: return (image,) - return PhotonPipelineOutput(images=image) + return PRXPipelineOutput(images=image) diff --git a/tests/models/transformers/test_models_transformer_photon.py b/tests/models/transformers/test_models_transformer_prx.py similarity index 90% rename from tests/models/transformers/test_models_transformer_photon.py rename to tests/models/transformers/test_models_transformer_prx.py index f5185245d399..1387625d5ea0 100644 --- a/tests/models/transformers/test_models_transformer_photon.py +++ b/tests/models/transformers/test_models_transformer_prx.py @@ -17,7 +17,7 @@ import torch -from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel +from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel from ...testing_utils import enable_full_determinism, torch_device from ..test_modeling_common import ModelTesterMixin @@ -26,8 +26,8 @@ enable_full_determinism() -class PhotonTransformerTests(ModelTesterMixin, unittest.TestCase): - model_class = PhotonTransformer2DModel +class PRXTransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = PRXTransformer2DModel main_input_name = "hidden_states" uses_custom_attn_processor = True @@ -75,7 +75,7 @@ def prepare_init_args_and_inputs_for_common(self): return init_dict, inputs_dict def test_gradient_checkpointing_is_applied(self): - expected_set = {"PhotonTransformer2DModel"} + expected_set = {"PRXTransformer2DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/pipelines/photon/__init__.py b/tests/pipelines/prx/__init__.py similarity index 100% rename from tests/pipelines/photon/__init__.py rename to tests/pipelines/prx/__init__.py diff --git a/tests/pipelines/photon/test_pipeline_photon.py b/tests/pipelines/prx/test_pipeline_prx.py similarity index 90% rename from tests/pipelines/photon/test_pipeline_photon.py rename to tests/pipelines/prx/test_pipeline_prx.py index c29c6ce0b0dd..46c6a5760e22 100644 --- a/tests/pipelines/photon/test_pipeline_photon.py +++ b/tests/pipelines/prx/test_pipeline_prx.py @@ -8,8 +8,8 @@ from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder from diffusers.models import AutoencoderDC, AutoencoderKL -from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel -from diffusers.pipelines.photon.pipeline_photon import PhotonPipeline +from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel +from diffusers.pipelines.prx.pipeline_prx import PRXPipeline from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils import is_transformers_version @@ -22,8 +22,8 @@ reason="See https://github.com/huggingface/diffusers/pull/12456#issuecomment-3424228544", strict=False, ) -class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase): - pipeline_class = PhotonPipeline +class PRXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = PRXPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} batch_params = frozenset(["prompt", "negative_prompt", "num_images_per_prompt"]) test_xformers_attention = False @@ -32,16 +32,16 @@ class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase): @classmethod def setUpClass(cls): - # Ensure PhotonPipeline has an _execution_device property expected by __call__ - if not isinstance(getattr(PhotonPipeline, "_execution_device", None), property): + # Ensure PRXPipeline has an _execution_device property expected by __call__ + if not isinstance(getattr(PRXPipeline, "_execution_device", None), property): try: - setattr(PhotonPipeline, "_execution_device", property(lambda self: torch.device("cpu"))) + setattr(PRXPipeline, "_execution_device", property(lambda self: torch.device("cpu"))) except Exception: pass def get_dummy_components(self): torch.manual_seed(0) - transformer = PhotonTransformer2DModel( + transformer = PRXTransformer2DModel( patch_size=1, in_channels=4, context_in_dim=8, @@ -129,7 +129,7 @@ def get_dummy_inputs(self, device, seed=0): def test_inference(self): device = "cpu" components = self.get_dummy_components() - pipe = PhotonPipeline(**components) + pipe = PRXPipeline(**components) pipe.to(device) pipe.set_progress_bar_config(disable=None) try: @@ -148,7 +148,7 @@ def test_inference(self): def test_callback_inputs(self): components = self.get_dummy_components() - pipe = PhotonPipeline(**components) + pipe = PRXPipeline(**components) pipe = pipe.to("cpu") pipe.set_progress_bar_config(disable=None) try: @@ -157,7 +157,7 @@ def test_callback_inputs(self): pass self.assertTrue( hasattr(pipe, "_callback_tensor_inputs"), - f" {PhotonPipeline} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + f" {PRXPipeline} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", ) def callback_inputs_subset(pipe, i, t, callback_kwargs): @@ -216,7 +216,7 @@ def to_np_local(tensor): self.assertLess(max(max_diff1, max_diff2), expected_max_diff) def test_inference_with_autoencoder_dc(self): - """Test PhotonPipeline with AutoencoderDC (DCAE) instead of AutoencoderKL.""" + """Test PRXPipeline with AutoencoderDC (DCAE) instead of AutoencoderKL.""" device = "cpu" components = self.get_dummy_components() @@ -248,7 +248,7 @@ def test_inference_with_autoencoder_dc(self): components["vae"] = vae_dc - pipe = PhotonPipeline(**components) + pipe = PRXPipeline(**components) pipe.to(device) pipe.set_progress_bar_config(disable=None) From 1b34a3c1a8200c860d361f2aebb628cb5fff5319 Mon Sep 17 00:00:00 2001 From: DavidBert Date: Tue, 21 Oct 2025 21:12:52 +0000 Subject: [PATCH 6/7] Revert .gitignore to state before commit b7fb0fe9d63bf766bbe3c42ac154a043796dd370 --- .gitignore | 24 +----------------------- 1 file changed, 1 insertion(+), 23 deletions(-) diff --git a/.gitignore b/.gitignore index c158ffe1378b..a55026febd5a 100644 --- a/.gitignore +++ b/.gitignore @@ -178,26 +178,4 @@ tags .ruff_cache # wandb -wandb -convert_checkpoints.py -dcae_mirage_generated_image_.png -dcae_prx_generated_image.png -example_usage.py -META_TENSOR_FIX.md -mirage_generated_image__.png -mirage_generated_image_.png -prx_generated_image.png -plan.md -test_existing_checkpoints_with_timestep_change.py -test_timestep_embedding.py -test_updated_checkpoint.png -test_updated_checkpoint.py -testhf.ipynb -update_checkpoint_parameters.py -verify_checkpoint_parameters.py -for_claude/mirage_layers.py -for_claude/mirage.py -for_claude/text_tower.py -for_claude/vae_tower.py -prx_/prx_layers.py -prx_/prx.py +wandb \ No newline at end of file From 3b80dcd3f482ad2af6e74818ce55c52dd6e0fd3e Mon Sep 17 00:00:00 2001 From: DavidBert Date: Tue, 21 Oct 2025 21:39:28 +0000 Subject: [PATCH 7/7] make fix-copies --- src/diffusers/utils/dummy_pt_objects.py | 6 +++--- .../utils/dummy_torch_and_transformers_objects.py | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 417461ea459b..ecf2d7957ad4 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1098,7 +1098,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class PRXTransformer2DModel(metaclass=DummyObject): +class PixArtTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -1113,7 +1113,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class PixArtTransformer2DModel(metaclass=DummyObject): +class PriorTransformer(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -1128,7 +1128,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class PriorTransformer(metaclass=DummyObject): +class PRXTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 225cae9a1e21..3a106c1b83a2 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1847,7 +1847,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class PRXPipeline(metaclass=DummyObject): +class PIAPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -1862,7 +1862,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class PIAPipeline(metaclass=DummyObject): +class PixArtAlphaPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -1877,7 +1877,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class PixArtAlphaPipeline(metaclass=DummyObject): +class PixArtSigmaPAGPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -1892,7 +1892,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class PixArtSigmaPAGPipeline(metaclass=DummyObject): +class PixArtSigmaPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -1907,7 +1907,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class PixArtSigmaPipeline(metaclass=DummyObject): +class PRXPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs):