Skip to content

Commit

Permalink
Add min snr to text2img lora training script (#3459)
Browse files Browse the repository at this point in the history
add min snr to text2img lora training script
  • Loading branch information
wfng92 committed May 17, 2023
1 parent bd78f63 commit 2faf91d
Showing 1 changed file with 48 additions and 1 deletion.
49 changes: 48 additions & 1 deletion examples/text_to_image/train_text_to_image_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,13 @@ def parse_args():
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--snr_gamma",
type=float,
default=None,
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
"More details here: https://arxiv.org/abs/2303.09556.",
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
)
Expand Down Expand Up @@ -472,6 +479,30 @@ def main():
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")

def compute_snr(timesteps):
"""
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
"""
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = alphas_cumprod**0.5
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5

# Expand the tensors.
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)

sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)

# Compute SNR.
snr = (alpha / sigma) ** 2
return snr

lora_layers = AttnProcsLayers(unet.attn_processors)

# Enable TF32 for faster training on Ampere GPUs,
Expand Down Expand Up @@ -727,7 +758,23 @@ def collate_fn(examples):

# Predict the noise residual and compute loss
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

if args.snr_gamma is None:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
else:
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps)
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()

# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
Expand Down

0 comments on commit 2faf91d

Please sign in to comment.