From aa65d6563e796cb0161a5108216d9bc0a2848727 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 1 Jul 2024 14:42:38 +0300 Subject: [PATCH 1/2] fix text encoder training for instance prompt --- examples/dreambooth/train_dreambooth_lora_sd3.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 2c66c341f78f..d7b136191716 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -910,6 +910,7 @@ def _encode_prompt_with_clip( ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + print("batch_size", batch_size) if tokenizer is not None: text_inputs = tokenizer( @@ -933,6 +934,7 @@ def _encode_prompt_with_clip( _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method + print("HEREEEEEE", prompt_embeds.shape) prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) @@ -976,7 +978,7 @@ def encode_prompt( max_sequence_length, prompt=prompt, num_images_per_prompt=num_images_per_prompt, - text_input_ids=text_input_ids_list[:-1], + text_input_ids=text_input_ids_list[-1], device=device if device is not None else text_encoders[-1].device, ) @@ -1451,6 +1453,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the # batch prompts on all training steps else: + print("tokenizing pre-training loop") tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt) tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt) tokens_three = tokenize_prompt(tokenizer_three, args.instance_prompt) @@ -1607,8 +1610,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.train_text_encoder: prompt_embeds, pooled_prompt_embeds = encode_prompt( text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three], - tokenizers=[None, None, tokenizer_three], - prompt=prompts, + tokenizers=[None, None, None], + prompt=args.instance_prompt, max_sequence_length=args.max_sequence_length, text_input_ids_list=[tokens_one, tokens_two, tokens_three], ) From b0c1bc3b81ec574a5e332c575419a8527799c09c Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 1 Jul 2024 14:57:08 +0300 Subject: [PATCH 2/2] fix text encoder training for instance prompt --- examples/dreambooth/train_dreambooth_lora_sd3.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index d7b136191716..dfbd629d4809 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -910,7 +910,6 @@ def _encode_prompt_with_clip( ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) - print("batch_size", batch_size) if tokenizer is not None: text_inputs = tokenizer( @@ -934,7 +933,6 @@ def _encode_prompt_with_clip( _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method - print("HEREEEEEE", prompt_embeds.shape) prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) @@ -1453,7 +1451,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the # batch prompts on all training steps else: - print("tokenizing pre-training loop") tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt) tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt) tokens_three = tokenize_prompt(tokenizer_three, args.instance_prompt)