-
Notifications
You must be signed in to change notification settings - Fork 6.5k
[Tencent Hunyuan Team] Add LoRA Inference Support for Hunyuan-DiT #8468
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -11,14 +11,15 @@ | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| from typing import Dict, Optional, Union | ||||
| from typing import Any, Dict, Optional, Union | ||||
|
|
||||
| import torch | ||||
| import torch.nn.functional as F | ||||
| from torch import nn | ||||
|
|
||||
| from ...configuration_utils import ConfigMixin, register_to_config | ||||
| from ...utils import logging | ||||
| from ...loaders import PeftAdapterMixin | ||||
| from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers | ||||
| from ...utils.torch_utils import maybe_allow_in_graph | ||||
| from ..attention import FeedForward | ||||
| from ..attention_processor import Attention, AttentionProcessor, HunyuanAttnProcessor2_0 | ||||
|
|
@@ -179,6 +180,7 @@ def forward( | |||
| temb: Optional[torch.Tensor] = None, | ||||
| image_rotary_emb=None, | ||||
| skip=None, | ||||
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, | ||||
| ) -> torch.Tensor: | ||||
| # Notice that normalization is always applied before the real computation in the following blocks. | ||||
| # 0. Long Skip Connection | ||||
|
|
@@ -209,7 +211,7 @@ def forward( | |||
| return hidden_states | ||||
|
|
||||
|
|
||||
| class HunyuanDiT2DModel(ModelMixin, ConfigMixin): | ||||
| class HunyuanDiT2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): | ||||
| """ | ||||
| HunYuanDiT: Diffusion model with a Transformer backbone. | ||||
|
|
||||
|
|
@@ -434,6 +436,7 @@ def forward( | |||
| text_embedding_mask=None, | ||||
| encoder_hidden_states_t5=None, | ||||
| text_embedding_mask_t5=None, | ||||
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, | ||||
| image_meta_size=None, | ||||
| style=None, | ||||
| image_rotary_emb=None, | ||||
|
|
@@ -457,6 +460,10 @@ def forward( | |||
| text_embedding_mask_t5: torch.Tensor | ||||
| An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output | ||||
| of T5 Text Encoder. | ||||
| cross_attention_kwargs (`dict`, *optional*): | ||||
| A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under | ||||
| `self.processor` in | ||||
| [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). | ||||
| image_meta_size (torch.Tensor): | ||||
| Conditional embedding indicate the image sizes | ||||
| style: torch.Tensor: | ||||
|
|
@@ -487,7 +494,21 @@ def forward( | |||
| text_embedding_mask = text_embedding_mask.unsqueeze(2).bool() | ||||
|
|
||||
| encoder_hidden_states = torch.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding) | ||||
|
|
||||
| # lora related | ||||
| # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated | ||||
| # to the internal blocks and will raise deprecation warnings. this will be confusing for our users. | ||||
| if cross_attention_kwargs is not None: | ||||
| cross_attention_kwargs = cross_attention_kwargs.copy() | ||||
| lora_scale = cross_attention_kwargs.pop("scale", 1.0) | ||||
| else: | ||||
| lora_scale = 1.0 | ||||
|
|
||||
| if USE_PEFT_BACKEND: | ||||
| # weight the lora layers by setting `lora_scale` for each PEFT layer | ||||
| scale_lora_layers(self, lora_scale) | ||||
|
|
||||
| # main forward network | ||||
| skips = [] | ||||
| for layer, block in enumerate(self.blocks): | ||||
| if layer > self.config.num_layers // 2: | ||||
|
|
@@ -498,13 +519,15 @@ def forward( | |||
| encoder_hidden_states=encoder_hidden_states, | ||||
| image_rotary_emb=image_rotary_emb, | ||||
| skip=skip, | ||||
| cross_attention_kwargs=cross_attention_kwargs, | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree with this! |
||||
| ) # (N, L, D) | ||||
| else: | ||||
| hidden_states = block( | ||||
| hidden_states, | ||||
| temb=temb, | ||||
| encoder_hidden_states=encoder_hidden_states, | ||||
| image_rotary_emb=image_rotary_emb, | ||||
| cross_attention_kwargs=cross_attention_kwargs, | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
| ) # (N, L, D) | ||||
|
|
||||
| if layer < (self.config.num_layers // 2 - 1): | ||||
|
|
||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,7 +13,7 @@ | |
| # limitations under the License. | ||
|
|
||
| import inspect | ||
| from typing import Callable, Dict, List, Optional, Tuple, Union | ||
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
@@ -23,6 +23,7 @@ | |
|
|
||
| from ...callbacks import MultiPipelineCallbacks, PipelineCallback | ||
| from ...image_processor import VaeImageProcessor | ||
| from ...loaders import LoraLoaderMixin | ||
| from ...models import AutoencoderKL, HunyuanDiT2DModel | ||
| from ...models.embeddings import get_2d_rotary_pos_embed | ||
| from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker | ||
|
|
@@ -138,7 +139,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): | |
| return noise_cfg | ||
|
|
||
|
|
||
| class HunyuanDiTPipeline(DiffusionPipeline): | ||
| class HunyuanDiTPipeline(DiffusionPipeline, LoraLoaderMixin): | ||
| r""" | ||
| Pipeline for English/Chinese-to-image generation using HunyuanDiT. | ||
|
|
||
|
|
@@ -239,6 +240,8 @@ def __init__( | |
| else 128 | ||
| ) | ||
|
|
||
| self.unet_name = 'transformer' # to support load_lora_weights | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, this feels a little weird to me. But okay for now. I will attempt to refactor this to harmonize. Cc: @yiyixuxu |
||
|
|
||
| def encode_prompt( | ||
| self, | ||
| prompt: str, | ||
|
|
@@ -558,6 +561,10 @@ def num_timesteps(self): | |
| @property | ||
| def interrupt(self): | ||
| return self._interrupt | ||
|
|
||
| @property | ||
| def cross_attention_kwargs(self): | ||
| return self._cross_attention_kwargs | ||
|
|
||
| @torch.no_grad() | ||
| @replace_example_docstring(EXAMPLE_DOC_STRING) | ||
|
|
@@ -592,6 +599,7 @@ def __call__( | |
| target_size: Optional[Tuple[int, int]] = None, | ||
| crops_coords_top_left: Tuple[int, int] = (0, 0), | ||
| use_resolution_binning: bool = True, | ||
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, | ||
| ): | ||
| r""" | ||
| The call function to the pipeline for generation with HunyuanDiT. | ||
|
|
@@ -663,7 +671,11 @@ def __call__( | |
| Whether to use resolution binning or not. If `True`, the input resolution will be mapped to the closest | ||
| standard resolution. Supported resolutions are 1024x1024, 1280x1280, 1024x768, 1152x864, 1280x960, | ||
| 768x1024, 864x1152, 960x1280, 1280x768, and 768x1280. It is recommended to set this to `True`. | ||
|
|
||
| cross_attention_kwargs (`dict`, *optional*): | ||
| A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under | ||
| `self.processor` in | ||
| [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). | ||
|
|
||
| Examples: | ||
|
|
||
| Returns: | ||
|
|
@@ -707,6 +719,7 @@ def __call__( | |
| ) | ||
| self._guidance_scale = guidance_scale | ||
| self._guidance_rescale = guidance_rescale | ||
| self._cross_attention_kwargs = cross_attention_kwargs | ||
| self._interrupt = False | ||
|
|
||
| # 2. Define call parameters | ||
|
|
@@ -720,7 +733,6 @@ def __call__( | |
| device = self._execution_device | ||
|
|
||
| # 3. Encode input prompt | ||
|
|
||
| ( | ||
| prompt_embeds, | ||
| negative_prompt_embeds, | ||
|
|
@@ -780,7 +792,7 @@ def __call__( | |
| # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline | ||
| extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | ||
|
|
||
| # 7 create image_rotary_emb, style embedding & time ids | ||
| # 7. Create image_rotary_emb, style embedding & time ids | ||
| grid_height = height // 8 // self.transformer.config.patch_size | ||
| grid_width = width // 8 // self.transformer.config.patch_size | ||
| base_size = 512 // 8 // self.transformer.config.patch_size | ||
|
|
@@ -837,6 +849,7 @@ def __call__( | |
| text_embedding_mask=prompt_attention_mask, | ||
| encoder_hidden_states_t5=prompt_embeds_2, | ||
| text_embedding_mask_t5=prompt_attention_mask_2, | ||
| cross_attention_kwargs=self.cross_attention_kwargs, | ||
| image_meta_size=add_time_ids, | ||
| style=style, | ||
| image_rotary_emb=image_rotary_emb, | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
| @@ -0,0 +1,13 @@ | ||||
| import torch | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we will remove this file before merge, no?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. Additionally, I think this should be turned into a proper test suite: Here is an example: diffusers/tests/lora/test_lora_layers_sd.py Line 205 in d457bee
|
||||
| from diffusers import HunyuanDiTPipeline | ||||
|
|
||||
| pipe = HunyuanDiTPipeline.from_pretrained("Tencent-Hunyuan/HunyuanDiT-Diffusers", torch_dtype=torch.float16) | ||||
| pipe.to("cuda") | ||||
|
|
||||
| pipe.load_lora_weights("YOUR_LORA_PATH", weight_name="lora_weights.pt", adapter_name="yushi") | ||||
|
|
||||
| prompt = "玉石绘画风格,一个人在雨中跳舞" | ||||
| image = pipe( | ||||
| prompt, num_inference_steps=50, generator=torch.manual_seed(0) | ||||
| ).images[0] | ||||
| image.save('img.png') | ||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.