diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 14dbfe3ea1d3..6ddb2765d9a8 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -366,6 +366,8 @@ title: LatteTransformer3DModel - local: api/models/ltx_video_transformer3d title: LTXVideoTransformer3DModel + - local: api/models/lumina2_accessory_transformer2d + title: Lumina2AccessoryTransformer2DModel - local: api/models/lumina2_transformer2d title: Lumina2Transformer2DModel - local: api/models/lumina_nextdit2d diff --git a/docs/source/en/api/models/lumina2_accessory_transformer2d.md b/docs/source/en/api/models/lumina2_accessory_transformer2d.md new file mode 100644 index 000000000000..49aaf97d9906 --- /dev/null +++ b/docs/source/en/api/models/lumina2_accessory_transformer2d.md @@ -0,0 +1,31 @@ + + +# Lumina2AccessoryTransformer2DModel + +A Diffusion Transformer model for 2D data from [Lumina-Accessory](https://github.com/Alpha-VLLM/Lumina-Accessory). by Alpha-VLLM. + +The model can be loaded with the following code snippet. + +```python +from diffusers import Lumina2AccessoryTransformer2DModel + +ckpt_path = "https://huggingface.co/Alpha-VLLM/Lumina-Accessory/blob/main/consolidated.00-of-01.pth" +transformer = Lumina2AccessoryTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16) +``` + +## Lumina2AccessoryTransformer2DModel + +[[autodoc]] Lumina2AccessoryTransformer2DModel + +## Transformer2DModelOutput + +[[autodoc]] models.modeling_outputs.Transformer2DModelOutput diff --git a/docs/source/en/api/pipelines/lumina2.md b/docs/source/en/api/pipelines/lumina2.md index 092d7cde2ebb..fac2055de2fd 100644 --- a/docs/source/en/api/pipelines/lumina2.md +++ b/docs/source/en/api/pipelines/lumina2.md @@ -80,8 +80,53 @@ image = pipe( image.save("lumina-gguf.png") ``` +## Lumina Accessory + +Lumina-Accessory is a multi-task instruction fine-tuning framework designed for the Lumina series. The official repository is from [Alpha-VLLM/Lumina-Accessory](https://github.com/Alpha-VLLM/Lumina-Accessory) + +```python +import torch +from diffusers import Lumina2AccessoryPipeline, Lumina2AccessoryTransformer2DModel +from diffusers.utils import load_image + +ckpt_path = "https://huggingface.co/Alpha-VLLM/Lumina-Accessory/blob/main/consolidated.00-of-01.pth" +transformer = Lumina2AccessoryTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16) +pipe = Lumina2AccessoryPipeline.from_pretrained( + "Alpha-VLLM/Lumina-Image-2.0", transformer=transformer, torch_dtype=torch.bfloat16 + ) + +# Enable memory optimizations. +pipe.enable_model_cpu_offload() + +img = load_image("https://github.com/Alpha-VLLM/Lumina-Accessory/blob/main/examples/case_1_condition.jpg?raw=true") +prompt = "A classical oil painting of a young woman dressed in a modern DARK BLACK leather jacket." +system_prompt = "You are an assistant designed to generate superior images with the highest degree of image-text alignment based on textual prompts and a partially masked image." +image = pipe( + image=img, + prompt=prompt, + system_prompt=system_prompt, + width=img.size[0], + height=img.size[1], + negative_prompt=" ", + num_inference_steps=25, + num_images_per_prompt=1, + guidance_scale=4.0, + cfg_trunc_ratio=1.0, + cfg_normalization=True, + generator=torch.Generator().manual_seed(42), + ).images[0] +image.save("lumina2_accessory_image_infliling.png") +``` + ## Lumina2Pipeline [[autodoc]] Lumina2Pipeline - all - __call__ + + +## Lumina2AccessoryPipeline + +[[autodoc]] Lumina2AccessoryPipeline + - all + - __call__ \ No newline at end of file diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 167d39c6e8df..1d7794aa5df8 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -221,6 +221,7 @@ "Kandinsky3UNet", "LatteTransformer3DModel", "LTXVideoTransformer3DModel", + "Lumina2AccessoryTransformer2DModel", "Lumina2Transformer2DModel", "LuminaNextDiT2DModel", "MochiTransformer3DModel", @@ -496,6 +497,7 @@ "LTXLatentUpsamplePipeline", "LTXPipeline", "LucyEditPipeline", + "Lumina2AccessoryPipeline", "Lumina2Pipeline", "Lumina2Text2ImgPipeline", "LuminaPipeline", @@ -906,6 +908,7 @@ Kandinsky3UNet, LatteTransformer3DModel, LTXVideoTransformer3DModel, + Lumina2AccessoryTransformer2DModel, Lumina2Transformer2DModel, LuminaNextDiT2DModel, MochiTransformer3DModel, @@ -1151,6 +1154,7 @@ LTXLatentUpsamplePipeline, LTXPipeline, LucyEditPipeline, + Lumina2AccessoryPipeline, Lumina2Pipeline, Lumina2Text2ImgPipeline, LuminaPipeline, diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index b53647d47630..8ba1def24265 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -134,6 +134,10 @@ "checkpoint_mapping_fn": convert_lumina2_to_diffusers, "default_subfolder": "transformer", }, + "Lumina2AccessoryTransformer2DModel": { + "checkpoint_mapping_fn": convert_lumina2_to_diffusers, + "default_subfolder": "transformer", + }, "SanaTransformer2DModel": { "checkpoint_mapping_fn": convert_sana_transformer_to_diffusers, "default_subfolder": "transformer", diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 49ac2a1c56fd..98b0de34debd 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -92,6 +92,7 @@ _import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] + _import_structure["transformers.transformer_lumina2_accessory"] = ["Lumina2AccessoryTransformer2DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"] _import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"] @@ -182,6 +183,7 @@ HunyuanVideoTransformer3DModel, LatteTransformer3DModel, LTXVideoTransformer3DModel, + Lumina2AccessoryTransformer2DModel, Lumina2Transformer2DModel, LuminaNextDiT2DModel, MochiTransformer3DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index b60f0636e6dc..a65a38954784 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -29,6 +29,7 @@ from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel from .transformer_ltx import LTXVideoTransformer3DModel from .transformer_lumina2 import Lumina2Transformer2DModel + from .transformer_lumina2_accessory import Lumina2AccessoryTransformer2DModel from .transformer_mochi import MochiTransformer3DModel from .transformer_omnigen import OmniGenTransformer2DModel from .transformer_qwenimage import QwenImageTransformer2DModel diff --git a/src/diffusers/models/transformers/transformer_lumina2_accessory.py b/src/diffusers/models/transformers/transformer_lumina2_accessory.py new file mode 100644 index 000000000000..653825f20d9c --- /dev/null +++ b/src/diffusers/models/transformers/transformer_lumina2_accessory.py @@ -0,0 +1,623 @@ +# Copyright 2025 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin +from ...loaders.single_file_model import FromOriginalModelMixin +from ...models.attention_processor import Attention +from ...models.embeddings import TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed +from ...models.modeling_outputs import Transformer2DModelOutput +from ...models.modeling_utils import ModelMixin +from ...models.normalization import LuminaLayerNormContinuous, RMSNorm +from ...models.transformers.transformer_lumina2 import ( + Lumina2AttnProcessor2_0, + LuminaFeedForward, + LuminaRMSNormZero, +) +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class Lumina2AccessoryCombinedTimestepCaptionEmbedding(nn.Module): + def __init__( + self, + hidden_size: int = 4096, + cap_feat_dim: int = 2048, + frequency_embedding_size: int = 256, + norm_eps: float = 1e-5, + ) -> None: + super().__init__() + + self.time_proj = Timesteps( + num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0 + ) + + self.timestep_embedder = TimestepEmbedding( + in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024) + ) + + self.caption_embedder = nn.Sequential( + RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, hidden_size, bias=True) + ) + + def forward( + self, hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + timestep_proj = self.time_proj(timestep).type_as(hidden_states) + cond_timestep_proj = self.time_proj(timestep * 0 + 1).type_as(hidden_states) + time_embed = self.timestep_embedder(timestep_proj) + cond_time_embed = self.timestep_embedder(cond_timestep_proj) + caption_embed = self.caption_embedder(encoder_hidden_states) + return time_embed, cond_time_embed, caption_embed + + +class Lumina2AccessoryRotaryPosEmbed(nn.Module): + def __init__( + self, + theta: int, + axes_dim: List[int], + axes_lens: List[int] = (300, 512, 512), + patch_size: int = 2, + ): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + self.axes_lens = axes_lens + self.patch_size = patch_size + + self.freqs_cis = self._precompute_freqs_cis(axes_dim, axes_lens, theta) + + def _precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]: + freqs_cis = [] + freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 + for i, (d, e) in enumerate(zip(axes_dim, axes_lens)): + emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=freqs_dtype) + freqs_cis.append(emb) + return freqs_cis + + def _get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor: + device = ids.device + if ids.device.type == "mps": + ids = ids.to("cpu") + + result = [] + for i in range(len(self.axes_dim)): + freqs = self.freqs_cis[i].to(ids.device) + index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64) + result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index)) + return torch.cat(result, dim=-1).to(device) + + def forward( + self, + hidden_states: torch.Tensor, + cond_hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + cond_position_type: Literal["aligned", "offset"], + ): + batch_size, channels, height, width = hidden_states.shape + _, _, cond_height, cond_width = cond_hidden_states.shape + p = self.patch_size + post_patch_height, post_patch_width = height // p, width // p + post_patch_cond_height, post_patch_cond_width = cond_height // p, cond_width // p + image_seq_len = post_patch_height * post_patch_width + cond_image_seq_len = post_patch_cond_height * post_patch_cond_width + device = hidden_states.device + + encoder_seq_len = attention_mask.shape[1] + l_effective_cap_len = attention_mask.sum(dim=1).tolist() + seq_lengths = [cap_seq_len + image_seq_len + cond_image_seq_len for cap_seq_len in l_effective_cap_len] + max_seq_len = max(seq_lengths) + + # Create position IDs + position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device) + + for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): + # add caption position ids + position_ids[i, :cap_seq_len, 0] = torch.arange(cap_seq_len, dtype=torch.int32, device=device) + + # add condition image position ids + position_ids[i, cap_seq_len:cond_image_seq_len, 0] = cap_seq_len + cond_row_ids = ( + torch.arange(post_patch_cond_height, dtype=torch.int32, device=device) + .view(-1, 1) + .repeat(1, post_patch_cond_width) + .flatten() + ) + cond_col_ids = ( + torch.arange(post_patch_cond_width, dtype=torch.int32, device=device) + .view(1, -1) + .repeat(post_patch_cond_height, 1) + .flatten() + ) + if cond_position_type == "aligned": + position_ids[i, cap_seq_len : cap_seq_len + cond_image_seq_len, 1] = cond_row_ids + position_ids[i, cap_seq_len : cap_seq_len + cond_image_seq_len, 2] = cond_col_ids + elif cond_position_type == "offset": + position_ids[i, cap_seq_len : cap_seq_len + cond_image_seq_len, 1] = cond_row_ids + post_patch_height + position_ids[i, cap_seq_len : cap_seq_len + cond_image_seq_len, 2] = cond_col_ids + post_patch_width + else: + raise ValueError( + f"Unknown cond_position_type: {cond_position_type}, must be one of ['aligned', 'offset']" + ) + + # add image position ids + position_ids[i, cap_seq_len + cond_image_seq_len : seq_len, 0] = cap_seq_len + 1 + + row_ids = ( + torch.arange(post_patch_height, dtype=torch.int32, device=device) + .view(-1, 1) + .repeat(1, post_patch_width) + .flatten() + ) + col_ids = ( + torch.arange(post_patch_width, dtype=torch.int32, device=device) + .view(1, -1) + .repeat(post_patch_height, 1) + .flatten() + ) + position_ids[i, cap_seq_len + cond_image_seq_len : seq_len, 1] = row_ids + position_ids[i, cap_seq_len + cond_image_seq_len : seq_len, 2] = col_ids + + # Get combined rotary embeddings + freqs_cis = self._get_freqs_cis(position_ids) + + # create separate rotary embeddings for captions and images + cap_freqs_cis = torch.zeros( + batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype + ) + cond_freqs_cis = torch.zeros( + batch_size, cond_image_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype + ) + img_freqs_cis = torch.zeros( + batch_size, image_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype + ) + + for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): + cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len] + cond_freqs_cis[i, :cond_image_seq_len] = freqs_cis[i, cap_seq_len : cap_seq_len + cond_image_seq_len] + img_freqs_cis[i, :image_seq_len] = freqs_cis[i, cap_seq_len + cond_image_seq_len : seq_len] + + # image patch embeddings + hidden_states = ( + hidden_states.view(batch_size, channels, post_patch_height, p, post_patch_width, p) + .permute(0, 2, 4, 3, 5, 1) + .flatten(3) + .flatten(1, 2) + ) + + cond_hidden_states = ( + cond_hidden_states.view(batch_size, channels, post_patch_cond_height, p, post_patch_cond_width, p) + .permute(0, 2, 4, 3, 5, 1) + .flatten(3) + .flatten(1, 2) + ) + + return ( + hidden_states, + cond_hidden_states, + cap_freqs_cis, + img_freqs_cis, + cond_freqs_cis, + freqs_cis, + l_effective_cap_len, + seq_lengths, + ) + + +class Lumina2AccessoryTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + num_kv_heads: int, + multiple_of: int, + ffn_dim_multiplier: float, + norm_eps: float, + modulation: bool = True, + ) -> None: + super().__init__() + self.head_dim = dim // num_attention_heads + self.modulation = modulation + + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // num_attention_heads, + qk_norm="rms_norm", + heads=num_attention_heads, + kv_heads=num_kv_heads, + eps=1e-5, + bias=False, + out_bias=False, + processor=Lumina2AttnProcessor2_0(), + ) + + self.feed_forward = LuminaFeedForward( + dim=dim, + inner_dim=4 * dim, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier, + ) + + if modulation: + self.norm1 = LuminaRMSNormZero( + embedding_dim=dim, + norm_eps=norm_eps, + norm_elementwise_affine=True, + ) + else: + self.norm1 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) + + self.norm2 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + image_rotary_emb: torch.Tensor, + temb: Optional[torch.Tensor] = None, + temb_cond: Optional[torch.Tensor] = None, + cond_seq_len: Optional[int] = None, + encoder_seq_lengths: Optional[List[int]] = None, + ) -> torch.Tensor: + if self.modulation: + norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) + + if temb_cond is not None: + assert cond_seq_len is not None and encoder_seq_lengths is not None, ( + "cond_seq_len and encoder_seq_lengths must be provided with temb_cond" + ) + + norm_hidden_states_cond, gate_msa_cond, scale_mlp_cond, gate_mlp_cond = self.norm1( + hidden_states, temb_cond + ) + + _, seq_len, _ = hidden_states.shape + device = hidden_states.device + seq_indices = torch.arange(seq_len, device=device) + start_indices = torch.tensor(encoder_seq_lengths, device=device, dtype=torch.long) + end_indices = start_indices + cond_seq_len + cond_mask = (seq_indices >= start_indices.unsqueeze(1)) & (seq_indices < end_indices.unsqueeze(1)) + cond_mask = cond_mask.unsqueeze(-1) + + norm_hidden_states = torch.where(cond_mask, norm_hidden_states_cond, norm_hidden_states) + + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + norm_attention = self.norm2(attn_output) + + gate_norm_attention = gate_msa.unsqueeze(1).tanh() * norm_attention + if temb_cond is not None: + gate_norm_attention_cond = gate_msa_cond.unsqueeze(1).tanh() * norm_attention + gate_norm_attention = torch.where(cond_mask, gate_norm_attention_cond, gate_norm_attention) + + hidden_states = hidden_states + gate_norm_attention + + norm_hidden_states = self.ffn_norm1(hidden_states) + + modulated_hidden_states = norm_hidden_states * (1 + scale_mlp.unsqueeze(1)) + if temb_cond is not None: + modulated_hidden_states_cond = norm_hidden_states * (1 + scale_mlp_cond.unsqueeze(1)) + modulated_hidden_states = torch.where(cond_mask, modulated_hidden_states_cond, modulated_hidden_states) + + ffn_out = self.feed_forward(modulated_hidden_states) + norm_ffn = self.ffn_norm2(ffn_out) + + gate_norm_mlp = gate_mlp.unsqueeze(1).tanh() * norm_ffn + if temb_cond is not None: + gate_norm_mlp_cond = gate_mlp_cond.unsqueeze(1).tanh() * norm_ffn + gate_norm_mlp = torch.where(cond_mask, gate_norm_mlp_cond, gate_norm_mlp) + + hidden_states = hidden_states + gate_norm_mlp + else: + norm_hidden_states = self.norm1(hidden_states) + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = hidden_states + self.norm2(attn_output) + mlp_output = self.feed_forward(self.ffn_norm1(hidden_states)) + hidden_states = hidden_states + self.ffn_norm2(mlp_output) + + return hidden_states + + +class Lumina2AccessoryTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): + r""" + Lumina2NextDiT: Diffusion model with a Transformer backbone. + + Parameters: + sample_size (`int`): The width of the latent images. This is fixed during training since + it is used to learn a number of position embeddings. + patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2): + The size of each patch in the image. This parameter defines the resolution of patches fed into the model. + in_channels (`int`, *optional*, defaults to 4): + The number of input channels for the model. Typically, this matches the number of channels in the input + images. + hidden_size (`int`, *optional*, defaults to 4096): + The dimensionality of the hidden layers in the model. This parameter determines the width of the model's + hidden representations. + num_layers (`int`, *optional*, default to 32): + The number of layers in the model. This defines the depth of the neural network. + num_attention_heads (`int`, *optional*, defaults to 32): + The number of attention heads in each attention layer. This parameter specifies how many separate attention + mechanisms are used. + num_kv_heads (`int`, *optional*, defaults to 8): + The number of key-value heads in the attention mechanism, if different from the number of attention heads. + If None, it defaults to num_attention_heads. + multiple_of (`int`, *optional*, defaults to 256): + A factor that the hidden size should be a multiple of. This can help optimize certain hardware + configurations. + ffn_dim_multiplier (`float`, *optional*): + A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on + the model configuration. + norm_eps (`float`, *optional*, defaults to 1e-5): + A small value added to the denominator for numerical stability in normalization layers. + scaling_factor (`float`, *optional*, defaults to 1.0): + A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the + overall scale of the model's operations. + + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["Lumina2AccessoryTransformerBlock"] + _skip_layerwise_casting_patterns = ["x_embedder", "norm"] + + @register_to_config + def __init__( + self, + sample_size: int = 128, + patch_size: int = 2, + in_channels: int = 16, + out_channels: Optional[int] = None, + hidden_size: int = 2304, + num_layers: int = 26, + num_refiner_layers: int = 2, + num_attention_heads: int = 24, + num_kv_heads: int = 8, + multiple_of: int = 256, + ffn_dim_multiplier: Optional[float] = None, + norm_eps: float = 1e-5, + scaling_factor: float = 1.0, + axes_dim_rope: Tuple[int, int, int] = (32, 32, 32), + axes_lens: Tuple[int, int, int] = (300, 512, 512), + cap_feat_dim: int = 1024, + ) -> None: + super().__init__() + self.out_channels = out_channels or in_channels + + # 1. Positional, patch & conditional embeddings + self.rope_embedder = Lumina2AccessoryRotaryPosEmbed( + theta=10000, + axes_dim=axes_dim_rope, + axes_lens=axes_lens, + patch_size=patch_size, + ) + + self.x_embedder = nn.Linear(in_features=patch_size * patch_size * in_channels, out_features=hidden_size) + + self.cond_embedder = nn.Linear(in_features=patch_size * patch_size * in_channels, out_features=hidden_size) + + self.time_caption_embed = Lumina2AccessoryCombinedTimestepCaptionEmbedding( + hidden_size=hidden_size, cap_feat_dim=cap_feat_dim, norm_eps=norm_eps + ) + + # 2. Noise, image condition and context refinement blocks + self.noise_refiner = nn.ModuleList( + [ + Lumina2AccessoryTransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=True, + ) + for _ in range(num_refiner_layers) + ] + ) + + self.cond_refiner = nn.ModuleList( + [ + Lumina2AccessoryTransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=True, + ) + for _ in range(num_refiner_layers) + ] + ) + + self.context_refiner = nn.ModuleList( + [ + Lumina2AccessoryTransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=False, + ) + for _ in range(num_refiner_layers) + ] + ) + + # 3. Transformer blocks + self.layers = nn.ModuleList( + [ + Lumina2AccessoryTransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=True, + ) + for _ in range(num_layers) + ] + ) + + # 4. Output norm & projection + self.norm_out = LuminaLayerNormContinuous( + embedding_dim=hidden_size, + conditioning_embedding_dim=min(hidden_size, 1024), + elementwise_affine=False, + eps=1e-6, + bias=True, + out_dim=patch_size * patch_size * self.out_channels, + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + cond_hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_attention_mask: torch.Tensor, + cond_position_type: Literal["aligned", "offset"] = "aligned", + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + # 1. Condition, positional & patch embedding + batch_size, _, height, width = hidden_states.shape + + temb, temb_cond, encoder_hidden_states = self.time_caption_embed( + hidden_states, timestep, encoder_hidden_states + ) + + ( + hidden_states, + cond_hidden_states, + context_rotary_emb, + noise_rotary_emb, + cond_rotary_emb, + rotary_emb, + encoder_seq_lengths, + seq_lengths, + ) = self.rope_embedder(hidden_states, cond_hidden_states, encoder_attention_mask, cond_position_type) + + hidden_states = self.x_embedder(hidden_states) + + cond_hidden_states = self.cond_embedder(cond_hidden_states) + + # 2. Context & noise refinement + for layer in self.context_refiner: + encoder_hidden_states = layer(encoder_hidden_states, encoder_attention_mask, context_rotary_emb) + + for layer in self.noise_refiner: + hidden_states = layer(hidden_states, None, noise_rotary_emb, temb) + + for layer in self.cond_refiner: + cond_hidden_states = layer(cond_hidden_states, None, cond_rotary_emb, temb_cond) + + # 3. Joint Transformer blocks + max_seq_len = max(seq_lengths) + use_mask = len(set(seq_lengths)) > 1 + + attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool) + joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size) + cond_seq_len = cond_hidden_states.shape[1] + for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)): + attention_mask[i, :seq_len] = True + joint_hidden_states[i, :encoder_seq_len] = encoder_hidden_states[i, :encoder_seq_len] + joint_hidden_states[i, encoder_seq_len : encoder_seq_len + cond_seq_len] = cond_hidden_states[i] + joint_hidden_states[i, encoder_seq_len + cond_seq_len : seq_len] = hidden_states[i] + + hidden_states = joint_hidden_states + + for layer in self.layers: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + layer, + hidden_states, + attention_mask if use_mask else None, + rotary_emb, + temb, + temb_cond, + cond_seq_len, + encoder_seq_lengths, + ) + else: + hidden_states = layer( + hidden_states, + attention_mask if use_mask else None, + rotary_emb, + temb, + temb_cond, + cond_seq_len, + encoder_seq_lengths, + ) + + # 4. Output norm & projection + hidden_states = self.norm_out(hidden_states, temb) + + # 5. Unpatchify + p = self.config.patch_size + output = [] + for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)): + output.append( + hidden_states[i][encoder_seq_len + cond_seq_len : seq_len] + .view(height // p, width // p, p, p, self.out_channels) + .permute(4, 0, 2, 1, 3) + .flatten(3, 4) + .flatten(1, 2) + ) + output = torch.stack(output, dim=0) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 17f3fc909e4d..2bc529c0b989 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -284,7 +284,7 @@ "LTXLatentUpsamplePipeline", ] _import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"] - _import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"] + _import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline", "Lumina2AccessoryPipeline"] _import_structure["lucy"] = ["LucyEditPipeline"] _import_structure["marigold"].extend( [ @@ -685,7 +685,7 @@ from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline from .lucy import LucyEditPipeline from .lumina import LuminaPipeline, LuminaText2ImgPipeline - from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline + from .lumina2 import Lumina2AccessoryPipeline, Lumina2Pipeline, Lumina2Text2ImgPipeline from .marigold import ( MarigoldDepthPipeline, MarigoldIntrinsicsPipeline, diff --git a/src/diffusers/pipelines/lumina2/__init__.py b/src/diffusers/pipelines/lumina2/__init__.py index b1d6bfeb0d58..f59afe374e2e 100644 --- a/src/diffusers/pipelines/lumina2/__init__.py +++ b/src/diffusers/pipelines/lumina2/__init__.py @@ -23,6 +23,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"] + _import_structure["pipeline_lumina2_accessory"] = ["Lumina2AccessoryPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -33,6 +34,7 @@ from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline + from .pipeline_lumina2_accessory import Lumina2AccessoryPipeline else: import sys diff --git a/src/diffusers/pipelines/lumina2/pipeline_lumina2_accessory.py b/src/diffusers/pipelines/lumina2/pipeline_lumina2_accessory.py new file mode 100644 index 000000000000..55f94cbca627 --- /dev/null +++ b/src/diffusers/pipelines/lumina2/pipeline_lumina2_accessory.py @@ -0,0 +1,870 @@ +# Copyright 2025 Alpha-VLLM and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import Lumina2LoraLoaderMixin +from ...models import AutoencoderKL +from ...models.transformers.transformer_lumina2_accessory import Lumina2AccessoryTransformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm # type: ignore + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ``` + >>> import torch + >>> from diffusers import Lumina2AccessoryPipeline, Lumina2AccessoryTransformer2DModel + >>> from diffusers.utils import load_image + + >>> ckpt_path = "https://huggingface.co/Alpha-VLLM/Lumina-Accessory/blob/main/consolidated.00-of-01.pth" + >>> transformer = Lumina2AccessoryTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16) + >>> pipe = Lumina2AccessoryPipeline.from_pretrained( + "Alpha-VLLM/Lumina-Image-2.0", transformer=transformer, torch_dtype=torch.bfloat16 + ) + + >>> # Enable memory optimizations. + >>> pipe.enable_model_cpu_offload() + + >>> img = load_image("https://github.com/Alpha-VLLM/Lumina-Accessory/blob/main/examples/case_1_condition.jpg?raw=true") + >>> prompt = "A classical oil painting of a young woman dressed in a modern DARK BLACK leather jacket." + >>> system_prompt = "You are an assistant designed to generate superior images with the highest degree of image-text alignment based on textual prompts and a partially masked image." + >>> image = pipe( + image=img, + prompt=prompt, + system_prompt=system_prompt, + width=img.size[0], + height=img.size[1], + negative_prompt=" ", + num_inference_steps=25, + num_images_per_prompt=1, + guidance_scale=4.0, + cfg_trunc_ratio=1.0, + cfg_normalization=True, + generator=torch.Generator().manual_seed(42), + ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class Lumina2AccessoryPipeline(DiffusionPipeline, Lumina2LoraLoaderMixin): + r""" + The Lumina Accessory pipeline uses Lumina-Image-2.0 for text-to-image and image-to-image generation. + + Reference https://github.com/Alpha-VLLM/Lumina-Accessory + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Gemma2PreTrainedModel`]): + Frozen Gemma2 text-encoder. + tokenizer (`GemmaTokenizer` or `GemmaTokenizerFast`): + Gemma tokenizer. + transformer ([`Transformer2DModel`]): + A text conditioned `Transformer2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + def __init__( + self, + transformer: Lumina2AccessoryTransformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: Gemma2PreTrainedModel, + tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast], + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) + self.default_image_size = self.default_sample_size * self.vae_scale_factor + self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts." + + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + def _get_gemma_prompt_embeds( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + max_sequence_length: int = 256, + ) -> Tuple[torch.Tensor, torch.Tensor]: + device = device or self._execution_device + prompt = [prompt] if isinstance(prompt, str) else prompt + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device) + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because Gemma can only handle sequences up to" + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = self.text_encoder( + text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ) + prompt_embeds = prompt_embeds.hidden_states[-2] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + return prompt_embeds, prompt_attention_mask + + # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + system_prompt: Optional[str] = None, + max_sequence_length: int = 256, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + Lumina-T2I, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string. + max_sequence_length (`int`, defaults to `256`): + Maximum sequence length to use for the prompt. + """ + if device is None: + device = self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if system_prompt is None: + system_prompt = self.system_prompt + if prompt is not None: + prompt = [system_prompt + " " + p for p in prompt] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + device=device, + max_sequence_length=max_sequence_length, + ) + + batch_size, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1) + + # Get negative embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt if negative_prompt is not None else "" + + # Normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + device=device, + max_sequence_length=max_sequence_length, + ) + + batch_size, seq_len, _ = negative_prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + negative_prompt_attention_mask = negative_prompt_attention_mask.view( + batch_size * num_images_per_prompt, -1 + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def prepare_latents( + self, + image: Optional[torch.Tensor], + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if image is not None: + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + return latents, image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: Optional[PipelineImageInput] = None, + prompt: Union[str, List[str]] = None, + width: Optional[int] = None, + height: Optional[int] = None, + num_inference_steps: int = 30, + guidance_scale: float = 4.0, + negative_prompt: Union[str, List[str]] = None, + sigmas: List[float] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + system_prompt: Optional[str] = None, + cfg_trunc_ratio: float = 1.0, + cfg_normalization: bool = True, + max_sequence_length: int = 256, + cond_position_type: Literal["aligned", "offset"] = "aligned", + ) -> Union[ImagePipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 30): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Lumina-T2I this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + attention_kwargs: + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + system_prompt (`str`, *optional*): + The system prompt to use for the image generation. + cfg_trunc_ratio (`float`, *optional*, defaults to `1.0`): + The ratio of the timestep interval to apply normalization-based guidance scale. + cfg_normalization (`bool`, *optional*, defaults to `True`): + Whether to apply normalization-based guidance scale. + max_sequence_length (`int`, defaults to `256`): + Maximum sequence length to use with the `prompt`. + cond_position_type (`str`, *optional*, defaults to `"aligned"`): + The position type to use for the conditional embeddings. Can be `"aligned"` or `"offset"`. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + multiple_of = self.vae_scale_factor * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt / Preprocess input image + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + system_prompt=system_prompt, + ) + + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + img = image[0] if isinstance(image, list) else image + image_height, image_width = self.image_processor.get_default_height_width(img) + image_width = image_width // multiple_of * multiple_of + image_height = image_height // multiple_of * multiple_of + image = self.image_processor.resize(image, image_height, image_width) + image = self.image_processor.preprocess(image, image_height, image_width) + + # 4. Prepare latents. + num_channels_latents = self.transformer.config.in_channels + latents, image_latents = self.prepare_latents( + image, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # compute whether apply classifier-free truncation on this timestep + do_classifier_free_truncation = (i + 1) / num_inference_steps > cfg_trunc_ratio + # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image + current_timestep = 1 - t / self.scheduler.config.num_train_timesteps + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand(latents.shape[0]) + + noise_pred_cond = self.transformer( + hidden_states=latents, + cond_hidden_states=image_latents, + timestep=current_timestep, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + cond_position_type=cond_position_type, + return_dict=False, + attention_kwargs=self.attention_kwargs, + )[0] + + # perform normalization-based guidance scale on a truncated timestep interval + if self.do_classifier_free_guidance and not do_classifier_free_truncation: + noise_pred_uncond = self.transformer( + hidden_states=latents, + cond_hidden_states=image_latents, + timestep=current_timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_attention_mask=negative_prompt_attention_mask, + cond_position_type=cond_position_type, + return_dict=False, + attention_kwargs=self.attention_kwargs, + )[0] + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + # apply normalization after classifier-free guidance + if cfg_normalization: + cond_norm = torch.norm(noise_pred_cond, dim=-1, keepdim=True) + noise_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_pred = noise_pred * (cond_norm / noise_norm) + else: + noise_pred = noise_pred_cond + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + noise_pred = -noise_pred + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + else: + image = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index bbb971249604..64044eba44cf 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -933,6 +933,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class Lumina2AccessoryTransformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class Lumina2Transformer2DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index e29be174f02e..0b0ac9f7e066 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1607,6 +1607,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class Lumina2AccessoryPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class Lumina2Pipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/transformers/test_models_transformer_lumina2_accessory.py b/tests/models/transformers/test_models_transformer_lumina2_accessory.py new file mode 100644 index 000000000000..67799391fd78 --- /dev/null +++ b/tests/models/transformers/test_models_transformer_lumina2_accessory.py @@ -0,0 +1,91 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import Lumina2AccessoryTransformer2DModel + +from ...testing_utils import ( + enable_full_determinism, + torch_device, +) +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class Lumina2AccessoryTransformer2DModelTransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = Lumina2AccessoryTransformer2DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True + + @property + def dummy_input(self): + batch_size = 2 # N + num_channels = 4 # C + height = width = 16 # H, W + embedding_dim = 32 # D + sequence_length = 16 # L + + hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device) + cond_hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + timestep = torch.rand(size=(batch_size,)).to(torch_device) + attention_mask = torch.ones(size=(batch_size, sequence_length), dtype=torch.bool).to(torch_device) + + return { + "hidden_states": hidden_states, + "cond_hidden_states": cond_hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "timestep": timestep, + "encoder_attention_mask": attention_mask, + } + + @property + def input_shape(self): + return (4, 16, 16) + + @property + def output_shape(self): + return (4, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "sample_size": 16, + "patch_size": 2, + "in_channels": 4, + "hidden_size": 24, + "num_layers": 2, + "num_refiner_layers": 1, + "num_attention_heads": 3, + "num_kv_heads": 1, + "multiple_of": 2, + "ffn_dim_multiplier": None, + "norm_eps": 1e-5, + "scaling_factor": 1.0, + "axes_dim_rope": (4, 2, 2), + "axes_lens": (128, 128, 128), + "cap_feat_dim": 32, + } + + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"Lumina2AccessoryTransformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/pipelines/lumina2/test_pipeline_lumina2_accessory.py b/tests/pipelines/lumina2/test_pipeline_lumina2_accessory.py new file mode 100644 index 000000000000..91558eb68908 --- /dev/null +++ b/tests/pipelines/lumina2/test_pipeline_lumina2_accessory.py @@ -0,0 +1,177 @@ +import unittest + +import torch +from PIL import Image +from transformers import AutoTokenizer, Gemma2Config, Gemma2Model + +from diffusers import ( + AutoencoderKL, + FlowMatchEulerDiscreteScheduler, + Lumina2AccessoryPipeline, + Lumina2AccessoryTransformer2DModel, +) + +from ...testing_utils import ( + enable_full_determinism, + torch_device, +) +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class Lumina2AccessoryPipelineFastTests(unittest.TestCase, PipelineTesterMixin): + pipeline_class = Lumina2AccessoryPipeline + params = frozenset( + [ + "prompt", + "image", + "height", + "width", + "guidance_scale", + "negative_prompt", + "prompt_embeds", + "negative_prompt_embeds", + ] + ) + batch_params = frozenset( + [ + "prompt", + "image", + "negative_prompt", + ] + ) + image_params = frozenset(["image"]) + image_latents_params = frozenset(["latents"]) + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + + supports_dduf = False + test_xformers_attention = False + test_layerwise_casting = True + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = Lumina2AccessoryTransformer2DModel( + sample_size=4, + patch_size=2, + in_channels=4, + hidden_size=8, + num_layers=2, + num_attention_heads=1, + num_kv_heads=1, + multiple_of=16, + ffn_dim_multiplier=None, + norm_eps=1e-5, + scaling_factor=1.0, + axes_dim_rope=[4, 2, 2], + cap_feat_dim=8, + ) + + torch.manual_seed(0) + vae = AutoencoderKL( + sample_size=32, + in_channels=3, + out_channels=3, + block_out_channels=(4,), + layers_per_block=1, + latent_channels=4, + norm_num_groups=1, + use_quant_conv=False, + use_post_quant_conv=False, + shift_factor=0.0609, + scaling_factor=1.5035, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/dummy-gemma") + + torch.manual_seed(0) + config = Gemma2Config( + head_dim=4, + hidden_size=8, + intermediate_size=8, + num_attention_heads=2, + num_hidden_layers=2, + num_key_value_heads=2, + sliding_window=2, + ) + text_encoder = Gemma2Model(config) + + components = { + "transformer": transformer, + "vae": vae.eval(), + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "negative_prompt": "bad quality", + "image": Image.new("RGB", (32, 32)), + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 4.0, + "height": 32, + "width": 32, + "output_type": "np", + } + return inputs + + def test_lumina2_accessory_batch_inputs(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components).to(torch_device) + + inputs = self.get_dummy_inputs(device=torch_device) + + inputs["prompt"] = ["A squirrel", "A cat"] + inputs["image"] = [Image.new("RGB", (32, 32)), Image.new("RGB", (32, 32))] + + output = pipe(**inputs) + assert len(output.images) == 2 + + def test_lumina2_accessory_image_output_shape(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + height_width_pairs = [(32, 32), (72, 57)] + for height, width in height_width_pairs: + expected_height = height - height % (pipe.vae_scale_factor * 2) + expected_width = width - width % (pipe.vae_scale_factor * 2) + + inputs.update({"height": height, "width": width}) + image = pipe(**inputs).images[0] + output_height, output_width, _ = image.shape + assert (output_height, output_width) == (expected_height, expected_width) + + def test_lumina2_accessory_guidance_scale_effect(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components).to(torch_device) + + inputs = self.get_dummy_inputs(device=torch_device) + # run with default guidance_scale + output1 = pipe(**inputs) + + # run with zero guidance_scale + inputs["guidance_scale"] = 0.0 + output2 = pipe(**inputs) + + # outputs should not be exactly equal + assert not (output1.images[0] == output2.images[0]).all()