diff --git a/docs/source/en/training/lora.mdx b/docs/source/en/training/lora.mdx index dfb31c7ef87a..3c4400eca073 100644 --- a/docs/source/en/training/lora.mdx +++ b/docs/source/en/training/lora.mdx @@ -89,6 +89,40 @@ accelerate launch --mixed_precision="fp16" train_text_to_image_lora.py \ --seed=1337 ``` +### Finetuning the text encoder and UNet + +The script also allows you to finetune the `text_encoder` along with the `unet`. + + + +Training the text encoder requires additional memory and it won't fit on a 16GB GPU. You'll need at least 24GB VRAM to use this option. + + + +Pass the `--train_text_encoder` argument to the training script to enable finetuning the `text_encoder` and `unet`: + +```bash +accelerate launch --mixed_precision="fp16" train_text_to_image_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$DATASET_NAME \ + --dataloader_num_workers=8 \ + --resolution=512 --center_crop --random_flip \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --max_train_steps=15000 \ + --learning_rate=1e-04 \ + --max_grad_norm=1 \ + --lr_scheduler="cosine" --lr_warmup_steps=0 \ + --output_dir=${OUTPUT_DIR} \ + --push_to_hub \ + --hub_model_id=${HUB_MODEL_ID} \ + --report_to=wandb \ + --checkpointing_steps=500 \ + --validation_prompt="A pokemon with blue eyes." \ + --train_text_encoder \ + --seed=1337 +``` + ### Inference[[text-to-image-inference]] Now you can use the model for inference by loading the base model in the [`StableDiffusionPipeline`] and then the [`DPMSolverMultistepScheduler`]: @@ -144,6 +178,22 @@ pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch. +If you used `--train_text_encoder` during training, then use `pipe.load_lora_weights()` to load the LoRA +weights. For example: + +```python +from diffusers import StableDiffusionPipeline +import torch + +lora_model_id = "takuoko/classic-anime-expressions-lora" +base_model_id = "stablediffusionapi/anything-v5" + +pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16) +pipe = pipe.to("cuda") +pipe.load_lora_weights(lora_model_id, weight_name="pytorch_lora_weights.bin") +image = pipe("1girl, >_<", num_inference_steps=50).images[0] +``` + ## DreamBooth diff --git a/examples/test_examples.py b/examples/test_examples.py index cc3b3fbf7478..e65c3dcc4899 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -849,6 +849,47 @@ def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multip {"checkpoint-6", "checkpoint-8", "checkpoint-10"}, ) + def test_text_to_image_lora_with_text_encoder(self): + pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" + + with tempfile.TemporaryDirectory() as tmpdir: + initial_run_args = f""" + examples/text_to_image/train_text_to_image_lora.py + --pretrained_model_name_or_path {pretrained_model_name_or_path} + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --center_crop + --random_flip + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 7 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=2 + --checkpoints_total_limit=2 + --seed=0 + --train_text_encoder + --num_validation_images=0 + """.split() + + run_command(self._launch_args + initial_run_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin"))) + + # check `text_encoder` is present at all. + lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin")) + keys = lora_state_dict.keys() + is_text_encoder_present = any(k.startswith("text_encoder") for k in keys) + self.assertTrue(is_text_encoder_present) + + # the names of the keys of the state dict should either start with `unet` + # or `text_encoder`. + is_correct_naming = all(k.startswith("unet") or k.startswith("text_encoder") for k in keys) + self.assertTrue(is_correct_naming) + def test_unconditional_checkpointing_checkpoints_total_limit(self): with tempfile.TemporaryDirectory() as tmpdir: initial_run_args = f""" diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 4fb4f9e13e72..8034ad83cc8a 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -15,12 +15,14 @@ """Fine-tuning script for Stable Diffusion for text2image with support for LoRA.""" import argparse +import itertools import logging import math import os import random import shutil from pathlib import Path +from typing import Dict import datasets import numpy as np @@ -40,8 +42,8 @@ import diffusers from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel -from diffusers.loaders import AttnProcsLayers -from diffusers.models.attention_processor import LoRAAttnProcessor +from diffusers.loaders import LoraLoaderMixin, text_encoder_lora_state_dict +from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0 from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -195,6 +197,11 @@ def parse_args(): action="store_true", help="whether to randomly flip images horizontally", ) + parser.add_argument( + "--train_text_encoder", + action="store_true", + help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", + ) parser.add_argument( "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." ) @@ -367,6 +374,22 @@ def parse_args(): } +def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]: + r""" + Returns: + a state dict containing just the attention processor parameters. + """ + attn_processors = unet.attn_processors + + attn_processors_state_dict = {} + + for attn_processor_key, attn_processor in attn_processors.items(): + for parameter_key, parameter in attn_processor.state_dict().items(): + attn_processors_state_dict[f"{attn_processor_key}.{parameter_key}"] = parameter + + return attn_processors_state_dict + + def main(): args = parse_args() logging_dir = Path(args.output_dir, args.logging_dir) @@ -444,6 +467,19 @@ def main(): vae.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype) + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + # now we will add new LoRA weights to the attention layers # It's important to realize here how many attention weights will be added and of which sizes # The sizes of the attention layers consist only of two different variables: @@ -458,7 +494,8 @@ def main(): # => 32 layers # Set correct lora layers - lora_attn_procs = {} + unet_lora_attn_procs = {} + unet_lora_parameters = [] for name in unet.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim if name.startswith("mid_block"): @@ -470,26 +507,17 @@ def main(): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] - lora_attn_procs[name] = LoRAAttnProcessor( - hidden_size=hidden_size, - cross_attention_dim=cross_attention_dim, - rank=args.rank, + lora_attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor ) - unet.set_attn_processor(lora_attn_procs) - - if args.enable_xformers_memory_efficient_attention: - if is_xformers_available(): - import xformers + module = lora_attn_processor_class( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=args.rank + ) + unet_lora_attn_procs[name] = module + unet_lora_parameters.extend(module.parameters()) - xformers_version = version.parse(xformers.__version__) - if xformers_version == version.parse("0.0.16"): - logger.warn( - "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." - ) - unet.enable_xformers_memory_efficient_attention() - else: - raise ValueError("xformers is not available. Make sure it is installed correctly") + unet.set_attn_processor(unet_lora_attn_procs) def compute_snr(timesteps): """ @@ -515,7 +543,56 @@ def compute_snr(timesteps): snr = (alpha / sigma) ** 2 return snr - lora_layers = AttnProcsLayers(unet.attn_processors) + # The text encoder comes from 🤗 transformers, so we cannot directly modify it. + # So, instead, we monkey-patch the forward calls of its attention-blocks. + if args.train_text_encoder: + text_lora_parameters = LoraLoaderMixin._modify_text_encoder(text_encoder, dtype=torch.float32, rank=args.rank) + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + # there are only two options here. Either are just the unet attn processor layers + # or there are the unet and text encoder atten layers + unet_lora_layers_to_save = None + text_encoder_lora_layers_to_save = None + + for model in models: + if isinstance(model, type(accelerator.unwrap_model(unet))): + unet_lora_layers_to_save = unet_attn_processors_state_dict(model) + elif isinstance(model, type(accelerator.unwrap_model(text_encoder))): + text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + LoraLoaderMixin.save_lora_weights( + output_dir, + unet_lora_layers=unet_lora_layers_to_save, + text_encoder_lora_layers=text_encoder_lora_layers_to_save, + ) + + def load_model_hook(models, input_dir): + unet_ = None + text_encoder_ = None + + while len(models) > 0: + model = models.pop() + if isinstance(model, type(accelerator.unwrap_model(unet))): + unet_ = model + elif isinstance(model, type(accelerator.unwrap_model(text_encoder))): + text_encoder_ = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir) + LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_) + LoraLoaderMixin.load_lora_into_text_encoder( + lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_ + ) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices @@ -540,8 +617,13 @@ def compute_snr(timesteps): else: optimizer_cls = torch.optim.AdamW + params_to_optimize = ( + itertools.chain(unet_lora_parameters, text_lora_parameters) + if args.train_text_encoder + else unet_lora_parameters + ) optimizer = optimizer_cls( - lora_layers.parameters(), + params_to_optimize, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, @@ -667,9 +749,14 @@ def collate_fn(examples): ) # Prepare everything with our `accelerator`. - lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - lora_layers, optimizer, train_dataloader, lr_scheduler - ) + if args.train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -727,6 +814,8 @@ def collate_fn(examples): for epoch in range(first_epoch, args.num_train_epochs): unet.train() + if args.train_text_encoder: + text_encoder.train() train_loss = 0.0 for step, batch in enumerate(train_dataloader): # Skip steps until we reach the resumed step @@ -799,7 +888,11 @@ def collate_fn(examples): # Backpropagate accelerator.backward(loss) if accelerator.sync_gradients: - params_to_clip = lora_layers.parameters() + params_to_clip = ( + itertools.chain(unet_lora_parameters, text_lora_parameters) + if args.train_text_encoder + else unet_lora_parameters + ) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() @@ -854,6 +947,7 @@ def collate_fn(examples): pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder), revision=args.revision, torch_dtype=weight_dtype, ) @@ -866,9 +960,9 @@ def collate_fn(examples): generator = generator.manual_seed(args.seed) images = [] for _ in range(args.num_validation_images): - images.append( - pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] - ) + with torch.cuda.amp.autocast(): + image = pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] + images.append(image) for tracker in accelerator.trackers: if tracker.name == "tensorboard": @@ -890,8 +984,22 @@ def collate_fn(examples): # Save the lora layers accelerator.wait_for_everyone() if accelerator.is_main_process: + unet = accelerator.unwrap_model(unet) unet = unet.to(torch.float32) - unet.save_attn_procs(args.output_dir) + unet_lora_layers = unet_attn_processors_state_dict(unet) + + if text_encoder is not None and args.train_text_encoder: + text_encoder = accelerator.unwrap_model(text_encoder) + text_encoder = text_encoder.to(torch.float32) + text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder) + else: + text_encoder_lora_layers = None + + LoraLoaderMixin.save_lora_weights( + save_directory=args.output_dir, + unet_lora_layers=unet_lora_layers, + text_encoder_lora_layers=text_encoder_lora_layers, + ) if args.push_to_hub: save_model_card( @@ -916,7 +1024,7 @@ def collate_fn(examples): pipeline = pipeline.to(accelerator.device) # load attention processors - pipeline.unet.load_attn_procs(args.output_dir) + pipeline.load_lora_weights(args.output_dir) # run inference generator = torch.Generator(device=accelerator.device)