From 77faaf59ad778d5a87c9b6e7535d9600d1c1255b Mon Sep 17 00:00:00 2001 From: XCLiu Date: Tue, 11 Jun 2024 11:53:11 +0800 Subject: [PATCH] add lora support and test --- .../transformers/hunyuan_transformer_2d.py | 29 +++++++++++++++++-- .../hunyuandit/pipeline_hunyuandit.py | 23 +++++++++++---- test_hunyuan_dit_lora.py | 13 +++++++++ 3 files changed, 57 insertions(+), 8 deletions(-) create mode 100644 test_hunyuan_dit_lora.py diff --git a/src/diffusers/models/transformers/hunyuan_transformer_2d.py b/src/diffusers/models/transformers/hunyuan_transformer_2d.py index fdc3410e2454..591624752bd7 100644 --- a/src/diffusers/models/transformers/hunyuan_transformer_2d.py +++ b/src/diffusers/models/transformers/hunyuan_transformer_2d.py @@ -11,14 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional, Union +from typing import Any, Dict, Optional, Union import torch import torch.nn.functional as F from torch import nn from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import logging +from ...loaders import PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward from ..attention_processor import Attention, AttentionProcessor, HunyuanAttnProcessor2_0 @@ -179,6 +180,7 @@ def forward( temb: Optional[torch.Tensor] = None, image_rotary_emb=None, skip=None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: # Notice that normalization is always applied before the real computation in the following blocks. # 0. Long Skip Connection @@ -209,7 +211,7 @@ def forward( return hidden_states -class HunyuanDiT2DModel(ModelMixin, ConfigMixin): +class HunyuanDiT2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): """ HunYuanDiT: Diffusion model with a Transformer backbone. @@ -434,6 +436,7 @@ def forward( text_embedding_mask=None, encoder_hidden_states_t5=None, text_embedding_mask_t5=None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, image_meta_size=None, style=None, image_rotary_emb=None, @@ -457,6 +460,10 @@ def forward( text_embedding_mask_t5: torch.Tensor An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output of T5 Text Encoder. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). image_meta_size (torch.Tensor): Conditional embedding indicate the image sizes style: torch.Tensor: @@ -487,7 +494,21 @@ def forward( text_embedding_mask = text_embedding_mask.unsqueeze(2).bool() encoder_hidden_states = torch.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding) + + # lora related + # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated + # to the internal blocks and will raise deprecation warnings. this will be confusing for our users. + if cross_attention_kwargs is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + lora_scale = cross_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + # main forward network skips = [] for layer, block in enumerate(self.blocks): if layer > self.config.num_layers // 2: @@ -498,6 +519,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, image_rotary_emb=image_rotary_emb, skip=skip, + cross_attention_kwargs=cross_attention_kwargs, ) # (N, L, D) else: hidden_states = block( @@ -505,6 +527,7 @@ def forward( temb=temb, encoder_hidden_states=encoder_hidden_states, image_rotary_emb=image_rotary_emb, + cross_attention_kwargs=cross_attention_kwargs, ) # (N, L, D) if layer < (self.config.num_layers // 2 - 1): diff --git a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py index 86089abc07b4..da61b2415280 100644 --- a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +++ b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py @@ -13,7 +13,7 @@ # limitations under the License. import inspect -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -23,6 +23,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin from ...models import AutoencoderKL, HunyuanDiT2DModel from ...models.embeddings import get_2d_rotary_pos_embed from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker @@ -138,7 +139,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): return noise_cfg -class HunyuanDiTPipeline(DiffusionPipeline): +class HunyuanDiTPipeline(DiffusionPipeline, LoraLoaderMixin): r""" Pipeline for English/Chinese-to-image generation using HunyuanDiT. @@ -239,6 +240,8 @@ def __init__( else 128 ) + self.unet_name = 'transformer' # to support load_lora_weights + def encode_prompt( self, prompt: str, @@ -558,6 +561,10 @@ def num_timesteps(self): @property def interrupt(self): return self._interrupt + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) @@ -592,6 +599,7 @@ def __call__( target_size: Optional[Tuple[int, int]] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), use_resolution_binning: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" The call function to the pipeline for generation with HunyuanDiT. @@ -663,7 +671,11 @@ def __call__( Whether to use resolution binning or not. If `True`, the input resolution will be mapped to the closest standard resolution. Supported resolutions are 1024x1024, 1280x1280, 1024x768, 1152x864, 1280x960, 768x1024, 864x1152, 960x1280, 1280x768, and 768x1280. It is recommended to set this to `True`. - + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + Examples: Returns: @@ -707,6 +719,7 @@ def __call__( ) self._guidance_scale = guidance_scale self._guidance_rescale = guidance_rescale + self._cross_attention_kwargs = cross_attention_kwargs self._interrupt = False # 2. Define call parameters @@ -720,7 +733,6 @@ def __call__( device = self._execution_device # 3. Encode input prompt - ( prompt_embeds, negative_prompt_embeds, @@ -780,7 +792,7 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 7 create image_rotary_emb, style embedding & time ids + # 7. Create image_rotary_emb, style embedding & time ids grid_height = height // 8 // self.transformer.config.patch_size grid_width = width // 8 // self.transformer.config.patch_size base_size = 512 // 8 // self.transformer.config.patch_size @@ -837,6 +849,7 @@ def __call__( text_embedding_mask=prompt_attention_mask, encoder_hidden_states_t5=prompt_embeds_2, text_embedding_mask_t5=prompt_attention_mask_2, + cross_attention_kwargs=self.cross_attention_kwargs, image_meta_size=add_time_ids, style=style, image_rotary_emb=image_rotary_emb, diff --git a/test_hunyuan_dit_lora.py b/test_hunyuan_dit_lora.py new file mode 100644 index 000000000000..e80afdfee566 --- /dev/null +++ b/test_hunyuan_dit_lora.py @@ -0,0 +1,13 @@ +import torch +from diffusers import HunyuanDiTPipeline + +pipe = HunyuanDiTPipeline.from_pretrained("Tencent-Hunyuan/HunyuanDiT-Diffusers", torch_dtype=torch.float16) +pipe.to("cuda") + +pipe.load_lora_weights("YOUR_LORA_PATH", weight_name="lora_weights.pt", adapter_name="yushi") + +prompt = "玉石绘画风格,一个人在雨中跳舞" +image = pipe( + prompt, num_inference_steps=50, generator=torch.manual_seed(0) +).images[0] +image.save('img.png') \ No newline at end of file