Skip to content

Conversation

DavidBert
Copy link

@DavidBert DavidBert commented Oct 9, 2025

This commit adds support for the Photon image generation model:

  • PhotonTransformer2DModel: Core transformer architecture
  • PhotonPipeline: Text-to-image generation pipeline
  • Attention processor updates for Photon-specific attention mechanism
  • Conversion script for loading Photon checkpoints
  • Documentation and tests

Some exemples below with the 512 model fine-tuned on the Alchemist dataset and distilled with PAG

image_10 image_4 image_0 image_1

What does this PR do?

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

This commit adds support for the Photon image generation model:
- PhotonTransformer2DModel: Core transformer architecture
- PhotonPipeline: Text-to-image generation pipeline
- Attention processor updates for Photon-specific attention mechanism
- Conversion script for loading Photon checkpoints
- Documentation and tests
print("✓ Created scheduler config")


def download_and_save_vae(vae_type: str, output_path: str):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure on this one: I'm saving the VAE weights while they are already available on the Hub (Flux VAE and DC-AE).
Is there a way to avoid storing them and instead look directly for the original ones?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, it's okay to keep this as is. This way, everything is under the same model repo.

print(f"✓ Saved VAE to {vae_path}")


def download_and_save_text_encoder(output_path: str):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here for the Text Encoder.

print("✓ Created scheduler config")


def download_and_save_vae(vae_type: str, output_path: str):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, it's okay to keep this as is. This way, everything is under the same model repo.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clean PR! I left some initial feedback for you. LMK if that makes sense.

Also, it would be great to see some samples of Photon!

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Left a couple more comments. Let's also add the pipeline-level tests.

<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
</div>

Photon is a text-to-image diffusion model using simplified MMDIT architecture with flow matching for efficient high-quality image generation. The model uses T5Gemma as the text encoder and supports either Flux VAE (AutoencoderKL) or DC-AE (AutoencoderDC) for latent compression.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cc: @stevhliu for a review on the docs.

return xq_out.reshape(*xq.shape).type_as(xq)


class PhotonAttnProcessor2_0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we write it in a fashion similar to

?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I second this suggestion - in particular, I think it would be more in line with other diffusers models implementations to reuse the layers defined in Attention, such as to_q/to_k/to_v, etc. instead of defining them in PhotonBlock (e.g. PhotonBlock.img_qkv_proj), and to keep the entire attention implementation in the PhotonAttnProcessor2_0 class.

Attention supports stuff like QK norms and fusing projections, so that could potentially be reused as well. If you need some custom logic not found in Attention, you could potentially add it in there or create a new Attention-style class like Flux does:

class FluxAttention(torch.nn.Module, AttentionModuleMixin):

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made the change and updated both the conversion script and the checkpoints on the hub.

def __call__(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We support passing prompt embeddings too in case users want to supply them precomputed:

prompt_embeds: Optional[torch.FloatTensor] = None,

Comment on lines 484 to 486
default_sample_size = getattr(self.config, "default_sample_size", DEFAULT_RESOLUTION)
height = height or default_sample_size
width = width or default_sample_size
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer this pattern:

height = height or self.default_sample_size * self.vae_scale_factor

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did it this way because the model works for two different vae with different scale_factors.
Is it ok to not make it depend of self.vae_scale_factor? It makes it hard to define a default value otherwise.

Copy link
Member

@sayakpaul sayakpaul Oct 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh good point! I think we could make a small utility function in the pipeline class that determines the default resolution given the VAE that's loaded into it? WDYT?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, way cleaner! I did it.

@DavidBert
Copy link
Author

Thanks @dg845 and @stevhliu for your last reviews! I updated the PR and hopefully addressed all your suggestions.

@DavidBert DavidBert requested review from dg845 and stevhliu October 16, 2025 09:51
Copy link
Member

@stevhliu stevhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, docs LGTM

Comment on lines 308 to 310
parser.add_argument(
"--checkpoint_path", type=str, required=True, help="Path to the original Photon checkpoint (.pth file)"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to set a meaningful default argument for checkpoint_path (for example, if the model checkpoint has been open-sourced and is available on e.g. HF hub, we could set it as a default)?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We did not open source the original code and model weights yet but plan to do it soon.
Is it ok to update it later when it's done?
What's the common practice here? Store the original weights and corresponding code on a model repo? I don't see any default path in the other conversion scripts.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's totally ok. I thought other conversion scripts would have it set but you're right that it's usually not the case.

Comment on lines 167 to 170
# Apply scaled dot-product attention
attn_output = torch.nn.functional.scaled_dot_product_attention(
img_q.contiguous(), k.contiguous(), v.contiguous(), attn_mask=attn_mask_tensor
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, have you tested Photon with any other attention backends (e.g. Flash Attention, Sage Attention, etc.)? Not a blocker, but if so you could consider refactoring to use dispatch_attention_fn to add support for these backends.

You can look at the Flux attention processor for an example:

hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)

See PR #11916 and the attention backend docs for more info.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion, I tried and it works!

from ..test_pipelines_common import PipelineTesterMixin


class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to add a corresponding PhotonPipelineSlowTests class where we test whether inference on a full checkpoint is consistent between diffusers and the original code? You can refer to FluxPipelineSlowTests as a reference:

@nightly
@require_big_accelerator
class FluxPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxPipeline
repo_id = "black-forest-labs/FLUX.1-schnell"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay to skip it for now IMO since we also don't add it for Qwen.

| Model | Resolution | Fine-tuned | Distilled | Description | Suggested prompts | Suggested parameters | Recommended dtype |
|:-----:|:-----------------:|:----------:|:----------:|:----------:|:----------:|:----------:|:----------:|
| [`Photoroom/photon-256-t2i`](https://huggingface.co/Photoroom/photon-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/photon-256-t2i-sft`](https://huggingface.co/Photoroom/photon-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `torch.bfloat16` |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these model links expected to be broken for now? I get a 404 for https://huggingface.co/Photoroom/photon-256-t2i-sft currently and see that only the Photoroom/photon-256-t2i model is currently in the Photon collection.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They were on a private repo. I made it public.

"MultiControlNetModel",
"OmniGenTransformer2DModel",
"ParallelConfig",
"PhotonTransformer2DModel",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add PhotonPipeline to the main __init__? As an example, here is how FluxPipeline is added:

"FluxPipeline",

FluxPipeline,

Also, could you add PhotonTransformer2DModel to the TYPE_CHECKING section of __init__? Here is how FluxTransformer2DModel is added:

FluxTransformer2DModel,

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that PhotonPipeline has been added in both places, but PhotonTransformer2DModel is still only added to the _import_structure part of the __init__ file. Could you add it to the other (TYPE_CHECKING) section as well? See e.g. FluxTransformer2DModel:

_import_structure:

"FluxTransformer2DModel",

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:

FluxTransformer2DModel,

Copy link
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes! The PR is close to merge, I think the most important things left are to fix the imports (e.g. #12456 (comment)) and other changes to make the CI green :).

Comment on lines 72 to 92
encoder_params = dict(
vocab_size=tokenizer.vocab_size,
hidden_size=8,
intermediate_size=16,
num_hidden_layers=1,
num_attention_heads=2,
num_key_value_heads=1,
head_dim=4,
max_position_embeddings=64,
layer_types=["full_attention"],
attention_bias=False,
attention_dropout=0.0,
dropout_rate=0.0,
hidden_activation="gelu_pytorch_tanh",
rms_norm_eps=1e-06,
attn_logit_softcapping=50.0,
final_logit_softcapping=30.0,
query_pre_attn_scalar=4,
rope_theta=10000.0,
sliding_window=4096,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
encoder_params = dict(
vocab_size=tokenizer.vocab_size,
hidden_size=8,
intermediate_size=16,
num_hidden_layers=1,
num_attention_heads=2,
num_key_value_heads=1,
head_dim=4,
max_position_embeddings=64,
layer_types=["full_attention"],
attention_bias=False,
attention_dropout=0.0,
dropout_rate=0.0,
hidden_activation="gelu_pytorch_tanh",
rms_norm_eps=1e-06,
attn_logit_softcapping=50.0,
final_logit_softcapping=30.0,
query_pre_attn_scalar=4,
rope_theta=10000.0,
sliding_window=4096,
)
encoder_params = {
"vocab_size": tokenizer.vocab_size,
"hidden_size": 8,
"intermediate_size": 16,
"num_hidden_layers": 1,
"num_attention_heads": 2,
"num_key_value_heads": 1,
"head_dim": 4,
"max_position_embeddings": 64,
"layer_types": ["full_attention"],
"attention_bias": False,
"attention_dropout": 0.0,
"dropout_rate": 0.0,
"hidden_activation": "gelu_pytorch_tanh",
"rms_norm_eps": 1e-06,
"attn_logit_softcapping": 50.0,
"final_logit_softcapping": 30.0,
"query_pre_attn_scalar": 4,
"rope_theta": 10000.0,
"sliding_window": 4096,
}

make style/make quality complain about the dict(...) call here and I think it will happier if a dict literal is used instead

Copy link
Author

@DavidBert DavidBert Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @dg845! Thanks for your new review.
I addressed all your new comments except for one about a default path because we do not not currently have an open-source implementation of our original model.
I also prepared a second PR based on this one where we rename Photon (already existing in the community) to PRX.
I did not include it here to ease the PR review, but can do it if you prefer.
Have a nice weekend!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be easier to merge this PR first, then do the renaming as a follow-up PR. CC @sayakpaul

timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)

return {
"image_latent": image_latent,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"image_latent": image_latent,
"hidden_states": image_latent,

To be consistent with suggested naming change in #12456 (comment)

return {
"image_latent": image_latent,
"timestep": timestep,
"cross_attn_conditioning": cross_attn_conditioning,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"cross_attn_conditioning": cross_attn_conditioning,
"encoder_hidden_states": cross_attn_conditioning,

To be consistent with suggested naming change in #12456 (comment)

Comment on lines 707 to 708
micro_conditioning (`torch.Tensor`):
Extra conditioning vector (currently unused, reserved for future use).
Copy link
Collaborator

@dg845 dg845 Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was removing micro_conditioning here (in bef0845) intentional? I think it would be fine to retain it and the transformer tests (specifically PhotonTransformerTests.prepare_dummy_input) also use this argument.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it was intentional, I removed it in the tests too.


class PhotonTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = PhotonTransformer2DModel
main_input_name = "image_latent"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
main_input_name = "image_latent"
main_input_name = "hidden_states"

To be consistent with the naming change suggested in #12456 (comment)

Copy link
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Can you confirm that the tests are working as expected after the new changes?

@DavidBert
Copy link
Author

Thanks! Can you confirm that the tests are working as expected after the new changes?

Very sorry, I forgot to verify this test.
Both test_pipeline_photon.py and test_models_transformer_photon.py are working now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants