Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions src/diffusers/models/transformers/hunyuan_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional, Union
from typing import Any, Dict, Optional, Union

import torch
import torch.nn.functional as F
from torch import nn

from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention, AttentionProcessor, HunyuanAttnProcessor2_0
Expand Down Expand Up @@ -179,6 +180,7 @@ def forward(
temb: Optional[torch.Tensor] = None,
image_rotary_emb=None,
skip=None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cross_attention_kwargs: Optional[Dict[str, Any]] = None,

) -> torch.Tensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Long Skip Connection
Expand Down Expand Up @@ -209,7 +211,7 @@ def forward(
return hidden_states


class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
class HunyuanDiT2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
"""
HunYuanDiT: Diffusion model with a Transformer backbone.

Expand Down Expand Up @@ -434,6 +436,7 @@ def forward(
text_embedding_mask=None,
encoder_hidden_states_t5=None,
text_embedding_mask_t5=None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
image_meta_size=None,
style=None,
image_rotary_emb=None,
Expand All @@ -457,6 +460,10 @@ def forward(
text_embedding_mask_t5: torch.Tensor
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
of T5 Text Encoder.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
image_meta_size (torch.Tensor):
Conditional embedding indicate the image sizes
style: torch.Tensor:
Expand Down Expand Up @@ -487,7 +494,21 @@ def forward(
text_embedding_mask = text_embedding_mask.unsqueeze(2).bool()

encoder_hidden_states = torch.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding)

# lora related
# we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
# to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
if cross_attention_kwargs is not None:
cross_attention_kwargs = cross_attention_kwargs.copy()
lora_scale = cross_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0

if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)

# main forward network
skips = []
for layer, block in enumerate(self.blocks):
if layer > self.config.num_layers // 2:
Expand All @@ -498,13 +519,15 @@ def forward(
encoder_hidden_states=encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
skip=skip,
cross_attention_kwargs=cross_attention_kwargs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cross_attention_kwargs=cross_attention_kwargs,

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with this!

) # (N, L, D)
else:
hidden_states = block(
hidden_states,
temb=temb,
encoder_hidden_states=encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
cross_attention_kwargs=cross_attention_kwargs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cross_attention_kwargs=cross_attention_kwargs,

) # (N, L, D)

if layer < (self.config.num_layers // 2 - 1):
Expand Down
23 changes: 18 additions & 5 deletions src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import inspect
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
Expand All @@ -23,6 +23,7 @@

from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin
from ...models import AutoencoderKL, HunyuanDiT2DModel
from ...models.embeddings import get_2d_rotary_pos_embed
from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
Expand Down Expand Up @@ -138,7 +139,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg


class HunyuanDiTPipeline(DiffusionPipeline):
class HunyuanDiTPipeline(DiffusionPipeline, LoraLoaderMixin):
r"""
Pipeline for English/Chinese-to-image generation using HunyuanDiT.

Expand Down Expand Up @@ -239,6 +240,8 @@ def __init__(
else 128
)

self.unet_name = 'transformer' # to support load_lora_weights
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this feels a little weird to me. But okay for now. I will attempt to refactor this to harmonize. Cc: @yiyixuxu


def encode_prompt(
self,
prompt: str,
Expand Down Expand Up @@ -558,6 +561,10 @@ def num_timesteps(self):
@property
def interrupt(self):
return self._interrupt

@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs

@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
Expand Down Expand Up @@ -592,6 +599,7 @@ def __call__(
target_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
use_resolution_binning: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
):
r"""
The call function to the pipeline for generation with HunyuanDiT.
Expand Down Expand Up @@ -663,7 +671,11 @@ def __call__(
Whether to use resolution binning or not. If `True`, the input resolution will be mapped to the closest
standard resolution. Supported resolutions are 1024x1024, 1280x1280, 1024x768, 1152x864, 1280x960,
768x1024, 864x1152, 960x1280, 1280x768, and 768x1280. It is recommended to set this to `True`.

cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).

Examples:

Returns:
Expand Down Expand Up @@ -707,6 +719,7 @@ def __call__(
)
self._guidance_scale = guidance_scale
self._guidance_rescale = guidance_rescale
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False

# 2. Define call parameters
Expand All @@ -720,7 +733,6 @@ def __call__(
device = self._execution_device

# 3. Encode input prompt

(
prompt_embeds,
negative_prompt_embeds,
Expand Down Expand Up @@ -780,7 +792,7 @@ def __call__(
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

# 7 create image_rotary_emb, style embedding & time ids
# 7. Create image_rotary_emb, style embedding & time ids
grid_height = height // 8 // self.transformer.config.patch_size
grid_width = width // 8 // self.transformer.config.patch_size
base_size = 512 // 8 // self.transformer.config.patch_size
Expand Down Expand Up @@ -837,6 +849,7 @@ def __call__(
text_embedding_mask=prompt_attention_mask,
encoder_hidden_states_t5=prompt_embeds_2,
text_embedding_mask_t5=prompt_attention_mask_2,
cross_attention_kwargs=self.cross_attention_kwargs,
image_meta_size=add_time_ids,
style=style,
image_rotary_emb=image_rotary_emb,
Expand Down
13 changes: 13 additions & 0 deletions test_hunyuan_dit_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we will remove this file before merge, no?
in the future, maybe it's easier to just post testing example in PR description:)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Additionally, I think this should be turned into a proper test suite: test_lora_layers_hunyuan_dit.py. Just including a SLOW test is sufficient for the time being.

Here is an example:

class LoraIntegrationTests(unittest.TestCase):
.

from diffusers import HunyuanDiTPipeline

pipe = HunyuanDiTPipeline.from_pretrained("Tencent-Hunyuan/HunyuanDiT-Diffusers", torch_dtype=torch.float16)
pipe.to("cuda")

pipe.load_lora_weights("YOUR_LORA_PATH", weight_name="lora_weights.pt", adapter_name="yushi")

prompt = "玉石绘画风格,一个人在雨中跳舞"
image = pipe(
prompt, num_inference_steps=50, generator=torch.manual_seed(0)
).images[0]
image.save('img.png')