From be84b0337182a2739b5b08a1fea135542fcc402b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 7 Aug 2023 11:16:48 +0000 Subject: [PATCH] fix batch size lora --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index a2b6e4a38278..6f99dbc64d70 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1103,11 +1103,11 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): "time_ids": add_time_ids.repeat(elems_to_repeat, 1), "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat, 1), } - prompt_embeds = prompt_embeds.repeat(elems_to_repeat, 1, 1) + prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat, 1, 1) model_pred = unet( noisy_model_input, timesteps, - prompt_embeds, + prompt_embeds_input, added_cond_kwargs=unet_added_conditions, ).sample else: @@ -1119,9 +1119,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): text_input_ids_list=[tokens_one, tokens_two], ) unet_added_conditions.update({"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat, 1)}) - prompt_embeds = prompt_embeds.repeat(elems_to_repeat, 1, 1) + prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat, 1, 1) model_pred = unet( - noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions + noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions ).sample # Get the target for loss depending on the prediction type