diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index c681943f2e94..4a3048a0ba23 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -1038,7 +1038,6 @@ def compute_time_ids(original_size, crops_coords_top_left): prompt_embeds = batch["prompt_embeds"].to(accelerator.device) pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device) unet_added_conditions.update({"text_embeds": pooled_prompt_embeds}) - prompt_embeds = prompt_embeds model_pred = unet( noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions ).sample