Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,8 @@
title: PixArtTransformer2DModel
- local: api/models/prior_transformer
title: PriorTransformer
- local: api/models/qwenimage_transformer2d
title: QwenImageTransformer2DModel
- local: api/models/sana_transformer2d
title: SanaTransformer2DModel
- local: api/models/sd3_transformer2d
Expand Down Expand Up @@ -418,6 +420,8 @@
title: AutoencoderKLMagvit
- local: api/models/autoencoderkl_mochi
title: AutoencoderKLMochi
- local: api/models/autoencoderkl_qwenimage
title: AutoencoderKLQwenImage
- local: api/models/autoencoder_kl_wan
title: AutoencoderKLWan
- local: api/models/consistency_decoder_vae
Expand Down Expand Up @@ -554,6 +558,8 @@
title: PixArt-α
- local: api/pipelines/pixart_sigma
title: PixArt-Σ
- local: api/pipelines/qwenimage
title: QwenImage
- local: api/pipelines/sana
title: Sana
- local: api/pipelines/sana_sprint
Expand Down
35 changes: 35 additions & 0 deletions docs/source/en/api/models/autoencoderkl_qwenimage.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->

# AutoencoderKLQwenImage

The model can be loaded with the following code snippet.

```python
from diffusers import AutoencoderKLQwenImage

vae = AutoencoderKLQwenImage.from_pretrained("Qwen/QwenImage-20B", subfolder="vae")
```

## AutoencoderKLQwenImage

[[autodoc]] AutoencoderKLQwenImage
- decode
- encode
- all

## AutoencoderKLOutput

[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput

## DecoderOutput

[[autodoc]] models.autoencoders.vae.DecoderOutput
28 changes: 28 additions & 0 deletions docs/source/en/api/models/qwenimage_transformer2d.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->

# QwenImageTransformer2DModel

The model can be loaded with the following code snippet.

```python
from diffusers import QwenImageTransformer2DModel

transformer = QwenImageTransformer2DModel.from_pretrained("Qwen/QwenImage-20B", subfolder="transformer", torch_dtype=torch.bfloat16)
```

## QwenImageTransformer2DModel

[[autodoc]] QwenImageTransformer2DModel

## Transformer2DModelOutput

[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
33 changes: 33 additions & 0 deletions docs/source/en/api/pipelines/qwenimage.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->

# QwenImage

<!-- TODO: update this section when model is out -->

<Tip>

Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.

</Tip>

## QwenImagePipeline

[[autodoc]] QwenImagePipeline
- all
- __call__

## QwenImagePipeline

[[autodoc]] pipelines.qwenimage.pipeline_output.QwenImagePipelineOutput
40 changes: 4 additions & 36 deletions src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,7 @@ class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):

_supports_gradient_checkpointing = False

# fmt: off
@register_to_config
def __init__(
self,
Expand All @@ -678,43 +679,10 @@ def __init__(
attn_scales: List[float] = [],
temperal_downsample: List[bool] = [False, True, True],
dropout: float = 0.0,
latents_mean: List[float] = [
-0.7571,
-0.7089,
-0.9113,
0.1075,
-0.1745,
0.9653,
-0.1517,
1.5508,
0.4134,
-0.0715,
0.5517,
-0.3632,
-0.1922,
-0.9497,
0.2503,
-0.2921,
],
latents_std: List[float] = [
2.8184,
1.4541,
2.3275,
2.6558,
1.2196,
1.7708,
2.6052,
2.0743,
3.2687,
2.1526,
2.8652,
1.5579,
1.6382,
1.1253,
2.8251,
1.9160,
],
latents_mean: List[float] = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921],
latents_std: List[float] = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160],
) -> None:
# fmt: on
super().__init__()

self.z_dim = z_dim
Expand Down
28 changes: 11 additions & 17 deletions src/diffusers/models/transformers/transformer_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def apply_rotary_emb_qwen(


class QwenTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim):
def __init__(self, embedding_dim):
super().__init__()

self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
Expand Down Expand Up @@ -473,8 +473,6 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
joint_attention_dim (`int`, defaults to `3584`):
The number of dimensions to use for the joint attention (embedding/channel dimension of
`encoder_hidden_states`).
pooled_projection_dim (`int`, defaults to `768`):
The number of dimensions to use for the pooled projection.
guidance_embeds (`bool`, defaults to `False`):
Whether to use guidance embeddings for guidance-distilled variant of the model.
axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
Expand All @@ -495,8 +493,7 @@ def __init__(
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 3584,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
guidance_embeds: bool = False, # TODO: this should probably be removed
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @naykun can you confirm if we need this config?

axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
):
super().__init__()
Expand All @@ -505,9 +502,7 @@ def __init__(

self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)

self.time_text_embed = QwenTimestepProjEmbeddings(
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
)
self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim)

self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)

Expand Down Expand Up @@ -538,10 +533,9 @@ def forward(
timestep: torch.LongTensor = None,
img_shapes: Optional[List[Tuple[int, int, int]]] = None,
txt_seq_lens: Optional[List[int]] = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance: torch.Tensor = None, # TODO: this should probably be removed
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
controlnet_blocks_repeat: bool = False,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
"""
The [`QwenTransformer2DModel`] forward method.
Expand All @@ -555,7 +549,7 @@ def forward(
Mask of the input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
joint_attention_kwargs (`dict`, *optional*):
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
Expand All @@ -567,17 +561,17 @@ def forward(
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0

if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
Expand Down Expand Up @@ -617,7 +611,7 @@ def forward(
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
joint_attention_kwargs=attention_kwargs,
)

# Use only the image part (hidden_states) from the dual-stream blocks
Expand Down
Loading
Loading