From 9451288fa6a4a6b02c957371912b8b32e0dbfd2a Mon Sep 17 00:00:00 2001 From: bghira Date: Tue, 19 Sep 2023 10:34:40 -0700 Subject: [PATCH] SNR gamma fixes for v_prediction training --- examples/controlnet/train_controlnet_flax.py | 3 +++ .../onnxruntime/text_to_image/train_text_to_image.py | 3 +++ examples/text_to_image/train_text_to_image.py | 3 +++ examples/text_to_image/train_text_to_image_lora.py | 3 +++ examples/text_to_image/train_text_to_image_lora_sdxl.py | 3 +++ 5 files changed, 15 insertions(+) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 75087284cbb4..d04c616c57eb 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -908,6 +908,9 @@ def compute_loss(params, minibatch, sample_rng): if args.snr_gamma is not None: snr = jnp.array(compute_snr(timesteps)) snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr + if noise_scheduler.config.prediction_type == "v_prediction": + # velocity objective prediction requires SNR weights to be floored to a min value of 1. + snr_loss_weights = snr_loss_weights + 1 loss = loss * snr_loss_weights loss = loss.mean() diff --git a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py index 2548c3a286a6..b89de5e001c5 100644 --- a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py +++ b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py @@ -875,6 +875,9 @@ def collate_fn(examples): mse_loss_weights = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) + if noise_scheduler.config.prediction_type == "v_prediction": + # velocity objective prediction requires SNR weights to be floored to a min value of 1. + mse_loss_weights = mse_loss_weights + 1 # We first calculate the original loss. Then we mean over the non-batch dimensions and # rebalance the sample-wise losses with their respective loss weights. # Finally, we take the mean of the rebalanced loss. diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index b00884bfb7ea..0d14e6ccd548 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -955,6 +955,9 @@ def collate_fn(examples): mse_loss_weights = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) + if noise_scheduler.config.prediction_type == "v_prediction": + # velocity objective prediction requires SNR weights to be floored to a min value of 1. + mse_loss_weights = mse_loss_weights + 1 # We first calculate the original loss. Then we mean over the non-batch dimensions and # rebalance the sample-wise losses with their respective loss weights. # Finally, we take the mean of the rebalanced loss. diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index b9830a83ae8a..5845bda0e54f 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -786,6 +786,9 @@ def collate_fn(examples): mse_loss_weights = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) + if noise_scheduler.config.prediction_type == "v_prediction": + # velocity objective prediction requires SNR weights to be floored to a min value of 1. + mse_loss_weights = mse_loss_weights + 1 # We first calculate the original loss. Then we mean over the non-batch dimensions and # rebalance the sample-wise losses with their respective loss weights. # Finally, we take the mean of the rebalanced loss. diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 45ae1cc9ef7a..7a8c2c353eb0 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -1075,6 +1075,9 @@ def compute_time_ids(original_size, crops_coords_top_left): mse_loss_weights = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) + if noise_scheduler.config.prediction_type == "v_prediction": + # velocity objective prediction requires SNR weights to be floored to a min value of 1. + mse_loss_weights = mse_loss_weights + 1 # We first calculate the original loss. Then we mean over the non-batch dimensions and # rebalance the sample-wise losses with their respective loss weights. # Finally, we take the mean of the rebalanced loss.