From 8e6553e597cdeef00a7b25748d8f8256f5fe6fd1 Mon Sep 17 00:00:00 2001 From: okotaku Date: Fri, 30 Jun 2023 16:49:31 +0900 Subject: [PATCH 01/14] ft text encoder --- .../text_to_image/train_text_to_image_lora.py | 140 ++++++++++++++++-- 1 file changed, 127 insertions(+), 13 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 29259e408eff..6c29ed15769a 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -15,6 +15,7 @@ """Fine-tuning script for Stable Diffusion for text2image with support for LoRA.""" import argparse +import itertools import logging import math import os @@ -40,10 +41,10 @@ import diffusers from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel -from diffusers.loaders import AttnProcsLayers +from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin from diffusers.models.attention_processor import LoRAAttnProcessor from diffusers.optimization import get_scheduler -from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -195,6 +196,11 @@ def parse_args(): action="store_true", help="whether to randomly flip images horizontally", ) + parser.add_argument( + "--train_text_encoder", + action="store_true", + help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", + ) parser.add_argument( "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." ) @@ -458,7 +464,7 @@ def main(): # => 32 layers # Set correct lora layers - lora_attn_procs = {} + unet_lora_attn_procs = {} for name in unet.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim if name.startswith("mid_block"): @@ -470,13 +476,13 @@ def main(): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] - lora_attn_procs[name] = LoRAAttnProcessor( + unet_lora_attn_procs[name] = LoRAAttnProcessor( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=args.rank, ) - unet.set_attn_processor(lora_attn_procs) + unet.set_attn_processor(unet_lora_attn_procs) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): @@ -515,7 +521,88 @@ def compute_snr(timesteps): snr = (alpha / sigma) ** 2 return snr - lora_layers = AttnProcsLayers(unet.attn_processors) + unet_lora_layers = AttnProcsLayers(unet.attn_processors) + + # The text encoder comes from 🤗 transformers, so we cannot directly modify it. + # So, instead, we monkey-patch the forward calls of its attention-blocks. For this, + # we first load a dummy pipeline with the text encoder and then do the monkey-patching. + text_encoder_lora_layers = None + if args.train_text_encoder: + text_lora_attn_procs = {} + for name, module in text_encoder.named_modules(): + if name.endswith(TEXT_ENCODER_ATTN_MODULE): + text_lora_attn_procs[name] = LoRAAttnProcessor( + hidden_size=module.out_proj.out_features, + cross_attention_dim=None, + rank=args.rank, + ) + text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs) + temp_pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, text_encoder=text_encoder + ) + temp_pipeline._modify_text_encoder(text_lora_attn_procs) + text_encoder = temp_pipeline.text_encoder + del temp_pipeline + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + # there are only two options here. Either are just the unet attn processor layers + # or there are the unet and text encoder atten layers + unet_lora_layers_to_save = None + text_encoder_lora_layers_to_save = None + + if args.train_text_encoder: + text_encoder_keys = accelerator.unwrap_model(text_encoder_lora_layers).state_dict().keys() + unet_keys = accelerator.unwrap_model(unet_lora_layers).state_dict().keys() + + for model in models: + state_dict = model.state_dict() + + if ( + text_encoder_lora_layers is not None + and text_encoder_keys is not None + and state_dict.keys() == text_encoder_keys + ): + # text encoder + text_encoder_lora_layers_to_save = state_dict + elif state_dict.keys() == unet_keys: + # unet + unet_lora_layers_to_save = state_dict + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + LoraLoaderMixin.save_lora_weights( + output_dir, + unet_lora_layers=unet_lora_layers_to_save, + text_encoder_lora_layers=text_encoder_lora_layers_to_save, + ) + + def load_model_hook(models, input_dir): + # Note we DON'T pass the unet and text encoder here an purpose + # so that the we don't accidentally override the LoRA layers of + # unet_lora_layers and text_encoder_lora_layers which are stored in `models` + # with new torch.nn.Modules / weights. We simply use the pipeline class as + # an easy way to load the lora checkpoints + temp_pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + torch_dtype=weight_dtype, + ) + temp_pipeline.load_lora_weights(input_dir) + + # load lora weights into models + models[0].load_state_dict(AttnProcsLayers(temp_pipeline.unet.attn_processors).state_dict()) + if len(models) > 1: + models[1].load_state_dict(AttnProcsLayers(temp_pipeline.text_encoder_lora_attn_procs).state_dict()) + + # delete temporary pipeline and pop models + del temp_pipeline + for _ in range(len(models)): + models.pop() + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices @@ -540,8 +627,13 @@ def compute_snr(timesteps): else: optimizer_cls = torch.optim.AdamW + params_to_optimize = ( + itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters()) + if args.train_text_encoder + else unet_lora_layers.parameters() + ) optimizer = optimizer_cls( - lora_layers.parameters(), + params_to_optimize, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, @@ -667,9 +759,14 @@ def collate_fn(examples): ) # Prepare everything with our `accelerator`. - lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - lora_layers, optimizer, train_dataloader, lr_scheduler - ) + if args.train_text_encoder: + unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler + ) + else: + unet_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet_lora_layers, optimizer, train_dataloader, lr_scheduler + ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -727,6 +824,8 @@ def collate_fn(examples): for epoch in range(first_epoch, args.num_train_epochs): unet.train() + if args.train_text_encoder: + text_encoder.train() train_loss = 0.0 for step, batch in enumerate(train_dataloader): # Skip steps until we reach the resumed step @@ -799,7 +898,11 @@ def collate_fn(examples): # Backpropagate accelerator.backward(loss) if accelerator.sync_gradients: - params_to_clip = lora_layers.parameters() + params_to_clip = ( + itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters()) + if args.train_text_encoder + else unet_lora_layers.parameters() + ) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() @@ -854,6 +957,7 @@ def collate_fn(examples): pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder), revision=args.revision, torch_dtype=weight_dtype, ) @@ -891,7 +995,17 @@ def collate_fn(examples): accelerator.wait_for_everyone() if accelerator.is_main_process: unet = unet.to(torch.float32) - unet.save_attn_procs(args.output_dir) + unet_lora_layers = accelerator.unwrap_model(unet_lora_layers) + + if text_encoder is not None: + text_encoder = text_encoder.to(torch.float32) + text_encoder_lora_layers = accelerator.unwrap_model(text_encoder_lora_layers) + + LoraLoaderMixin.save_lora_weights( + save_directory=args.output_dir, + unet_lora_layers=unet_lora_layers, + text_encoder_lora_layers=text_encoder_lora_layers, + ) if args.push_to_hub: save_model_card( @@ -916,7 +1030,7 @@ def collate_fn(examples): pipeline = pipeline.to(accelerator.device) # load attention processors - pipeline.unet.load_attn_procs(args.output_dir) + pipeline.load_lora_weights(args.output_dir) # run inference generator = torch.Generator(device=accelerator.device) From 3b47815e9f5df24b9289f222fff5e26efc398c1c Mon Sep 17 00:00:00 2001 From: okotaku Date: Sat, 15 Jul 2023 12:29:42 +0900 Subject: [PATCH 02/14] merge refactor --- .../text_to_image/train_text_to_image_lora.py | 141 ++++++++---------- 1 file changed, 66 insertions(+), 75 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index e42442b6c1b8..c7c77f67063a 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -22,6 +22,7 @@ import random import shutil from pathlib import Path +from typing import Dict import datasets import numpy as np @@ -41,10 +42,10 @@ import diffusers from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel -from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin +from diffusers.loaders import LoraLoaderMixin, text_encoder_lora_state_dict from diffusers.models.attention_processor import LoRAAttnProcessor from diffusers.optimization import get_scheduler -from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, check_min_version, is_wandb_available +from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -373,6 +374,22 @@ def parse_args(): } +def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]: + r""" + Returns: + a state dict containing just the attention processor parameters. + """ + attn_processors = unet.attn_processors + + attn_processors_state_dict = {} + + for attn_processor_key, attn_processor in attn_processors.items(): + for parameter_key, parameter in attn_processor.state_dict().items(): + attn_processors_state_dict[f"{attn_processor_key}.{parameter_key}"] = parameter + + return attn_processors_state_dict + + def main(): args = parse_args() logging_dir = Path(args.output_dir, args.logging_dir) @@ -465,6 +482,7 @@ def main(): # Set correct lora layers unet_lora_attn_procs = {} + unet_lora_parameters = [] for name in unet.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim if name.startswith("mid_block"): @@ -476,11 +494,9 @@ def main(): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] - unet_lora_attn_procs[name] = LoRAAttnProcessor( - hidden_size=hidden_size, - cross_attention_dim=cross_attention_dim, - rank=args.rank, - ) + module = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=args.rank) + unet_lora_attn_procs[name] = module + unet_lora_parameters.extend(module.parameters()) unet.set_attn_processor(unet_lora_attn_procs) @@ -521,28 +537,11 @@ def compute_snr(timesteps): snr = (alpha / sigma) ** 2 return snr - unet_lora_layers = AttnProcsLayers(unet.attn_processors) # The text encoder comes from 🤗 transformers, so we cannot directly modify it. - # So, instead, we monkey-patch the forward calls of its attention-blocks. For this, - # we first load a dummy pipeline with the text encoder and then do the monkey-patching. - text_encoder_lora_layers = None + # So, instead, we monkey-patch the forward calls of its attention-blocks. if args.train_text_encoder: - text_lora_attn_procs = {} - for name, module in text_encoder.named_modules(): - if name.endswith(TEXT_ENCODER_ATTN_MODULE): - text_lora_attn_procs[name] = LoRAAttnProcessor( - hidden_size=module.out_proj.out_features, - cross_attention_dim=None, - rank=args.rank, - ) - text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs) - temp_pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, text_encoder=text_encoder - ) - temp_pipeline._modify_text_encoder(text_lora_attn_procs) - text_encoder = temp_pipeline.text_encoder - del temp_pipeline + text_lora_parameters = LoraLoaderMixin._modify_text_encoder(text_encoder, dtype=torch.float32, rank=args.rank) # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): @@ -551,23 +550,13 @@ def save_model_hook(models, weights, output_dir): unet_lora_layers_to_save = None text_encoder_lora_layers_to_save = None - if args.train_text_encoder: - text_encoder_keys = accelerator.unwrap_model(text_encoder_lora_layers).state_dict().keys() - unet_keys = accelerator.unwrap_model(unet_lora_layers).state_dict().keys() - for model in models: - state_dict = model.state_dict() - - if ( - text_encoder_lora_layers is not None - and text_encoder_keys is not None - and state_dict.keys() == text_encoder_keys - ): - # text encoder - text_encoder_lora_layers_to_save = state_dict - elif state_dict.keys() == unet_keys: - # unet - unet_lora_layers_to_save = state_dict + if isinstance(model, type(accelerator.unwrap_model(unet))): + unet_lora_layers_to_save = unet_attn_processors_state_dict(model) + elif isinstance(model, type(accelerator.unwrap_model(text_encoder))): + text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") # make sure to pop weight so that corresponding model is not saved again weights.pop() @@ -579,27 +568,23 @@ def save_model_hook(models, weights, output_dir): ) def load_model_hook(models, input_dir): - # Note we DON'T pass the unet and text encoder here an purpose - # so that the we don't accidentally override the LoRA layers of - # unet_lora_layers and text_encoder_lora_layers which are stored in `models` - # with new torch.nn.Modules / weights. We simply use the pipeline class as - # an easy way to load the lora checkpoints - temp_pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, - revision=args.revision, - torch_dtype=weight_dtype, - ) - temp_pipeline.load_lora_weights(input_dir) - - # load lora weights into models - models[0].load_state_dict(AttnProcsLayers(temp_pipeline.unet.attn_processors).state_dict()) - if len(models) > 1: - models[1].load_state_dict(AttnProcsLayers(temp_pipeline.text_encoder_lora_attn_procs).state_dict()) + unet_ = None + text_encoder_ = None + + while len(models) > 0: + model = models.pop() + if isinstance(model, type(accelerator.unwrap_model(unet))): + unet_ = model + elif isinstance(model, type(accelerator.unwrap_model(text_encoder))): + text_encoder_ = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") - # delete temporary pipeline and pop models - del temp_pipeline - for _ in range(len(models)): - models.pop() + lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir) + LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_) + LoraLoaderMixin.load_lora_into_text_encoder( + lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_ + ) accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) @@ -628,9 +613,9 @@ def load_model_hook(models, input_dir): optimizer_cls = torch.optim.AdamW params_to_optimize = ( - itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters()) + itertools.chain(unet_lora_parameters, text_lora_parameters) if args.train_text_encoder - else unet_lora_layers.parameters() + else unet_lora_parameters ) optimizer = optimizer_cls( params_to_optimize, @@ -760,12 +745,12 @@ def collate_fn(examples): # Prepare everything with our `accelerator`. if args.train_text_encoder: - unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler ) else: - unet_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet_lora_layers, optimizer, train_dataloader, lr_scheduler + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. @@ -899,9 +884,9 @@ def collate_fn(examples): accelerator.backward(loss) if accelerator.sync_gradients: params_to_clip = ( - itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters()) + itertools.chain(unet_lora_parameters, text_lora_parameters) if args.train_text_encoder - else unet_lora_layers.parameters() + else unet_lora_parameters ) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() @@ -954,6 +939,8 @@ def collate_fn(examples): f" {args.validation_prompt}." ) # create pipeline + print(weight_dtype) + kk pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet), @@ -970,9 +957,9 @@ def collate_fn(examples): generator = generator.manual_seed(args.seed) images = [] for _ in range(args.num_validation_images): - images.append( - pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] - ) + with torch.cuda.amp.autocast(): + image = pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] + images.append(image) for tracker in accelerator.trackers: if tracker.name == "tensorboard": @@ -994,12 +981,16 @@ def collate_fn(examples): # Save the lora layers accelerator.wait_for_everyone() if accelerator.is_main_process: + unet = accelerator.unwrap_model(unet) unet = unet.to(torch.float32) - unet_lora_layers = accelerator.unwrap_model(unet_lora_layers) + unet_lora_layers = unet_attn_processors_state_dict(unet) if text_encoder is not None: + text_encoder = accelerator.unwrap_model(text_encoder) text_encoder = text_encoder.to(torch.float32) - text_encoder_lora_layers = accelerator.unwrap_model(text_encoder_lora_layers) + text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder) + else: + text_encoder_lora_layers = None LoraLoaderMixin.save_lora_weights( save_directory=args.output_dir, From 1f6908052e2ea114d5f01e64c678b51851c71ee9 Mon Sep 17 00:00:00 2001 From: okotaku Date: Sat, 15 Jul 2023 12:30:21 +0900 Subject: [PATCH 03/14] merge refactor --- examples/text_to_image/train_text_to_image_lora.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index c7c77f67063a..77f3b9db2a33 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -537,7 +537,6 @@ def compute_snr(timesteps): snr = (alpha / sigma) ** 2 return snr - # The text encoder comes from 🤗 transformers, so we cannot directly modify it. # So, instead, we monkey-patch the forward calls of its attention-blocks. if args.train_text_encoder: @@ -939,8 +938,6 @@ def collate_fn(examples): f" {args.validation_prompt}." ) # create pipeline - print(weight_dtype) - kk pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet), From fc2b7ec7985a19eae0769cefabd1e2b39bfa3804 Mon Sep 17 00:00:00 2001 From: okotaku Date: Sat, 15 Jul 2023 14:48:28 +0900 Subject: [PATCH 04/14] merge refactor --- examples/text_to_image/train_text_to_image_lora.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 77f3b9db2a33..1ad97bbd5673 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -43,7 +43,7 @@ import diffusers from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel from diffusers.loaders import LoraLoaderMixin, text_encoder_lora_state_dict -from diffusers.models.attention_processor import LoRAAttnProcessor +from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0 from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -494,7 +494,13 @@ def main(): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] - module = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=args.rank) + lora_attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) + + module = lora_attn_processor_class( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=args.rank + ) unet_lora_attn_procs[name] = module unet_lora_parameters.extend(module.parameters()) From 80e54c83bb7da54e20df343d1fa28f67024ec69c Mon Sep 17 00:00:00 2001 From: okotaku Date: Sat, 15 Jul 2023 18:51:17 +0900 Subject: [PATCH 05/14] merge refactor --- examples/text_to_image/train_text_to_image_lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 1ad97bbd5673..1edfc6a10968 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -1024,7 +1024,7 @@ def collate_fn(examples): pipeline = pipeline.to(accelerator.device) # load attention processors - pipeline.load_lora_weights(args.output_dir) + pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.bin") # run inference generator = torch.Generator(device=accelerator.device) From 8905a4921926eb0e78516bec98f8e95d3cadfb07 Mon Sep 17 00:00:00 2001 From: okotaku Date: Sun, 16 Jul 2023 15:30:09 +0900 Subject: [PATCH 06/14] fix xformers place --- .../text_to_image/train_text_to_image_lora.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 1edfc6a10968..befed77d2cc7 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -467,6 +467,19 @@ def main(): vae.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype) + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + # now we will add new LoRA weights to the attention layers # It's important to realize here how many attention weights will be added and of which sizes # The sizes of the attention layers consist only of two different variables: @@ -506,19 +519,6 @@ def main(): unet.set_attn_processor(unet_lora_attn_procs) - if args.enable_xformers_memory_efficient_attention: - if is_xformers_available(): - import xformers - - xformers_version = version.parse(xformers.__version__) - if xformers_version == version.parse("0.0.16"): - logger.warn( - "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." - ) - unet.enable_xformers_memory_efficient_attention() - else: - raise ValueError("xformers is not available. Make sure it is installed correctly") - def compute_snr(timesteps): """ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 From 1350bd8cdda8341ccc3ff2937c04c519ce851960 Mon Sep 17 00:00:00 2001 From: okotaku Date: Sun, 16 Jul 2023 15:35:14 +0900 Subject: [PATCH 07/14] fix text encoder rule --- examples/text_to_image/train_text_to_image_lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index befed77d2cc7..0d491211cf0b 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -988,7 +988,7 @@ def collate_fn(examples): unet = unet.to(torch.float32) unet_lora_layers = unet_attn_processors_state_dict(unet) - if text_encoder is not None: + if text_encoder is not None and args.train_text_encoder: text_encoder = accelerator.unwrap_model(text_encoder) text_encoder = text_encoder.to(torch.float32) text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder) From 760ad2078dd27c7993eecaace6a27e3acb1c4d81 Mon Sep 17 00:00:00 2001 From: okotaku Date: Tue, 18 Jul 2023 09:46:29 +0900 Subject: [PATCH 08/14] update docs --- docs/source/en/training/lora.mdx | 52 ++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/docs/source/en/training/lora.mdx b/docs/source/en/training/lora.mdx index dfb31c7ef87a..fa74d4e885a6 100644 --- a/docs/source/en/training/lora.mdx +++ b/docs/source/en/training/lora.mdx @@ -89,6 +89,40 @@ accelerate launch --mixed_precision="fp16" train_text_to_image_lora.py \ --seed=1337 ``` +## Finetuning the text encoder and UNet + +The script also allows you to finetune the `text_encoder` along with the `unet`. + + + +Training the text encoder requires additional memory and it won't fit on a 16GB GPU. You'll need at least 24GB VRAM to use this option. + + + +Pass the `--train_text_encoder` argument to the training script to enable finetuning the `text_encoder` and `unet`: + +```bash +accelerate launch --mixed_precision="fp16" train_text_to_image_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$DATASET_NAME \ + --dataloader_num_workers=8 \ + --resolution=512 --center_crop --random_flip \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --max_train_steps=15000 \ + --learning_rate=1e-04 \ + --max_grad_norm=1 \ + --lr_scheduler="cosine" --lr_warmup_steps=0 \ + --output_dir=${OUTPUT_DIR} \ + --push_to_hub \ + --hub_model_id=${HUB_MODEL_ID} \ + --report_to=wandb \ + --checkpointing_steps=500 \ + --validation_prompt="A pokemon with blue eyes." \ + --train_text_encoder \ + --seed=1337 +``` + ### Inference[[text-to-image-inference]] Now you can use the model for inference by loading the base model in the [`StableDiffusionPipeline`] and then the [`DPMSolverMultistepScheduler`]: @@ -144,6 +178,24 @@ pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch. +If you used `--train_text_encoder` during training, then use `pipe.load_lora_weights()` to load the LoRA +weights. For example: + +```python +from huggingface_hub.repocard import RepoCard +from diffusers import StableDiffusionPipeline +import torch + +lora_model_id = "sayakpaul/sd-model-finetuned-lora-t4" +card = RepoCard.load(lora_model_id) +base_model_id = card.data.to_dict()["base_model"] + +pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16) +pipe = pipe.to("cuda") +pipe.load_lora_weights(lora_model_id) +image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0] +``` + ## DreamBooth From c274dc7d0149fc5b6a0d223d6fb021fb863bc9f6 Mon Sep 17 00:00:00 2001 From: okotaku Date: Tue, 18 Jul 2023 10:09:31 +0900 Subject: [PATCH 09/14] update test --- examples/test_examples.py | 46 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/examples/test_examples.py b/examples/test_examples.py index cc3b3fbf7478..5b3c980d3042 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -849,6 +849,52 @@ def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multip {"checkpoint-6", "checkpoint-8", "checkpoint-10"}, ) + def test_text_to_image_lora_with_text_encoder(self): + pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" + + with tempfile.TemporaryDirectory() as tmpdir: + # Run training script with checkpointing + # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2 + # Should create checkpoints at steps 2, 4, 6 + # with checkpoint at step 2 deleted + + initial_run_args = f""" + examples/text_to_image/train_text_to_image_lora.py + --pretrained_model_name_or_path {pretrained_model_name_or_path} + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --center_crop + --random_flip + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 7 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=2 + --checkpoints_total_limit=2 + --seed=0 + --train_text_encoder + --num_validation_images=0 + """.split() + + run_command(self._launch_args + initial_run_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin"))) + + # check `text_encoder` is present at all. + lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin")) + keys = lora_state_dict.keys() + is_text_encoder_present = any(k.startswith("text_encoder") for k in keys) + self.assertTrue(is_text_encoder_present) + + # the names of the keys of the state dict should either start with `unet` + # or `text_encoder`. + is_correct_naming = all(k.startswith("unet") or k.startswith("text_encoder") for k in keys) + self.assertTrue(is_correct_naming) + def test_unconditional_checkpointing_checkpoints_total_limit(self): with tempfile.TemporaryDirectory() as tmpdir: initial_run_args = f""" From 00cff1529ce2eab4aec9777b86f7271c5c7002e7 Mon Sep 17 00:00:00 2001 From: takuoko Date: Wed, 19 Jul 2023 08:36:50 +0900 Subject: [PATCH 10/14] Update docs/source/en/training/lora.mdx Co-authored-by: Sayak Paul --- docs/source/en/training/lora.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/training/lora.mdx b/docs/source/en/training/lora.mdx index fa74d4e885a6..967045650c76 100644 --- a/docs/source/en/training/lora.mdx +++ b/docs/source/en/training/lora.mdx @@ -89,7 +89,7 @@ accelerate launch --mixed_precision="fp16" train_text_to_image_lora.py \ --seed=1337 ``` -## Finetuning the text encoder and UNet +### Finetuning the text encoder and UNet The script also allows you to finetune the `text_encoder` along with the `unet`. From c0e80b0c57319d4bddf8688ec6b8b8c2a48ce880 Mon Sep 17 00:00:00 2001 From: takuoko Date: Wed, 19 Jul 2023 08:37:59 +0900 Subject: [PATCH 11/14] Update examples/test_examples.py Co-authored-by: Sayak Paul --- examples/test_examples.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/examples/test_examples.py b/examples/test_examples.py index 5b3c980d3042..409fcc6d31a5 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -853,10 +853,6 @@ def test_text_to_image_lora_with_text_encoder(self): pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" with tempfile.TemporaryDirectory() as tmpdir: - # Run training script with checkpointing - # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2 - # Should create checkpoints at steps 2, 4, 6 - # with checkpoint at step 2 deleted initial_run_args = f""" examples/text_to_image/train_text_to_image_lora.py From 3e9d86963e64e946e7e3df5cea4ef15bfd2446f4 Mon Sep 17 00:00:00 2001 From: okotaku Date: Wed, 19 Jul 2023 08:47:50 +0900 Subject: [PATCH 12/14] format --- examples/test_examples.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/test_examples.py b/examples/test_examples.py index 409fcc6d31a5..e65c3dcc4899 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -853,7 +853,6 @@ def test_text_to_image_lora_with_text_encoder(self): pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" with tempfile.TemporaryDirectory() as tmpdir: - initial_run_args = f""" examples/text_to_image/train_text_to_image_lora.py --pretrained_model_name_or_path {pretrained_model_name_or_path} From caf921b53830c3f5f73136c02261c03cbbdf57bd Mon Sep 17 00:00:00 2001 From: okotaku Date: Wed, 19 Jul 2023 08:49:20 +0900 Subject: [PATCH 13/14] update docs --- docs/source/en/training/lora.mdx | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/docs/source/en/training/lora.mdx b/docs/source/en/training/lora.mdx index 967045650c76..3c4400eca073 100644 --- a/docs/source/en/training/lora.mdx +++ b/docs/source/en/training/lora.mdx @@ -182,18 +182,16 @@ If you used `--train_text_encoder` during training, then use `pipe.load_lora_wei weights. For example: ```python -from huggingface_hub.repocard import RepoCard from diffusers import StableDiffusionPipeline import torch -lora_model_id = "sayakpaul/sd-model-finetuned-lora-t4" -card = RepoCard.load(lora_model_id) -base_model_id = card.data.to_dict()["base_model"] +lora_model_id = "takuoko/classic-anime-expressions-lora" +base_model_id = "stablediffusionapi/anything-v5" pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16) pipe = pipe.to("cuda") -pipe.load_lora_weights(lora_model_id) -image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0] +pipe.load_lora_weights(lora_model_id, weight_name="pytorch_lora_weights.bin") +image = pipe("1girl, >_<", num_inference_steps=50).images[0] ``` From 74b074952e926096df747e27864b099d553c64da Mon Sep 17 00:00:00 2001 From: okotaku Date: Fri, 21 Jul 2023 11:49:00 +0900 Subject: [PATCH 14/14] del weight_name --- examples/text_to_image/train_text_to_image_lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 0d491211cf0b..8034ad83cc8a 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -1024,7 +1024,7 @@ def collate_fn(examples): pipeline = pipeline.to(accelerator.device) # load attention processors - pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.bin") + pipeline.load_lora_weights(args.output_dir) # run inference generator = torch.Generator(device=accelerator.device)