diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 34e8c69ff64b..68162d7824ab 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -907,17 +907,10 @@ def compute_loss(params, minibatch, sample_rng): if args.snr_gamma is not None: snr = jnp.array(compute_snr(timesteps)) - base_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr if noise_scheduler.config.prediction_type == "v_prediction": - snr_loss_weights = base_weights + 1 - else: - # Epsilon and sample prediction use the base weights. - snr_loss_weights = base_weights - # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. - # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. - # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. - snr_loss_weights[snr == 0] = 1.0 - + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr loss = loss * snr_loss_weights loss = loss.mean() diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py index 0938d38ba487..4ca95ecebea9 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py @@ -781,25 +781,13 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) - base_weight = ( + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) - if noise_scheduler.config.prediction_type == "v_prediction": - # Velocity objective needs to be floored to an SNR weight of one. - mse_loss_weights = base_weight + 1 - else: - # Epsilon and sample both use the same loss weights. - mse_loss_weights = base_weight - - # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. - # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. - # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. - mse_loss_weights[snr == 0] = 1.0 - - # We first calculate the original loss. Then we mean over the non-batch dimensions and - # rebalance the sample-wise losses with their respective loss weights. - # Finally, we take the mean of the rebalanced loss. loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean() diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py index 91091f4d80fb..19245724ecf5 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py @@ -631,25 +631,13 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) - base_weight = ( + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) - if noise_scheduler.config.prediction_type == "v_prediction": - # Velocity objective needs to be floored to an SNR weight of one. - mse_loss_weights = base_weight + 1 - else: - # Epsilon and sample both use the same loss weights. - mse_loss_weights = base_weight - - # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. - # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. - # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. - mse_loss_weights[snr == 0] = 1.0 - - # We first calculate the original loss. Then we mean over the non-batch dimensions and - # rebalance the sample-wise losses with their respective loss weights. - # Finally, we take the mean of the rebalanced loss. loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean() diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py index 3099c613df73..7305137218ef 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py @@ -664,25 +664,13 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) - base_weight = ( + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) - if noise_scheduler.config.prediction_type == "v_prediction": - # Velocity objective needs to be floored to an SNR weight of one. - mse_loss_weights = base_weight + 1 - else: - # Epsilon and sample both use the same loss weights. - mse_loss_weights = base_weight - - # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. - # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. - # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. - mse_loss_weights[snr == 0] = 1.0 - - # We first calculate the original loss. Then we mean over the non-batch dimensions and - # rebalance the sample-wise losses with their respective loss weights. - # Finally, we take the mean of the rebalanced loss. loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean() diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py index 74fa504345fe..d21eaf3dd0b0 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py @@ -811,25 +811,13 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) - base_weight = ( + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) - if noise_scheduler.config.prediction_type == "v_prediction": - # Velocity objective needs to be floored to an SNR weight of one. - mse_loss_weights = base_weight + 1 - else: - # Epsilon and sample both use the same loss weights. - mse_loss_weights = base_weight - - # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. - # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. - # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. - mse_loss_weights[snr == 0] = 1.0 - - # We first calculate the original loss. Then we mean over the non-batch dimensions and - # rebalance the sample-wise losses with their respective loss weights. - # Finally, we take the mean of the rebalanced loss. loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean() diff --git a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py index 15c17063bd68..f7100788cde2 100644 --- a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py +++ b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py @@ -848,24 +848,13 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) - base_weight = ( + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) - if noise_scheduler.config.prediction_type == "v_prediction": - # velocity objective prediction requires SNR weights to be floored to a min value of 1. - mse_loss_weights = base_weight + 1 - else: - # Epsilon and sample prediction use the base weights. - mse_loss_weights = base_weight - - # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. - # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. - # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. - mse_loss_weights[snr == 0] = 1.0 - - # We first calculate the original loss. Then we mean over the non-batch dimensions and - # rebalance the sample-wise losses with their respective loss weights. - # Finally, we take the mean of the rebalanced loss. + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean() diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 03efc3fa13c5..e216529b2f54 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -929,25 +929,13 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) - base_weight = ( + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) - if noise_scheduler.config.prediction_type == "v_prediction": - # Velocity objective needs to be floored to an SNR weight of one. - mse_loss_weights = base_weight + 1 - else: - # Epsilon and sample both use the same loss weights. - mse_loss_weights = base_weight - - # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. - # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. - # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. - mse_loss_weights[snr == 0] = 1.0 - - # We first calculate the original loss. Then we mean over the non-batch dimensions and - # rebalance the sample-wise losses with their respective loss weights. - # Finally, we take the mean of the rebalanced loss. loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean() diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 6ae157bce9b3..eac0f18f49f4 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -759,25 +759,13 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) - base_weight = ( + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) - if noise_scheduler.config.prediction_type == "v_prediction": - # Velocity objective needs to be floored to an SNR weight of one. - mse_loss_weights = base_weight + 1 - else: - # Epsilon and sample both use the same loss weights. - mse_loss_weights = base_weight - - # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. - # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. - # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. - mse_loss_weights[snr == 0] = 1.0 - - # We first calculate the original loss. Then we mean over the non-batch dimensions and - # rebalance the sample-wise losses with their respective loss weights. - # Finally, we take the mean of the rebalanced loss. loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean() diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index d685d468db4d..ed7a15cd95fe 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -1050,25 +1050,13 @@ def compute_time_ids(original_size, crops_coords_top_left): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) - base_weight = ( + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) - if noise_scheduler.config.prediction_type == "v_prediction": - # Velocity objective needs to be floored to an SNR weight of one. - mse_loss_weights = base_weight + 1 - else: - # Epsilon and sample both use the same loss weights. - mse_loss_weights = base_weight - - # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. - # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. - # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. - mse_loss_weights[snr == 0] = 1.0 - - # We first calculate the original loss. Then we mean over the non-batch dimensions and - # rebalance the sample-wise losses with their respective loss weights. - # Finally, we take the mean of the rebalanced loss. loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean() diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 4cf966af77cd..c681943f2e94 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -1067,25 +1067,13 @@ def compute_time_ids(original_size, crops_coords_top_left): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) - base_weight = ( + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) - if noise_scheduler.config.prediction_type == "v_prediction": - # Velocity objective needs to be floored to an SNR weight of one. - mse_loss_weights = base_weight + 1 - else: - # Epsilon and sample both use the same loss weights. - mse_loss_weights = base_weight - - # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. - # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. - # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. - mse_loss_weights[snr == 0] = 1.0 - - # We first calculate the original loss. Then we mean over the non-batch dimensions and - # rebalance the sample-wise losses with their respective loss weights. - # Finally, we take the mean of the rebalanced loss. loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean()