From f4b119ef15089171bc896fd0e23307175808907a Mon Sep 17 00:00:00 2001 From: bghira Date: Fri, 29 Sep 2023 07:59:49 -0700 Subject: [PATCH] Min-SNR Gamma: correct the fix for SNR weighted loss in v-prediction by adding 1 to SNR rather than the resulting loss weights --- examples/controlnet/train_controlnet_flax.py | 13 +++--------- .../train_text_to_image_decoder.py | 20 ++++-------------- .../train_text_to_image_lora_decoder.py | 20 ++++-------------- .../train_text_to_image_lora_prior.py | 20 ++++-------------- .../train_text_to_image_prior.py | 20 ++++-------------- .../text_to_image/train_text_to_image.py | 21 +++++-------------- examples/text_to_image/train_text_to_image.py | 20 ++++-------------- .../text_to_image/train_text_to_image_lora.py | 20 ++++-------------- .../train_text_to_image_lora_sdxl.py | 20 ++++-------------- .../text_to_image/train_text_to_image_sdxl.py | 20 ++++-------------- 10 files changed, 40 insertions(+), 154 deletions(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 34e8c69ff64b..68162d7824ab 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -907,17 +907,10 @@ def compute_loss(params, minibatch, sample_rng): if args.snr_gamma is not None: snr = jnp.array(compute_snr(timesteps)) - base_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr if noise_scheduler.config.prediction_type == "v_prediction": - snr_loss_weights = base_weights + 1 - else: - # Epsilon and sample prediction use the base weights. - snr_loss_weights = base_weights - # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. - # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. - # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. - snr_loss_weights[snr == 0] = 1.0 - + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr loss = loss * snr_loss_weights loss = loss.mean() diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py index dd79c88f8a76..e769d2b28057 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py @@ -777,25 +777,13 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) - base_weight = ( + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) - if noise_scheduler.config.prediction_type == "v_prediction": - # Velocity objective needs to be floored to an SNR weight of one. - mse_loss_weights = base_weight + 1 - else: - # Epsilon and sample both use the same loss weights. - mse_loss_weights = base_weight - - # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. - # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. - # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. - mse_loss_weights[snr == 0] = 1.0 - - # We first calculate the original loss. Then we mean over the non-batch dimensions and - # rebalance the sample-wise losses with their respective loss weights. - # Finally, we take the mean of the rebalanced loss. loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean() diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py index 5e5f4b9cbf5d..075fe8897864 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py @@ -631,25 +631,13 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) - base_weight = ( + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) - if noise_scheduler.config.prediction_type == "v_prediction": - # Velocity objective needs to be floored to an SNR weight of one. - mse_loss_weights = base_weight + 1 - else: - # Epsilon and sample both use the same loss weights. - mse_loss_weights = base_weight - - # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. - # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. - # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. - mse_loss_weights[snr == 0] = 1.0 - - # We first calculate the original loss. Then we mean over the non-batch dimensions and - # rebalance the sample-wise losses with their respective loss weights. - # Finally, we take the mean of the rebalanced loss. loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean() diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py index d2aabb948969..8d32dc45c3b9 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py @@ -661,25 +661,13 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) - base_weight = ( + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) - if noise_scheduler.config.prediction_type == "v_prediction": - # Velocity objective needs to be floored to an SNR weight of one. - mse_loss_weights = base_weight + 1 - else: - # Epsilon and sample both use the same loss weights. - mse_loss_weights = base_weight - - # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. - # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. - # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. - mse_loss_weights[snr == 0] = 1.0 - - # We first calculate the original loss. Then we mean over the non-batch dimensions and - # rebalance the sample-wise losses with their respective loss weights. - # Finally, we take the mean of the rebalanced loss. loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean() diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py index b86df4de600c..6ab402f0f0e9 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py @@ -809,25 +809,13 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) - base_weight = ( + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) - if noise_scheduler.config.prediction_type == "v_prediction": - # Velocity objective needs to be floored to an SNR weight of one. - mse_loss_weights = base_weight + 1 - else: - # Epsilon and sample both use the same loss weights. - mse_loss_weights = base_weight - - # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. - # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. - # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. - mse_loss_weights[snr == 0] = 1.0 - - # We first calculate the original loss. Then we mean over the non-batch dimensions and - # rebalance the sample-wise losses with their respective loss weights. - # Finally, we take the mean of the rebalanced loss. loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean() diff --git a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py index 15c17063bd68..f7100788cde2 100644 --- a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py +++ b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py @@ -848,24 +848,13 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) - base_weight = ( + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) - if noise_scheduler.config.prediction_type == "v_prediction": - # velocity objective prediction requires SNR weights to be floored to a min value of 1. - mse_loss_weights = base_weight + 1 - else: - # Epsilon and sample prediction use the base weights. - mse_loss_weights = base_weight - - # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. - # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. - # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. - mse_loss_weights[snr == 0] = 1.0 - - # We first calculate the original loss. Then we mean over the non-batch dimensions and - # rebalance the sample-wise losses with their respective loss weights. - # Finally, we take the mean of the rebalanced loss. + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean() diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 535942629314..ff37d03ef0c5 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -928,25 +928,13 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) - base_weight = ( + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) - if noise_scheduler.config.prediction_type == "v_prediction": - # Velocity objective needs to be floored to an SNR weight of one. - mse_loss_weights = base_weight + 1 - else: - # Epsilon and sample both use the same loss weights. - mse_loss_weights = base_weight - - # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. - # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. - # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. - mse_loss_weights[snr == 0] = 1.0 - - # We first calculate the original loss. Then we mean over the non-batch dimensions and - # rebalance the sample-wise losses with their respective loss weights. - # Finally, we take the mean of the rebalanced loss. loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean() diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index d4d13a144f38..d44f41d8e0fd 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -760,25 +760,13 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) - base_weight = ( + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) - if noise_scheduler.config.prediction_type == "v_prediction": - # Velocity objective needs to be floored to an SNR weight of one. - mse_loss_weights = base_weight + 1 - else: - # Epsilon and sample both use the same loss weights. - mse_loss_weights = base_weight - - # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. - # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. - # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. - mse_loss_weights[snr == 0] = 1.0 - - # We first calculate the original loss. Then we mean over the non-batch dimensions and - # rebalance the sample-wise losses with their respective loss weights. - # Finally, we take the mean of the rebalanced loss. loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean() diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 991f8a84a243..261d024dbbc0 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -1049,25 +1049,13 @@ def compute_time_ids(original_size, crops_coords_top_left): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) - base_weight = ( + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) - if noise_scheduler.config.prediction_type == "v_prediction": - # Velocity objective needs to be floored to an SNR weight of one. - mse_loss_weights = base_weight + 1 - else: - # Epsilon and sample both use the same loss weights. - mse_loss_weights = base_weight - - # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. - # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. - # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. - mse_loss_weights[snr == 0] = 1.0 - - # We first calculate the original loss. Then we mean over the non-batch dimensions and - # rebalance the sample-wise losses with their respective loss weights. - # Finally, we take the mean of the rebalanced loss. loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean() diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 649b82ed3baa..0d5ed60f8b6a 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -1065,25 +1065,13 @@ def compute_time_ids(original_size, crops_coords_top_left): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) - base_weight = ( + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) - if noise_scheduler.config.prediction_type == "v_prediction": - # Velocity objective needs to be floored to an SNR weight of one. - mse_loss_weights = base_weight + 1 - else: - # Epsilon and sample both use the same loss weights. - mse_loss_weights = base_weight - - # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. - # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. - # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. - mse_loss_weights[snr == 0] = 1.0 - - # We first calculate the original loss. Then we mean over the non-batch dimensions and - # rebalance the sample-wise losses with their respective loss weights. - # Finally, we take the mean of the rebalanced loss. loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean()