From a515af22a9fbb6703abe7b28174517dcb0202777 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 18 Sep 2023 22:04:47 -0700 Subject: [PATCH] Add x-prediction / prediction_type=sample support for SDXL fine-tuning --- examples/text_to_image/train_text_to_image_sdxl.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 299d0f0d7523..22486298c984 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -998,6 +998,11 @@ def compute_time_ids(original_size, crops_coords_top_left): target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(model_input, noise, timesteps) + elif noise_scheduler.config.prediction_type == "sample": + # We set the target to latents here, but the model_pred will return the noise sample prediction. + target = model_input + # We will have to subtract the noise residual from the prediction to get the target sample. + model_pred = model_pred - noise else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")