diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 2c66c341f78f..dfbd629d4809 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -976,7 +976,7 @@ def encode_prompt( max_sequence_length, prompt=prompt, num_images_per_prompt=num_images_per_prompt, - text_input_ids=text_input_ids_list[:-1], + text_input_ids=text_input_ids_list[-1], device=device if device is not None else text_encoders[-1].device, ) @@ -1607,8 +1607,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.train_text_encoder: prompt_embeds, pooled_prompt_embeds = encode_prompt( text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three], - tokenizers=[None, None, tokenizer_three], - prompt=prompts, + tokenizers=[None, None, None], + prompt=args.instance_prompt, max_sequence_length=args.max_sequence_length, text_input_ids_list=[tokens_one, tokens_two, tokens_three], )