diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 82370fc4e2dd..bbf7bf9b85bb 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -115,7 +115,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( - "--input_pertubation", type=float, default=0, help="The scale of input pretubation. Recommended 0.1." + "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1." ) parser.add_argument( "--pretrained_model_name_or_path", @@ -830,8 +830,8 @@ def collate_fn(examples): noise += args.noise_offset * torch.randn( (latents.shape[0], latents.shape[1], 1, 1), device=latents.device ) - if args.input_pertubation: - new_noise = noise + args.input_pertubation * torch.randn_like(noise) + if args.input_perturbation: + new_noise = noise + args.input_perturbation * torch.randn_like(noise) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) @@ -839,7 +839,7 @@ def collate_fn(examples): # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - if args.input_pertubation: + if args.input_perturbation: noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps) else: noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)