From 45a36aa2e9e1ce7548495aeca3d4d99203b372f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Somoza?= Date: Sun, 30 Jun 2024 06:55:27 -0400 Subject: [PATCH 1/4] fix --- examples/dreambooth/train_dreambooth_lora_sd3.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 2c66c341f78f..69ad2d125d3a 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -962,7 +962,7 @@ def encode_prompt( prompt=prompt, device=device if device is not None else text_encoder.device, num_images_per_prompt=num_images_per_prompt, - text_input_ids=text_input_ids_list[i], + text_input_ids=text_input_ids_list[i] if text_input_ids_list else None, ) clip_prompt_embeds_list.append(prompt_embeds) clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds) @@ -976,7 +976,7 @@ def encode_prompt( max_sequence_length, prompt=prompt, num_images_per_prompt=num_images_per_prompt, - text_input_ids=text_input_ids_list[:-1], + text_input_ids=text_input_ids_list[:-1] if text_input_ids_list else None, device=device if device is not None else text_encoders[-1].device, ) @@ -1687,7 +1687,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if accelerator.sync_gradients: params_to_clip = itertools.chain( transformer_lora_parameters, - text_lora_parameters_one, + text_lora_parameters_one if args.train_text_encoder else transformer_lora_parameters, text_lora_parameters_two if args.train_text_encoder else transformer_lora_parameters, ) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) From 0b4da4ba826b622e14b3c23f3fccfc19dbec7b32 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 1 Jul 2024 18:02:36 +0530 Subject: [PATCH 2/4] fix things. Co-authored-by: Linoy Tsaban --- dreambooth.patch | 57 +++++++++++++++++++ .../dreambooth/train_dreambooth_lora_sd3.py | 18 +++--- 2 files changed, 65 insertions(+), 10 deletions(-) create mode 100644 dreambooth.patch diff --git a/dreambooth.patch b/dreambooth.patch new file mode 100644 index 000000000000..891dad70fef9 --- /dev/null +++ b/dreambooth.patch @@ -0,0 +1,57 @@ +diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py +index 69ad2d12..17198449 100644 +--- a/examples/dreambooth/train_dreambooth_lora_sd3.py ++++ b/examples/dreambooth/train_dreambooth_lora_sd3.py +@@ -1491,6 +1491,9 @@ def main(args): + ) = accelerator.prepare( + transformer, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler + ) ++ assert text_encoder_one is not None ++ assert text_encoder_two is not None ++ assert text_encoder_three is not None + else: + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler +@@ -1598,7 +1601,7 @@ def main(args): + tokens_three = tokenize_prompt(tokenizer_three, prompts) + prompt_embeds, pooled_prompt_embeds = encode_prompt( + text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three], +- tokenizers=[None, None, tokenizer_three], ++ tokenizers=[None, None, None], + prompt=prompts, + max_sequence_length=args.max_sequence_length, + text_input_ids_list=[tokens_one, tokens_two, tokens_three], +@@ -1608,7 +1611,7 @@ def main(args): + prompt_embeds, pooled_prompt_embeds = encode_prompt( + text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three], + tokenizers=[None, None, tokenizer_three], +- prompt=prompts, ++ prompt=args.instance_prompt, + max_sequence_length=args.max_sequence_length, + text_input_ids_list=[tokens_one, tokens_two, tokens_three], + ) +@@ -1741,13 +1744,6 @@ def main(args): + text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders( + text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three + ) +- else: +- text_encoder_three = text_encoder_cls_three.from_pretrained( +- args.pretrained_model_name_or_path, +- subfolder="text_encoder_3", +- revision=args.revision, +- variant=args.variant, +- ) + pipeline = StableDiffusion3Pipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, +@@ -1767,7 +1763,9 @@ def main(args): + pipeline_args=pipeline_args, + epoch=epoch, + ) +- del text_encoder_one, text_encoder_two, text_encoder_three ++ if not args.train_text_encoder: ++ del text_encoder_one, text_encoder_two, text_encoder_three ++ + torch.cuda.empty_cache() + gc.collect() + diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 69ad2d125d3a..c802d1399121 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -1491,6 +1491,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): ) = accelerator.prepare( transformer, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler ) + assert text_encoder_one is not None + assert text_encoder_two is not None + assert text_encoder_three is not None else: transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( transformer, optimizer, train_dataloader, lr_scheduler @@ -1598,7 +1601,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): tokens_three = tokenize_prompt(tokenizer_three, prompts) prompt_embeds, pooled_prompt_embeds = encode_prompt( text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three], - tokenizers=[None, None, tokenizer_three], + tokenizers=[None, None, None], prompt=prompts, max_sequence_length=args.max_sequence_length, text_input_ids_list=[tokens_one, tokens_two, tokens_three], @@ -1608,7 +1611,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): prompt_embeds, pooled_prompt_embeds = encode_prompt( text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three], tokenizers=[None, None, tokenizer_three], - prompt=prompts, + prompt=args.instance_prompt, max_sequence_length=args.max_sequence_length, text_input_ids_list=[tokens_one, tokens_two, tokens_three], ) @@ -1741,13 +1744,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders( text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three ) - else: - text_encoder_three = text_encoder_cls_three.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="text_encoder_3", - revision=args.revision, - variant=args.variant, - ) pipeline = StableDiffusion3Pipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, @@ -1767,7 +1763,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): pipeline_args=pipeline_args, epoch=epoch, ) - del text_encoder_one, text_encoder_two, text_encoder_three + if not args.train_text_encoder: + del text_encoder_one, text_encoder_two, text_encoder_three + torch.cuda.empty_cache() gc.collect() From c3e5da7eaaeb0e803c8fd850c9582f59ec0e3150 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 1 Jul 2024 18:02:56 +0530 Subject: [PATCH 3/4] remove patch --- dreambooth.patch | 57 ------------------------------------------------ 1 file changed, 57 deletions(-) delete mode 100644 dreambooth.patch diff --git a/dreambooth.patch b/dreambooth.patch deleted file mode 100644 index 891dad70fef9..000000000000 --- a/dreambooth.patch +++ /dev/null @@ -1,57 +0,0 @@ -diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py -index 69ad2d12..17198449 100644 ---- a/examples/dreambooth/train_dreambooth_lora_sd3.py -+++ b/examples/dreambooth/train_dreambooth_lora_sd3.py -@@ -1491,6 +1491,9 @@ def main(args): - ) = accelerator.prepare( - transformer, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler - ) -+ assert text_encoder_one is not None -+ assert text_encoder_two is not None -+ assert text_encoder_three is not None - else: - transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - transformer, optimizer, train_dataloader, lr_scheduler -@@ -1598,7 +1601,7 @@ def main(args): - tokens_three = tokenize_prompt(tokenizer_three, prompts) - prompt_embeds, pooled_prompt_embeds = encode_prompt( - text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three], -- tokenizers=[None, None, tokenizer_three], -+ tokenizers=[None, None, None], - prompt=prompts, - max_sequence_length=args.max_sequence_length, - text_input_ids_list=[tokens_one, tokens_two, tokens_three], -@@ -1608,7 +1611,7 @@ def main(args): - prompt_embeds, pooled_prompt_embeds = encode_prompt( - text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three], - tokenizers=[None, None, tokenizer_three], -- prompt=prompts, -+ prompt=args.instance_prompt, - max_sequence_length=args.max_sequence_length, - text_input_ids_list=[tokens_one, tokens_two, tokens_three], - ) -@@ -1741,13 +1744,6 @@ def main(args): - text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders( - text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three - ) -- else: -- text_encoder_three = text_encoder_cls_three.from_pretrained( -- args.pretrained_model_name_or_path, -- subfolder="text_encoder_3", -- revision=args.revision, -- variant=args.variant, -- ) - pipeline = StableDiffusion3Pipeline.from_pretrained( - args.pretrained_model_name_or_path, - vae=vae, -@@ -1767,7 +1763,9 @@ def main(args): - pipeline_args=pipeline_args, - epoch=epoch, - ) -- del text_encoder_one, text_encoder_two, text_encoder_three -+ if not args.train_text_encoder: -+ del text_encoder_one, text_encoder_two, text_encoder_three -+ - torch.cuda.empty_cache() - gc.collect() - From 1137349492d8ae38070b235fc7f42d079e43a00b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Somoza?= Date: Mon, 1 Jul 2024 10:27:53 -0400 Subject: [PATCH 4/4] apply suggestions --- examples/dreambooth/train_dreambooth_lora_sd3.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index c802d1399121..3aad7216f6aa 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -976,7 +976,7 @@ def encode_prompt( max_sequence_length, prompt=prompt, num_images_per_prompt=num_images_per_prompt, - text_input_ids=text_input_ids_list[:-1] if text_input_ids_list else None, + text_input_ids=text_input_ids_list[-1] if text_input_ids_list else None, device=device if device is not None else text_encoders[-1].device, ) @@ -1688,10 +1688,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator.backward(loss) if accelerator.sync_gradients: - params_to_clip = itertools.chain( - transformer_lora_parameters, - text_lora_parameters_one if args.train_text_encoder else transformer_lora_parameters, - text_lora_parameters_two if args.train_text_encoder else transformer_lora_parameters, + params_to_clip = ( + itertools.chain( + transformer_lora_parameters, text_lora_parameters_one, text_lora_parameters_two + ) + if args.train_text_encoder + else transformer_lora_parameters ) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)