From ea3b2d14d97ab70078e80492c541bc06bdbbefe7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=8B=E7=A1=95?= Date: Mon, 19 Aug 2024 20:12:32 +0800 Subject: [PATCH 1/3] Fix dtype error --- examples/text_to_image/train_text_to_image_sdxl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 7f4917b5464c..2ca511c857ae 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -1084,7 +1084,7 @@ def unwrap_model(model): # Add noise to the model input according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) + noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps).to(dtype=weight_dtype) # time ids def compute_time_ids(original_size, crops_coords_top_left): @@ -1101,7 +1101,7 @@ def compute_time_ids(original_size, crops_coords_top_left): # Predict the noise residual unet_added_conditions = {"time_ids": add_time_ids} - prompt_embeds = batch["prompt_embeds"].to(accelerator.device) + prompt_embeds = batch["prompt_embeds"].to(accelerator.device, dtype=weight_dtype) pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device) unet_added_conditions.update({"text_embeds": pooled_prompt_embeds}) model_pred = unet( From 449aded49d15f445074cb07fe260d5c6e172aac5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=8B=E7=A1=95?= Date: Thu, 12 Sep 2024 10:54:01 +0800 Subject: [PATCH 2/3] [bugfix] Fixed the issue on sd3 dreambooth training --- examples/dreambooth/train_dreambooth_flux.py | 448 ++++++++++---- examples/dreambooth/train_dreambooth_lora.py | 393 +++++++++--- .../dreambooth/train_dreambooth_lora_flux.py | 424 ++++++++++--- .../dreambooth/train_dreambooth_lora_sd3.py | 443 +++++++++++--- .../dreambooth/train_dreambooth_lora_sdxl.py | 575 ++++++++++++++---- examples/dreambooth/train_dreambooth_sd3.py | 414 ++++++++++--- 6 files changed, 2076 insertions(+), 621 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index da571cc46c57..a47abfaa74b8 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -32,7 +32,11 @@ import transformers from accelerate import Accelerator from accelerate.logging import get_logger -from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from accelerate.utils import ( + DistributedDataParallelKwargs, + ProjectConfiguration, + set_seed, +) from huggingface_hub import create_repo, upload_folder from huggingface_hub.utils import insecure_hashlib from PIL import Image @@ -41,7 +45,13 @@ from torchvision import transforms from torchvision.transforms.functional import crop from tqdm.auto import tqdm -from transformers import CLIPTextModelWithProjection, CLIPTokenizer, PretrainedConfig, T5EncoderModel, T5TokenizerFast +from transformers import ( + CLIPTextModelWithProjection, + CLIPTokenizer, + PretrainedConfig, + T5EncoderModel, + T5TokenizerFast, +) import diffusers from diffusers import ( @@ -51,15 +61,14 @@ FluxTransformer2DModel, ) from diffusers.optimization import get_scheduler -from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3 -from diffusers.utils import ( - check_min_version, - is_wandb_available, +from diffusers.training_utils import ( + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3, ) +from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.torch_utils import is_compiled_module - if is_wandb_available(): import wandb @@ -83,7 +92,10 @@ def save_model_card( for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) widget_dict.append( - {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} + { + "text": validation_prompt if validation_prompt else " ", + "output": {"url": f"image_{i}.png"}, + } ) model_description = f""" @@ -140,10 +152,16 @@ def save_model_card( def load_text_encoders(class_one, class_two): text_encoder_one = class_one.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + variant=args.variant, ) text_encoder_two = class_two.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant + args.pretrained_model_name_or_path, + subfolder="text_encoder_2", + revision=args.revision, + variant=args.variant, ) return text_encoder_one, text_encoder_two @@ -154,22 +172,30 @@ def log_validation( accelerator, pipeline_args, epoch, + torch_dtype, is_final_validation=False, ): logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) - pipeline = pipeline.to(accelerator.device) + pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) pipeline.set_progress_bar_config(disable=True) # run inference - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + generator = ( + torch.Generator(device=accelerator.device).manual_seed(args.seed) + if args.seed + else None + ) # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() autocast_ctx = nullcontext() with autocast_ctx: - images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] + images = [ + pipeline(**pipeline_args, generator=generator).images[0] + for _ in range(args.num_validation_images) + ] for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" @@ -180,7 +206,8 @@ def log_validation( tracker.log( { phase_name: [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) ] } ) @@ -278,7 +305,12 @@ def parse_args(input_args=None): help="The column of the dataset containing the instance prompt for each image", ) - parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") + parser.add_argument( + "--repeats", + type=int, + default=1, + help="How many times to repeat the training data.", + ) parser.add_argument( "--class_data_dir", @@ -333,7 +365,12 @@ def parse_args(input_args=None): action="store_true", help="Flag to add prior preservation loss.", ) - parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--prior_loss_weight", + type=float, + default=1.0, + help="The weight of prior preservation loss.", + ) parser.add_argument( "--num_class_images", type=int, @@ -349,7 +386,9 @@ def parse_args(input_args=None): default="flux-dreambooth", help="The output directory where the model predictions and checkpoints will be written.", ) - parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--seed", type=int, default=None, help="A seed for reproducible training." + ) parser.add_argument( "--resolution", type=int, @@ -379,10 +418,16 @@ def parse_args(input_args=None): help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", ) parser.add_argument( - "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + "--train_batch_size", + type=int, + default=4, + help="Batch size (per device) for the training dataloader.", ) parser.add_argument( - "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + "--sample_batch_size", + type=int, + default=4, + help="Batch size (per device) for sampling images.", ) parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( @@ -463,7 +508,10 @@ def parse_args(input_args=None): ), ) parser.add_argument( - "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + "--lr_warmup_steps", + type=int, + default=500, + help="Number of steps for the warmup in the lr scheduler.", ) parser.add_argument( "--lr_num_cycles", @@ -471,7 +519,12 @@ def parse_args(input_args=None): default=1, help="Number of hard resets of the lr in cosine_with_restarts scheduler.", ) - parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--lr_power", + type=float, + default=1.0, + help="Power factor of the polynomial scheduler.", + ) parser.add_argument( "--dataloader_num_workers", type=int, @@ -485,13 +538,21 @@ def parse_args(input_args=None): type=str, default="none", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], - help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), + help=( + 'We default to the "none" weighting scheme for uniform sampling and uniform loss' + ), ) parser.add_argument( - "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + "--logit_mean", + type=float, + default=0.0, + help="mean to use when using the `'logit_normal'` weighting scheme.", ) parser.add_argument( - "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + "--logit_std", + type=float, + default=1.0, + help="std to use when using the `'logit_normal'` weighting scheme.", ) parser.add_argument( "--mode_scale", @@ -513,10 +574,16 @@ def parse_args(input_args=None): ) parser.add_argument( - "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + "--adam_beta1", + type=float, + default=0.9, + help="The beta1 parameter for the Adam and Prodigy optimizers.", ) parser.add_argument( - "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." + "--adam_beta2", + type=float, + default=0.999, + help="The beta2 parameter for the Adam and Prodigy optimizers.", ) parser.add_argument( "--prodigy_beta3", @@ -525,10 +592,23 @@ def parse_args(input_args=None): help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " "uses the value of square root of beta2. Ignored if optimizer is adamW", ) - parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") - parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") parser.add_argument( - "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" + "--prodigy_decouple", + type=bool, + default=True, + help="Use AdamW style decoupled weight decay", + ) + parser.add_argument( + "--adam_weight_decay", + type=float, + default=1e-04, + help="Weight decay to use for unet params", + ) + parser.add_argument( + "--adam_weight_decay_text_encoder", + type=float, + default=1e-03, + help="Weight decay to use for text_encoder", ) parser.add_argument( @@ -551,9 +631,20 @@ def parse_args(input_args=None): help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " "Ignored if optimizer is adamW", ) - parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") - parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") - parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." + ) + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether or not to push the model to the Hub.", + ) + parser.add_argument( + "--hub_token", + type=str, + default=None, + help="The token to use to push to the Model Hub.", + ) parser.add_argument( "--hub_model_id", type=str, @@ -607,7 +698,12 @@ def parse_args(input_args=None): " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." ), ) - parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--local_rank", + type=int, + default=-1, + help="For distributed training: local_rank", + ) if input_args is not None: args = parser.parse_args(input_args) @@ -618,7 +714,9 @@ def parse_args(input_args=None): raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") if args.dataset_name is not None and args.instance_data_dir is not None: - raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") + raise ValueError( + "Specify only one of `--dataset_name` or `--instance_data_dir`" + ) env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: @@ -632,9 +730,13 @@ def parse_args(input_args=None): else: # logger is not available yet if args.class_data_dir is not None: - warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + warnings.warn( + "You need not use --class_data_dir without --with_prior_preservation." + ) if args.class_prompt is not None: - warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + warnings.warn( + "You need not use --class_prompt without --with_prior_preservation." + ) return args @@ -713,13 +815,17 @@ def __init__( # create final list of captions according to --repeats self.custom_instance_prompts = [] for caption in custom_instance_prompts: - self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) + self.custom_instance_prompts.extend( + itertools.repeat(caption, repeats) + ) else: self.instance_data_root = Path(instance_data_root) if not self.instance_data_root.exists(): raise ValueError("Instance images root doesn't exists.") - instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] + instance_images = [ + Image.open(path) for path in list(Path(instance_data_root).iterdir()) + ] self.custom_instance_prompts = None self.instance_images = [] @@ -727,8 +833,12 @@ def __init__( self.instance_images.extend(itertools.repeat(img, repeats)) self.pixel_values = [] - train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) - train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) + train_resize = transforms.Resize( + size, interpolation=transforms.InterpolationMode.BILINEAR + ) + train_crop = ( + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) + ) train_flip = transforms.RandomHorizontalFlip(p=1.0) train_transforms = transforms.Compose( [ @@ -749,7 +859,9 @@ def __init__( x1 = max(0, int(round((image.width - args.resolution) / 2.0))) image = train_crop(image) else: - y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) + y1, x1, h, w = train_crop.get_params( + image, (args.resolution, args.resolution) + ) image = crop(image, y1, x1, h, w) image = train_transforms(image) self.pixel_values.append(image) @@ -771,8 +883,14 @@ def __init__( self.image_transforms = transforms.Compose( [ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.Resize( + size, interpolation=transforms.InterpolationMode.BILINEAR + ), + ( + transforms.CenterCrop(size) + if center_crop + else transforms.RandomCrop(size) + ), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] @@ -797,7 +915,9 @@ def __getitem__(self, index): example["instance_prompt"] = self.instance_prompt if self.class_data_root: - class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = Image.open( + self.class_images_path[index % self.num_class_images] + ) class_image = exif_transpose(class_image) if not class_image.mode == "RGB": @@ -881,7 +1001,9 @@ def _encode_prompt_with_t5( text_input_ids = text_inputs.input_ids else: if text_input_ids is None: - raise ValueError("text_input_ids must be provided when the tokenizer is not specified") + raise ValueError( + "text_input_ids must be provided when the tokenizer is not specified" + ) prompt_embeds = text_encoder(text_input_ids.to(device))[0] @@ -922,7 +1044,9 @@ def _encode_prompt_with_clip( text_input_ids = text_inputs.input_ids else: if text_input_ids is None: - raise ValueError("text_input_ids must be provided when the tokenizer is not specified") + raise ValueError( + "text_input_ids must be provided when the tokenizer is not specified" + ) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) @@ -969,7 +1093,9 @@ def encode_prompt( text_input_ids=text_input_ids_list[1] if text_input_ids_list else None, ) - text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to( + device=device, dtype=dtype + ) text_ids = text_ids.repeat(num_images_per_prompt, 1, 1) return prompt_embeds, pooled_prompt_embeds, text_ids @@ -990,7 +1116,9 @@ def main(args): logging_dir = Path(args.output_dir, args.logging_dir) - accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + accelerator_project_config = ProjectConfiguration( + project_dir=args.output_dir, logging_dir=logging_dir + ) kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -1006,7 +1134,9 @@ def main(args): if args.report_to == "wandb": if not is_wandb_available(): - raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + raise ImportError( + "Make sure to install wandb if you want to use it for logging during training." + ) # Make one log on every process with the configuration for debugging. logging.basicConfig( @@ -1034,8 +1164,12 @@ def main(args): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available() - torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32 + has_supported_fp16_accelerator = ( + torch.cuda.is_available() or torch.backends.mps.is_available() + ) + torch_dtype = ( + torch.float16 if has_supported_fp16_accelerator else torch.float32 + ) if args.prior_generation_precision == "fp32": torch_dtype = torch.float32 elif args.prior_generation_precision == "fp16": @@ -1054,19 +1188,26 @@ def main(args): logger.info(f"Number of class images to sample: {num_new_images}.") sample_dataset = PromptDataset(args.class_prompt, num_new_images) - sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + sample_dataloader = torch.utils.data.DataLoader( + sample_dataset, batch_size=args.sample_batch_size + ) sample_dataloader = accelerator.prepare(sample_dataloader) pipeline.to(accelerator.device) for example in tqdm( - sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + sample_dataloader, + desc="Generating class images", + disable=not accelerator.is_local_main_process, ): images = pipeline(example["prompt"]).images for i, image in enumerate(images): hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() - image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image_filename = ( + class_images_dir + / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + ) image.save(image_filename) del pipeline @@ -1109,7 +1250,9 @@ def main(args): args.pretrained_model_name_or_path, subfolder="scheduler" ) noise_scheduler_copy = copy.deepcopy(noise_scheduler) - text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) + text_encoder_one, text_encoder_two = load_text_encoders( + text_encoder_cls_one, text_encoder_cls_two + ) vae = AutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae", @@ -1117,7 +1260,10 @@ def main(args): variant=args.variant, ) transformer = FluxTransformer2DModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant + args.pretrained_model_name_or_path, + subfolder="transformer", + revision=args.revision, + variant=args.variant, ) transformer.requires_grad_(True) @@ -1163,12 +1309,20 @@ def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: for i, model in enumerate(models): if isinstance(unwrap_model(model), FluxTransformer2DModel): - unwrap_model(model).save_pretrained(os.path.join(output_dir, "transformer")) - elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)): + unwrap_model(model).save_pretrained( + os.path.join(output_dir, "transformer") + ) + elif isinstance( + unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel) + ): if isinstance(unwrap_model(model), CLIPTextModelWithProjection): - unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder")) + unwrap_model(model).save_pretrained( + os.path.join(output_dir, "text_encoder") + ) else: - unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder_2")) + unwrap_model(model).save_pretrained( + os.path.join(output_dir, "text_encoder_2") + ) else: raise ValueError(f"Wrong model supplied: {type(model)=}.") @@ -1182,22 +1336,32 @@ def load_model_hook(models, input_dir): # load diffusers style into model if isinstance(unwrap_model(model), FluxTransformer2DModel): - load_model = FluxTransformer2DModel.from_pretrained(input_dir, subfolder="transformer") + load_model = FluxTransformer2DModel.from_pretrained( + input_dir, subfolder="transformer" + ) model.register_to_config(**load_model.config) model.load_state_dict(load_model.state_dict()) - elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)): + elif isinstance( + unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel) + ): try: - load_model = CLIPTextModelWithProjection.from_pretrained(input_dir, subfolder="text_encoder") + load_model = CLIPTextModelWithProjection.from_pretrained( + input_dir, subfolder="text_encoder" + ) model(**load_model.config) model.load_state_dict(load_model.state_dict()) except Exception: try: - load_model = T5EncoderModel.from_pretrained(input_dir, subfolder="text_encoder_2") + load_model = T5EncoderModel.from_pretrained( + input_dir, subfolder="text_encoder_2" + ) model(**load_model.config) model.load_state_dict(load_model.state_dict()) except Exception: - raise ValueError(f"Couldn't load the model of type: ({type(model)}).") + raise ValueError( + f"Couldn't load the model of type: ({type(model)})." + ) else: raise ValueError(f"Unsupported model found: {type(model)=}") @@ -1213,11 +1377,17 @@ def load_model_hook(models, input_dir): if args.scale_lr: args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + args.learning_rate + * args.gradient_accumulation_steps + * args.train_batch_size + * accelerator.num_processes ) # Optimization parameters - transformer_parameters_with_lr = {"params": transformer.parameters(), "lr": args.learning_rate} + transformer_parameters_with_lr = { + "params": transformer.parameters(), + "lr": args.learning_rate, + } if args.train_text_encoder: # different learning rate for text encoder and unet text_parameters_one_with_lr = { @@ -1270,7 +1440,9 @@ def load_model_hook(models, input_dir): try: import prodigyopt except ImportError: - raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + raise ImportError( + "To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`" + ) optimizer_class = prodigyopt.Prodigy @@ -1339,15 +1511,17 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid # the redundant encoding. if not args.train_text_encoder and not train_dataset.custom_instance_prompts: - instance_prompt_hidden_states, instance_pooled_prompt_embeds, instance_text_ids = compute_text_embeddings( - args.instance_prompt, text_encoders, tokenizers - ) + ( + instance_prompt_hidden_states, + instance_pooled_prompt_embeds, + instance_text_ids, + ) = compute_text_embeddings(args.instance_prompt, text_encoders, tokenizers) # Handle class prompt for prior-preservation. if args.with_prior_preservation: if not args.train_text_encoder: - class_prompt_hidden_states, class_pooled_prompt_embeds, class_text_ids = compute_text_embeddings( - args.class_prompt, text_encoders, tokenizers + class_prompt_hidden_states, class_pooled_prompt_embeds, class_text_ids = ( + compute_text_embeddings(args.class_prompt, text_encoders, tokenizers) ) # Clear the memory here @@ -1369,23 +1543,37 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): pooled_prompt_embeds = instance_pooled_prompt_embeds text_ids = instance_text_ids if args.with_prior_preservation: - prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) - pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0) + prompt_embeds = torch.cat( + [prompt_embeds, class_prompt_hidden_states], dim=0 + ) + pooled_prompt_embeds = torch.cat( + [pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0 + ) text_ids = torch.cat([text_ids, class_text_ids], dim=0) # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the # batch prompts on all training steps else: - tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt, max_sequence_length=77) - tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt, max_sequence_length=512) + tokens_one = tokenize_prompt( + tokenizer_one, args.instance_prompt, max_sequence_length=77 + ) + tokens_two = tokenize_prompt( + tokenizer_two, args.instance_prompt, max_sequence_length=512 + ) if args.with_prior_preservation: - class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt, max_sequence_length=77) - class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt, max_sequence_length=512) + class_tokens_one = tokenize_prompt( + tokenizer_one, args.class_prompt, max_sequence_length=77 + ) + class_tokens_two = tokenize_prompt( + tokenizer_two, args.class_prompt, max_sequence_length=512 + ) tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) # Scheduler and math around the number of training steps. overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True @@ -1420,7 +1608,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs @@ -1433,14 +1623,20 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): accelerator.init_trackers(tracker_name, config=vars(args)) # Train! - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + total_batch_size = ( + args.train_batch_size + * accelerator.num_processes + * args.gradient_accumulation_steps + ) logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num batches each epoch = {len(train_dataloader)}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info( + f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" + ) logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") global_step = 0 @@ -1509,13 +1705,17 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # encode batch prompts when custom prompts are provided for each image - if train_dataset.custom_instance_prompts: if not args.train_text_encoder: - prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings( - prompts, text_encoders, tokenizers + prompt_embeds, pooled_prompt_embeds, text_ids = ( + compute_text_embeddings(prompts, text_encoders, tokenizers) ) else: - tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77) + tokens_one = tokenize_prompt( + tokenizer_one, prompts, max_sequence_length=77 + ) tokens_two = tokenize_prompt( - tokenizer_two, prompts, max_sequence_length=args.max_sequence_length + tokenizer_two, + prompts, + max_sequence_length=args.max_sequence_length, ) prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( text_encoders=[text_encoder_one, text_encoder_two], @@ -1536,7 +1736,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Convert images to latent space model_input = vae.encode(pixel_values).latent_dist.sample() - model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor + model_input = ( + model_input - vae.config.shift_factor + ) * vae.config.scaling_factor model_input = model_input.to(dtype=weight_dtype) vae_scale_factor = 2 ** (len(vae.config.block_out_channels)) @@ -1563,11 +1765,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): mode_scale=args.mode_scale, ) indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() - timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) + timesteps = noise_scheduler_copy.timesteps[indices].to( + device=model_input.device + ) # Add noise according to flow matching. # zt = (1 - texp) * x + texp * z1 - sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) + sigmas = get_sigmas( + timesteps, n_dim=model_input.ndim, dtype=model_input.dtype + ) noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise packed_noisy_model_input = FluxPipeline._pack_latents( @@ -1580,7 +1786,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # handle guidance if transformer.config.guidance_embeds: - guidance = torch.tensor([args.guidance_scale], device=accelerator.device) + guidance = torch.tensor( + [args.guidance_scale], device=accelerator.device + ) guidance = guidance.expand(model_input.shape[0]) else: guidance = None @@ -1607,7 +1815,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss - weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + weighting = compute_loss_weighting_for_sd3( + weighting_scheme=args.weighting_scheme, sigmas=sigmas + ) # flow matching loss target = noise - model_input @@ -1619,16 +1829,19 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Compute prior loss prior_loss = torch.mean( - (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( - target_prior.shape[0], -1 - ), + ( + weighting.float() + * (model_pred_prior.float() - target_prior.float()) ** 2 + ).reshape(target_prior.shape[0], -1), 1, ) prior_loss = prior_loss.mean() # Compute regular loss. loss = torch.mean( - (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + ( + weighting.float() * (model_pred.float() - target.float()) ** 2 + ).reshape(target.shape[0], -1), 1, ) loss = loss.mean() @@ -1640,7 +1853,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator.backward(loss) if accelerator.sync_gradients: params_to_clip = ( - itertools.chain(transformer.parameters(), text_encoder_one.parameters()) + itertools.chain( + transformer.parameters(), text_encoder_one.parameters() + ) if args.train_text_encoder else transformer.parameters() ) @@ -1660,24 +1875,36 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: checkpoints = os.listdir(args.output_dir) - checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] - checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + checkpoints = [ + d for d in checkpoints if d.startswith("checkpoint") + ] + checkpoints = sorted( + checkpoints, key=lambda x: int(x.split("-")[1]) + ) # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints if len(checkpoints) >= args.checkpoints_total_limit: - num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + num_to_remove = ( + len(checkpoints) - args.checkpoints_total_limit + 1 + ) removing_checkpoints = checkpoints[0:num_to_remove] logger.info( f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" ) - logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + logger.info( + f"removing checkpoints: {', '.join(removing_checkpoints)}" + ) for removing_checkpoint in removing_checkpoints: - removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + removing_checkpoint = os.path.join( + args.output_dir, removing_checkpoint + ) shutil.rmtree(removing_checkpoint) - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + save_path = os.path.join( + args.output_dir, f"checkpoint-{global_step}" + ) accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") @@ -1689,10 +1916,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): break if accelerator.is_main_process: - if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + if ( + args.validation_prompt is not None + and epoch % args.validation_epochs == 0 + ): # create pipeline if not args.train_text_encoder: - text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) + text_encoder_one, text_encoder_two = load_text_encoders( + text_encoder_cls_one, text_encoder_cls_two + ) else: # even when training the text encoder we're only training text encoder one text_encoder_two = text_encoder_cls_two.from_pretrained( args.pretrained_model_name_or_path, @@ -1717,6 +1949,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator=accelerator, pipeline_args=pipeline_args, epoch=epoch, + torch_dtype=weight_dtype, ) if not args.train_text_encoder: del text_encoder_one, text_encoder_two @@ -1736,7 +1969,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): text_encoder=text_encoder_one, ) else: - pipeline = FluxPipeline.from_pretrained(args.pretrained_model_name_or_path, transformer=transformer) + pipeline = FluxPipeline.from_pretrained( + args.pretrained_model_name_or_path, transformer=transformer + ) # save the pipeline pipeline.save_pretrained(args.output_dir) @@ -1761,6 +1996,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): pipeline_args=pipeline_args, epoch=epoch, is_final_validation=True, + torch_dtype=weight_dtype, ) if args.push_to_hub: diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 331b2d6ab611..e0fb47dcae30 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -54,7 +54,10 @@ ) from diffusers.loaders import StableDiffusionLoraLoaderMixin from diffusers.optimization import get_scheduler -from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params +from diffusers.training_utils import ( + _set_state_dict_into_text_encoder, + cast_training_params, +) from diffusers.utils import ( check_min_version, convert_state_dict_to_diffusers, @@ -65,7 +68,6 @@ from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.torch_utils import is_compiled_module - if is_wandb_available(): import wandb @@ -122,6 +124,7 @@ def log_validation( accelerator, pipeline_args, epoch, + torch_dtype, is_final_validation=False, ): logger.info( @@ -139,13 +142,19 @@ def log_validation( scheduler_args["variance_type"] = variance_type - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) + pipeline.scheduler = DPMSolverMultistepScheduler.from_config( + pipeline.scheduler.config, **scheduler_args + ) - pipeline = pipeline.to(accelerator.device) + pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) pipeline.set_progress_bar_config(disable=True) # run inference - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + generator = ( + torch.Generator(device=accelerator.device).manual_seed(args.seed) + if args.seed + else None + ) if args.validation_images is None: images = [] @@ -158,7 +167,9 @@ def log_validation( for image in args.validation_images: image = Image.open(image) with torch.cuda.amp.autocast(): - image = pipeline(**pipeline_args, image=image, generator=generator).images[0] + image = pipeline( + **pipeline_args, image=image, generator=generator + ).images[0] images.append(image) for tracker in accelerator.trackers: @@ -170,7 +181,8 @@ def log_validation( tracker.log( { phase_name: [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) ] } ) @@ -181,7 +193,9 @@ def log_validation( return images -def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str +): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder", @@ -194,7 +208,9 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st return CLIPTextModel elif model_class == "RobertaSeriesModelWithTransformation": - from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation + from diffusers.pipelines.alt_diffusion.modeling_roberta_series import ( + RobertaSeriesModelWithTransformation, + ) return RobertaSeriesModelWithTransformation elif model_class == "T5EncoderModel": @@ -287,7 +303,12 @@ def parse_args(input_args=None): action="store_true", help="Flag to add prior preservation loss.", ) - parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--prior_loss_weight", + type=float, + default=1.0, + help="The weight of prior preservation loss.", + ) parser.add_argument( "--num_class_images", type=int, @@ -303,7 +324,9 @@ def parse_args(input_args=None): default="lora-dreambooth-model", help="The output directory where the model predictions and checkpoints will be written.", ) - parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--seed", type=int, default=None, help="A seed for reproducible training." + ) parser.add_argument( "--resolution", type=int, @@ -328,10 +351,16 @@ def parse_args(input_args=None): help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", ) parser.add_argument( - "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + "--train_batch_size", + type=int, + default=4, + help="Batch size (per device) for the training dataloader.", ) parser.add_argument( - "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + "--sample_batch_size", + type=int, + default=4, + help="Batch size (per device) for sampling images.", ) parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( @@ -398,7 +427,10 @@ def parse_args(input_args=None): ), ) parser.add_argument( - "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + "--lr_warmup_steps", + type=int, + default=500, + help="Number of steps for the warmup in the lr scheduler.", ) parser.add_argument( "--lr_num_cycles", @@ -406,7 +438,12 @@ def parse_args(input_args=None): default=1, help="Number of hard resets of the lr in cosine_with_restarts scheduler.", ) - parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--lr_power", + type=float, + default=1.0, + help="Power factor of the polynomial scheduler.", + ) parser.add_argument( "--dataloader_num_workers", type=int, @@ -416,15 +453,45 @@ def parse_args(input_args=None): ), ) parser.add_argument( - "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." - ) - parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") - parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") - parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") - parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") - parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") - parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") - parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes.", + ) + parser.add_argument( + "--adam_beta1", + type=float, + default=0.9, + help="The beta1 parameter for the Adam optimizer.", + ) + parser.add_argument( + "--adam_beta2", + type=float, + default=0.999, + help="The beta2 parameter for the Adam optimizer.", + ) + parser.add_argument( + "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use." + ) + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer", + ) + parser.add_argument( + "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." + ) + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether or not to push the model to the Hub.", + ) + parser.add_argument( + "--hub_token", + type=str, + default=None, + help="The token to use to push to the Model Hub.", + ) parser.add_argument( "--hub_model_id", type=str, @@ -478,9 +545,16 @@ def parse_args(input_args=None): " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." ), ) - parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument( - "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + "--local_rank", + type=int, + default=-1, + help="For distributed training: local_rank", + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", + action="store_true", + help="Whether or not to use xformers.", ) parser.add_argument( "--pre_compute_text_embeddings", @@ -537,12 +611,18 @@ def parse_args(input_args=None): else: # logger is not available yet if args.class_data_dir is not None: - warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + warnings.warn( + "You need not use --class_data_dir without --with_prior_preservation." + ) if args.class_prompt is not None: - warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + warnings.warn( + "You need not use --class_prompt without --with_prior_preservation." + ) if args.train_text_encoder and args.pre_compute_text_embeddings: - raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`") + raise ValueError( + "`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`" + ) return args @@ -598,8 +678,14 @@ def __init__( self.image_transforms = transforms.Compose( [ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.Resize( + size, interpolation=transforms.InterpolationMode.BILINEAR + ), + ( + transforms.CenterCrop(size) + if center_crop + else transforms.RandomCrop(size) + ), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] @@ -610,7 +696,9 @@ def __len__(self): def __getitem__(self, index): example = {} - instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + instance_image = Image.open( + self.instance_images_path[index % self.num_instance_images] + ) instance_image = exif_transpose(instance_image) if not instance_image.mode == "RGB": @@ -621,13 +709,17 @@ def __getitem__(self, index): example["instance_prompt_ids"] = self.encoder_hidden_states else: text_inputs = tokenize_prompt( - self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length + self.tokenizer, + self.instance_prompt, + tokenizer_max_length=self.tokenizer_max_length, ) example["instance_prompt_ids"] = text_inputs.input_ids example["instance_attention_mask"] = text_inputs.attention_mask if self.class_data_root: - class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = Image.open( + self.class_images_path[index % self.num_class_images] + ) class_image = exif_transpose(class_image) if not class_image.mode == "RGB": @@ -638,7 +730,9 @@ def __getitem__(self, index): example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states else: class_text_inputs = tokenize_prompt( - self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length + self.tokenizer, + self.class_prompt, + tokenizer_max_length=self.tokenizer_max_length, ) example["class_prompt_ids"] = class_text_inputs.input_ids example["class_attention_mask"] = class_text_inputs.attention_mask @@ -713,7 +807,9 @@ def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): return text_inputs -def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None): +def encode_prompt( + text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None +): text_input_ids = input_ids.to(text_encoder.device) if text_encoder_use_attention_mask: @@ -740,7 +836,9 @@ def main(args): logging_dir = Path(args.output_dir, args.logging_dir) - accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + accelerator_project_config = ProjectConfiguration( + project_dir=args.output_dir, logging_dir=logging_dir + ) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -755,12 +853,18 @@ def main(args): if args.report_to == "wandb": if not is_wandb_available(): - raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + raise ImportError( + "Make sure to install wandb if you want to use it for logging during training." + ) # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. # TODO (sayakpaul): Remove this check when gradient accumulation with two models is enabled in accelerate. - if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: + if ( + args.train_text_encoder + and args.gradient_accumulation_steps > 1 + and accelerator.num_processes > 1 + ): raise ValueError( "Gradient accumulation is not supported when training the text encoder in distributed training. " "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." @@ -792,7 +896,9 @@ def main(args): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 + torch_dtype = ( + torch.float16 if accelerator.device.type == "cuda" else torch.float32 + ) if args.prior_generation_precision == "fp32": torch_dtype = torch.float32 elif args.prior_generation_precision == "fp16": @@ -812,19 +918,26 @@ def main(args): logger.info(f"Number of class images to sample: {num_new_images}.") sample_dataset = PromptDataset(args.class_prompt, num_new_images) - sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + sample_dataloader = torch.utils.data.DataLoader( + sample_dataset, batch_size=args.sample_batch_size + ) sample_dataloader = accelerator.prepare(sample_dataloader) pipeline.to(accelerator.device) for example in tqdm( - sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + sample_dataloader, + desc="Generating class images", + disable=not accelerator.is_local_main_process, ): images = pipeline(example["prompt"]).images for i, image in enumerate(images): hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() - image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image_filename = ( + class_images_dir + / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + ) image.save(image_filename) del pipeline @@ -838,12 +951,16 @@ def main(args): if args.push_to_hub: repo_id = create_repo( - repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + token=args.hub_token, ).repo_id # Load the tokenizer if args.tokenizer_name: - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer_name, revision=args.revision, use_fast=False + ) elif args.pretrained_model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( args.pretrained_model_name_or_path, @@ -853,16 +970,26 @@ def main(args): ) # import correct text encoder class - text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) + text_encoder_cls = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision + ) # Load scheduler and models - noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + noise_scheduler = DDPMScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder="scheduler" + ) text_encoder = text_encoder_cls.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + variant=args.variant, ) try: vae = AutoencoderKL.from_pretrained( - args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, ) except OSError: # IF does not have a VAE so let's just set it to None @@ -870,7 +997,10 @@ def main(args): vae = None unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant + args.pretrained_model_name_or_path, + subfolder="unet", + revision=args.revision, + variant=args.variant, ) # We only train the additional adapter LoRA layers @@ -904,7 +1034,9 @@ def main(args): ) unet.enable_xformers_memory_efficient_attention() else: - raise ValueError("xformers is not available. Make sure it is installed correctly") + raise ValueError( + "xformers is not available. Make sure it is installed correctly" + ) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() @@ -945,7 +1077,9 @@ def save_model_hook(models, weights, output_dir): for model in models: if isinstance(model, type(unwrap_model(unet))): - unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) + unet_lora_layers_to_save = convert_state_dict_to_diffusers( + get_peft_model_state_dict(model) + ) elif isinstance(model, type(unwrap_model(text_encoder))): text_encoder_lora_layers_to_save = convert_state_dict_to_diffusers( get_peft_model_state_dict(model) @@ -976,11 +1110,19 @@ def load_model_hook(models, input_dir): else: raise ValueError(f"unexpected save model: {model.__class__}") - lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) + lora_state_dict, network_alphas = ( + StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) + ) - unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} + unet_state_dict = { + f'{k.replace("unet.", "")}': v + for k, v in lora_state_dict.items() + if k.startswith("unet.") + } unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) - incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") + incompatible_keys = set_peft_model_state_dict( + unet_, unet_state_dict, adapter_name="default" + ) if incompatible_keys is not None: # check only for unexpected keys @@ -992,7 +1134,9 @@ def load_model_hook(models, input_dir): ) if args.train_text_encoder: - _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_) + _set_state_dict_into_text_encoder( + lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_ + ) # Make sure the trainable params are in float32. This is again needed since the base models # are in `weight_dtype`. More details: @@ -1015,7 +1159,10 @@ def load_model_hook(models, input_dir): if args.scale_lr: args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + args.learning_rate + * args.gradient_accumulation_steps + * args.train_batch_size + * accelerator.num_processes ) # Make sure the trainable params are in float32. @@ -1043,7 +1190,9 @@ def load_model_hook(models, input_dir): # Optimizer creation params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters())) if args.train_text_encoder: - params_to_optimize = params_to_optimize + list(filter(lambda p: p.requires_grad, text_encoder.parameters())) + params_to_optimize = params_to_optimize + list( + filter(lambda p: p.requires_grad, text_encoder.parameters()) + ) optimizer = optimizer_class( params_to_optimize, @@ -1057,7 +1206,9 @@ def load_model_hook(models, input_dir): def compute_text_embeddings(prompt): with torch.no_grad(): - text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length) + text_inputs = tokenize_prompt( + tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length + ) prompt_embeds = encode_prompt( text_encoder, text_inputs.input_ids, @@ -1067,16 +1218,22 @@ def compute_text_embeddings(prompt): return prompt_embeds - pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt) + pre_computed_encoder_hidden_states = compute_text_embeddings( + args.instance_prompt + ) validation_prompt_negative_prompt_embeds = compute_text_embeddings("") if args.validation_prompt is not None: - validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt) + validation_prompt_encoder_hidden_states = compute_text_embeddings( + args.validation_prompt + ) else: validation_prompt_encoder_hidden_states = None if args.class_prompt is not None: - pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt) + pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings( + args.class_prompt + ) else: pre_computed_class_prompt_encoder_hidden_states = None @@ -1116,7 +1273,9 @@ def compute_text_embeddings(prompt): # Scheduler and math around the number of training steps. overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True @@ -1132,8 +1291,10 @@ def compute_text_embeddings(prompt): # Prepare everything with our `accelerator`. if args.train_text_encoder: - unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = ( + accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) ) else: unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( @@ -1141,7 +1302,9 @@ def compute_text_embeddings(prompt): ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs @@ -1155,14 +1318,20 @@ def compute_text_embeddings(prompt): accelerator.init_trackers("dreambooth-lora", config=tracker_config) # Train! - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + total_batch_size = ( + args.train_batch_size + * accelerator.num_processes + * args.gradient_accumulation_steps + ) logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num batches each epoch = {len(train_dataloader)}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info( + f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" + ) logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") global_step = 0 @@ -1223,13 +1392,18 @@ def compute_text_embeddings(prompt): bsz, channels, height, width = model_input.shape # Sample a random timestep for each image timesteps = torch.randint( - 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device + 0, + noise_scheduler.config.num_train_timesteps, + (bsz,), + device=model_input.device, ) timesteps = timesteps.long() # Add noise to the model input according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) + noisy_model_input = noise_scheduler.add_noise( + model_input, noise, timesteps + ) # Get the text embedding for conditioning if args.pre_compute_text_embeddings: @@ -1243,7 +1417,9 @@ def compute_text_embeddings(prompt): ) if unwrap_model(unet).config.in_channels == channels * 2: - noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) + noisy_model_input = torch.cat( + [noisy_model_input, noisy_model_input], dim=1 + ) if args.class_labels_conditioning == "timesteps": class_labels = timesteps @@ -1271,7 +1447,9 @@ def compute_text_embeddings(prompt): elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(model_input, noise, timesteps) else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + raise ValueError( + f"Unknown prediction type {noise_scheduler.config.prediction_type}" + ) if args.with_prior_preservation: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. @@ -1279,15 +1457,21 @@ def compute_text_embeddings(prompt): target, target_prior = torch.chunk(target, 2, dim=0) # Compute instance loss - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + loss = F.mse_loss( + model_pred.float(), target.float(), reduction="mean" + ) # Compute prior loss - prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + prior_loss = F.mse_loss( + model_pred_prior.float(), target_prior.float(), reduction="mean" + ) # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss else: - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + loss = F.mse_loss( + model_pred.float(), target.float(), reduction="mean" + ) accelerator.backward(loss) if accelerator.sync_gradients: @@ -1306,24 +1490,36 @@ def compute_text_embeddings(prompt): # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: checkpoints = os.listdir(args.output_dir) - checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] - checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + checkpoints = [ + d for d in checkpoints if d.startswith("checkpoint") + ] + checkpoints = sorted( + checkpoints, key=lambda x: int(x.split("-")[1]) + ) # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints if len(checkpoints) >= args.checkpoints_total_limit: - num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + num_to_remove = ( + len(checkpoints) - args.checkpoints_total_limit + 1 + ) removing_checkpoints = checkpoints[0:num_to_remove] logger.info( f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" ) - logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + logger.info( + f"removing checkpoints: {', '.join(removing_checkpoints)}" + ) for removing_checkpoint in removing_checkpoints: - removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + removing_checkpoint = os.path.join( + args.output_dir, removing_checkpoint + ) shutil.rmtree(removing_checkpoint) - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + save_path = os.path.join( + args.output_dir, f"checkpoint-{global_step}" + ) accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") @@ -1335,12 +1531,19 @@ def compute_text_embeddings(prompt): break if accelerator.is_main_process: - if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + if ( + args.validation_prompt is not None + and epoch % args.validation_epochs == 0 + ): # create pipeline pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=unwrap_model(unet), - text_encoder=None if args.pre_compute_text_embeddings else unwrap_model(text_encoder), + text_encoder=( + None + if args.pre_compute_text_embeddings + else unwrap_model(text_encoder) + ), revision=args.revision, variant=args.variant, torch_dtype=weight_dtype, @@ -1360,6 +1563,7 @@ def compute_text_embeddings(prompt): accelerator, pipeline_args, epoch, + torch_dtype=weight_dtype, ) # Save the lora layers @@ -1368,11 +1572,15 @@ def compute_text_embeddings(prompt): unet = unwrap_model(unet) unet = unet.to(torch.float32) - unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet)) + unet_lora_state_dict = convert_state_dict_to_diffusers( + get_peft_model_state_dict(unet) + ) if args.train_text_encoder: text_encoder = unwrap_model(text_encoder) - text_encoder_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder)) + text_encoder_state_dict = convert_state_dict_to_diffusers( + get_peft_model_state_dict(text_encoder) + ) else: text_encoder_state_dict = None @@ -1385,16 +1593,24 @@ def compute_text_embeddings(prompt): # Final inference # Load previous pipeline pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, ) # load attention processors - pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors") + pipeline.load_lora_weights( + args.output_dir, weight_name="pytorch_lora_weights.safetensors" + ) # run inference images = [] if args.validation_prompt and args.num_validation_images > 0: - pipeline_args = {"prompt": args.validation_prompt, "num_inference_steps": 25} + pipeline_args = { + "prompt": args.validation_prompt, + "num_inference_steps": 25, + } images = log_validation( pipeline, args, @@ -1402,6 +1618,7 @@ def compute_text_embeddings(prompt): pipeline_args, epoch, is_final_validation=True, + torch_dtype=weight_dtype, ) if args.push_to_hub: diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 48d669418fd8..9c60739f2aec 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -32,7 +32,11 @@ import transformers from accelerate import Accelerator from accelerate.logging import get_logger -from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from accelerate.utils import ( + DistributedDataParallelKwargs, + ProjectConfiguration, + set_seed, +) from huggingface_hub import create_repo, upload_folder from huggingface_hub.utils import insecure_hashlib from peft import LoraConfig, set_peft_model_state_dict @@ -67,7 +71,6 @@ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.torch_utils import is_compiled_module - if is_wandb_available(): import wandb @@ -91,7 +94,10 @@ def save_model_card( for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) widget_dict.append( - {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} + { + "text": validation_prompt if validation_prompt else " ", + "output": {"url": f"image_{i}.png"}, + } ) model_description = f""" @@ -156,10 +162,16 @@ def save_model_card( def load_text_encoders(class_one, class_two): text_encoder_one = class_one.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + variant=args.variant, ) text_encoder_two = class_two.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant + args.pretrained_model_name_or_path, + subfolder="text_encoder_2", + revision=args.revision, + variant=args.variant, ) return text_encoder_one, text_encoder_two @@ -170,22 +182,30 @@ def log_validation( accelerator, pipeline_args, epoch, + torch_dtype, is_final_validation=False, ): logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) - pipeline = pipeline.to(accelerator.device) + pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) pipeline.set_progress_bar_config(disable=True) # run inference - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + generator = ( + torch.Generator(device=accelerator.device).manual_seed(args.seed) + if args.seed + else None + ) # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() autocast_ctx = nullcontext() with autocast_ctx: - images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] + images = [ + pipeline(**pipeline_args, generator=generator).images[0] + for _ in range(args.num_validation_images) + ] for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" @@ -196,7 +216,8 @@ def log_validation( tracker.log( { phase_name: [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) ] } ) @@ -294,7 +315,12 @@ def parse_args(input_args=None): help="The column of the dataset containing the instance prompt for each image", ) - parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") + parser.add_argument( + "--repeats", + type=int, + default=1, + help="How many times to repeat the training data.", + ) parser.add_argument( "--class_data_dir", @@ -355,7 +381,12 @@ def parse_args(input_args=None): action="store_true", help="Flag to add prior preservation loss.", ) - parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--prior_loss_weight", + type=float, + default=1.0, + help="The weight of prior preservation loss.", + ) parser.add_argument( "--num_class_images", type=int, @@ -371,7 +402,9 @@ def parse_args(input_args=None): default="flux-dreambooth-lora", help="The output directory where the model predictions and checkpoints will be written.", ) - parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--seed", type=int, default=None, help="A seed for reproducible training." + ) parser.add_argument( "--resolution", type=int, @@ -401,10 +434,16 @@ def parse_args(input_args=None): help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", ) parser.add_argument( - "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + "--train_batch_size", + type=int, + default=4, + help="Batch size (per device) for the training dataloader.", ) parser.add_argument( - "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + "--sample_batch_size", + type=int, + default=4, + help="Batch size (per device) for sampling images.", ) parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( @@ -485,7 +524,10 @@ def parse_args(input_args=None): ), ) parser.add_argument( - "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + "--lr_warmup_steps", + type=int, + default=500, + help="Number of steps for the warmup in the lr scheduler.", ) parser.add_argument( "--lr_num_cycles", @@ -493,7 +535,12 @@ def parse_args(input_args=None): default=1, help="Number of hard resets of the lr in cosine_with_restarts scheduler.", ) - parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--lr_power", + type=float, + default=1.0, + help="Power factor of the polynomial scheduler.", + ) parser.add_argument( "--dataloader_num_workers", type=int, @@ -507,13 +554,21 @@ def parse_args(input_args=None): type=str, default="none", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], - help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), + help=( + 'We default to the "none" weighting scheme for uniform sampling and uniform loss' + ), ) parser.add_argument( - "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + "--logit_mean", + type=float, + default=0.0, + help="mean to use when using the `'logit_normal'` weighting scheme.", ) parser.add_argument( - "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + "--logit_std", + type=float, + default=1.0, + help="std to use when using the `'logit_normal'` weighting scheme.", ) parser.add_argument( "--mode_scale", @@ -535,10 +590,16 @@ def parse_args(input_args=None): ) parser.add_argument( - "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + "--adam_beta1", + type=float, + default=0.9, + help="The beta1 parameter for the Adam and Prodigy optimizers.", ) parser.add_argument( - "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." + "--adam_beta2", + type=float, + default=0.999, + help="The beta2 parameter for the Adam and Prodigy optimizers.", ) parser.add_argument( "--prodigy_beta3", @@ -547,10 +608,23 @@ def parse_args(input_args=None): help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " "uses the value of square root of beta2. Ignored if optimizer is adamW", ) - parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") - parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") parser.add_argument( - "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" + "--prodigy_decouple", + type=bool, + default=True, + help="Use AdamW style decoupled weight decay", + ) + parser.add_argument( + "--adam_weight_decay", + type=float, + default=1e-04, + help="Weight decay to use for unet params", + ) + parser.add_argument( + "--adam_weight_decay_text_encoder", + type=float, + default=1e-03, + help="Weight decay to use for text_encoder", ) parser.add_argument( @@ -573,9 +647,20 @@ def parse_args(input_args=None): help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " "Ignored if optimizer is adamW", ) - parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") - parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") - parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." + ) + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether or not to push the model to the Hub.", + ) + parser.add_argument( + "--hub_token", + type=str, + default=None, + help="The token to use to push to the Model Hub.", + ) parser.add_argument( "--hub_model_id", type=str, @@ -629,7 +714,12 @@ def parse_args(input_args=None): " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." ), ) - parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--local_rank", + type=int, + default=-1, + help="For distributed training: local_rank", + ) if input_args is not None: args = parser.parse_args(input_args) @@ -640,7 +730,9 @@ def parse_args(input_args=None): raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") if args.dataset_name is not None and args.instance_data_dir is not None: - raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") + raise ValueError( + "Specify only one of `--dataset_name` or `--instance_data_dir`" + ) env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: @@ -654,9 +746,13 @@ def parse_args(input_args=None): else: # logger is not available yet if args.class_data_dir is not None: - warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + warnings.warn( + "You need not use --class_data_dir without --with_prior_preservation." + ) if args.class_prompt is not None: - warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + warnings.warn( + "You need not use --class_prompt without --with_prior_preservation." + ) return args @@ -735,13 +831,17 @@ def __init__( # create final list of captions according to --repeats self.custom_instance_prompts = [] for caption in custom_instance_prompts: - self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) + self.custom_instance_prompts.extend( + itertools.repeat(caption, repeats) + ) else: self.instance_data_root = Path(instance_data_root) if not self.instance_data_root.exists(): raise ValueError("Instance images root doesn't exists.") - instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] + instance_images = [ + Image.open(path) for path in list(Path(instance_data_root).iterdir()) + ] self.custom_instance_prompts = None self.instance_images = [] @@ -749,8 +849,12 @@ def __init__( self.instance_images.extend(itertools.repeat(img, repeats)) self.pixel_values = [] - train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) - train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) + train_resize = transforms.Resize( + size, interpolation=transforms.InterpolationMode.BILINEAR + ) + train_crop = ( + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) + ) train_flip = transforms.RandomHorizontalFlip(p=1.0) train_transforms = transforms.Compose( [ @@ -771,7 +875,9 @@ def __init__( x1 = max(0, int(round((image.width - args.resolution) / 2.0))) image = train_crop(image) else: - y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) + y1, x1, h, w = train_crop.get_params( + image, (args.resolution, args.resolution) + ) image = crop(image, y1, x1, h, w) image = train_transforms(image) self.pixel_values.append(image) @@ -793,8 +899,14 @@ def __init__( self.image_transforms = transforms.Compose( [ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.Resize( + size, interpolation=transforms.InterpolationMode.BILINEAR + ), + ( + transforms.CenterCrop(size) + if center_crop + else transforms.RandomCrop(size) + ), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] @@ -819,7 +931,9 @@ def __getitem__(self, index): example["instance_prompt"] = self.instance_prompt if self.class_data_root: - class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = Image.open( + self.class_images_path[index % self.num_class_images] + ) class_image = exif_transpose(class_image) if not class_image.mode == "RGB": @@ -903,7 +1017,9 @@ def _encode_prompt_with_t5( text_input_ids = text_inputs.input_ids else: if text_input_ids is None: - raise ValueError("text_input_ids must be provided when the tokenizer is not specified") + raise ValueError( + "text_input_ids must be provided when the tokenizer is not specified" + ) prompt_embeds = text_encoder(text_input_ids.to(device))[0] @@ -944,7 +1060,9 @@ def _encode_prompt_with_clip( text_input_ids = text_inputs.input_ids else: if text_input_ids is None: - raise ValueError("text_input_ids must be provided when the tokenizer is not specified") + raise ValueError( + "text_input_ids must be provided when the tokenizer is not specified" + ) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) @@ -991,7 +1109,9 @@ def encode_prompt( text_input_ids=text_input_ids_list[1] if text_input_ids_list else None, ) - text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to( + device=device, dtype=dtype + ) text_ids = text_ids.repeat(num_images_per_prompt, 1, 1) return prompt_embeds, pooled_prompt_embeds, text_ids @@ -1012,7 +1132,9 @@ def main(args): logging_dir = Path(args.output_dir, args.logging_dir) - accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + accelerator_project_config = ProjectConfiguration( + project_dir=args.output_dir, logging_dir=logging_dir + ) kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -1028,7 +1150,9 @@ def main(args): if args.report_to == "wandb": if not is_wandb_available(): - raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + raise ImportError( + "Make sure to install wandb if you want to use it for logging during training." + ) # Make one log on every process with the configuration for debugging. logging.basicConfig( @@ -1056,8 +1180,12 @@ def main(args): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available() - torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32 + has_supported_fp16_accelerator = ( + torch.cuda.is_available() or torch.backends.mps.is_available() + ) + torch_dtype = ( + torch.float16 if has_supported_fp16_accelerator else torch.float32 + ) if args.prior_generation_precision == "fp32": torch_dtype = torch.float32 elif args.prior_generation_precision == "fp16": @@ -1076,19 +1204,26 @@ def main(args): logger.info(f"Number of class images to sample: {num_new_images}.") sample_dataset = PromptDataset(args.class_prompt, num_new_images) - sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + sample_dataloader = torch.utils.data.DataLoader( + sample_dataset, batch_size=args.sample_batch_size + ) sample_dataloader = accelerator.prepare(sample_dataloader) pipeline.to(accelerator.device) for example in tqdm( - sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + sample_dataloader, + desc="Generating class images", + disable=not accelerator.is_local_main_process, ): images = pipeline(example["prompt"]).images for i, image in enumerate(images): hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() - image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image_filename = ( + class_images_dir + / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + ) image.save(image_filename) del pipeline @@ -1131,7 +1266,9 @@ def main(args): args.pretrained_model_name_or_path, subfolder="scheduler" ) noise_scheduler_copy = copy.deepcopy(noise_scheduler) - text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) + text_encoder_one, text_encoder_two = load_text_encoders( + text_encoder_cls_one, text_encoder_cls_two + ) vae = AutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae", @@ -1139,7 +1276,10 @@ def main(args): variant=args.variant, ) transformer = FluxTransformer2DModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant + args.pretrained_model_name_or_path, + subfolder="transformer", + revision=args.revision, + variant=args.variant, ) # We only train the additional adapter LoRA layers @@ -1204,7 +1344,9 @@ def save_model_hook(models, weights, output_dir): if isinstance(model, type(unwrap_model(transformer))): transformer_lora_layers_to_save = get_peft_model_state_dict(model) elif isinstance(model, type(unwrap_model(text_encoder_one))): - text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) + text_encoder_one_lora_layers_to_save = get_peft_model_state_dict( + model + ) else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1234,10 +1376,14 @@ def load_model_hook(models, input_dir): lora_state_dict = FluxPipeline.lora_state_dict(input_dir) transformer_state_dict = { - f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") + f'{k.replace("transformer.", "")}': v + for k, v in lora_state_dict.items() + if k.startswith("transformer.") } transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) - incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") + incompatible_keys = set_peft_model_state_dict( + transformer_, transformer_state_dict, adapter_name="default" + ) if incompatible_keys is not None: # check only for unexpected keys unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) @@ -1248,7 +1394,9 @@ def load_model_hook(models, input_dir): ) if args.train_text_encoder: # Do we need to call `scale_lora_layers()` here? - _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) + _set_state_dict_into_text_encoder( + lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_ + ) # Make sure the trainable params are in float32. This is again needed since the base models # are in `weight_dtype`. More details: @@ -1270,7 +1418,10 @@ def load_model_hook(models, input_dir): if args.scale_lr: args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + args.learning_rate + * args.gradient_accumulation_steps + * args.train_batch_size + * accelerator.num_processes ) # Make sure the trainable params are in float32. @@ -1281,12 +1432,19 @@ def load_model_hook(models, input_dir): # only upcast trainable parameters (LoRA) into fp32 cast_training_params(models, dtype=torch.float32) - transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + transformer_lora_parameters = list( + filter(lambda p: p.requires_grad, transformer.parameters()) + ) if args.train_text_encoder: - text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters())) + text_lora_parameters_one = list( + filter(lambda p: p.requires_grad, text_encoder_one.parameters()) + ) # Optimization parameters - transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} + transformer_parameters_with_lr = { + "params": transformer_lora_parameters, + "lr": args.learning_rate, + } if args.train_text_encoder: # different learning rate for text encoder and unet text_parameters_one_with_lr = { @@ -1339,7 +1497,9 @@ def load_model_hook(models, input_dir): try: import prodigyopt except ImportError: - raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + raise ImportError( + "To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`" + ) optimizer_class = prodigyopt.Prodigy @@ -1408,15 +1568,17 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid # the redundant encoding. if not args.train_text_encoder and not train_dataset.custom_instance_prompts: - instance_prompt_hidden_states, instance_pooled_prompt_embeds, instance_text_ids = compute_text_embeddings( - args.instance_prompt, text_encoders, tokenizers - ) + ( + instance_prompt_hidden_states, + instance_pooled_prompt_embeds, + instance_text_ids, + ) = compute_text_embeddings(args.instance_prompt, text_encoders, tokenizers) # Handle class prompt for prior-preservation. if args.with_prior_preservation: if not args.train_text_encoder: - class_prompt_hidden_states, class_pooled_prompt_embeds, class_text_ids = compute_text_embeddings( - args.class_prompt, text_encoders, tokenizers + class_prompt_hidden_states, class_pooled_prompt_embeds, class_text_ids = ( + compute_text_embeddings(args.class_prompt, text_encoders, tokenizers) ) # Clear the memory here @@ -1438,27 +1600,41 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): pooled_prompt_embeds = instance_pooled_prompt_embeds text_ids = instance_text_ids if args.with_prior_preservation: - prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) - pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0) + prompt_embeds = torch.cat( + [prompt_embeds, class_prompt_hidden_states], dim=0 + ) + pooled_prompt_embeds = torch.cat( + [pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0 + ) text_ids = torch.cat([text_ids, class_text_ids], dim=0) # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) # we need to tokenize and encode the batch prompts on all training steps else: - tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt, max_sequence_length=77) + tokens_one = tokenize_prompt( + tokenizer_one, args.instance_prompt, max_sequence_length=77 + ) tokens_two = tokenize_prompt( - tokenizer_two, args.instance_prompt, max_sequence_length=args.max_sequence_length + tokenizer_two, + args.instance_prompt, + max_sequence_length=args.max_sequence_length, ) if args.with_prior_preservation: - class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt, max_sequence_length=77) + class_tokens_one = tokenize_prompt( + tokenizer_one, args.class_prompt, max_sequence_length=77 + ) class_tokens_two = tokenize_prompt( - tokenizer_two, args.class_prompt, max_sequence_length=args.max_sequence_length + tokenizer_two, + args.class_prompt, + max_sequence_length=args.max_sequence_length, ) tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) # Scheduler and math around the number of training steps. overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True @@ -1493,7 +1669,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs @@ -1506,14 +1684,20 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): accelerator.init_trackers(tracker_name, config=vars(args)) # Train! - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + total_batch_size = ( + args.train_batch_size + * accelerator.num_processes + * args.gradient_accumulation_steps + ) logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num batches each epoch = {len(train_dataloader)}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info( + f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" + ) logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") global_step = 0 @@ -1571,7 +1755,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.train_text_encoder: text_encoder_one.train() # set top parameter requires_grad = True for gradient checkpointing works - accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) + accelerator.unwrap_model( + text_encoder_one + ).text_model.embeddings.requires_grad_(True) for step, batch in enumerate(train_dataloader): models_to_accumulate = [transformer] @@ -1584,13 +1770,17 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # encode batch prompts when custom prompts are provided for each image - if train_dataset.custom_instance_prompts: if not args.train_text_encoder: - prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings( - prompts, text_encoders, tokenizers + prompt_embeds, pooled_prompt_embeds, text_ids = ( + compute_text_embeddings(prompts, text_encoders, tokenizers) ) else: - tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77) + tokens_one = tokenize_prompt( + tokenizer_one, prompts, max_sequence_length=77 + ) tokens_two = tokenize_prompt( - tokenizer_two, prompts, max_sequence_length=args.max_sequence_length + tokenizer_two, + prompts, + max_sequence_length=args.max_sequence_length, ) prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( text_encoders=[text_encoder_one, text_encoder_two], @@ -1613,7 +1803,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Convert images to latent space model_input = vae.encode(pixel_values).latent_dist.sample() - model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor + model_input = ( + model_input - vae.config.shift_factor + ) * vae.config.scaling_factor model_input = model_input.to(dtype=weight_dtype) vae_scale_factor = 2 ** (len(vae.config.block_out_channels)) @@ -1639,11 +1831,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): mode_scale=args.mode_scale, ) indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() - timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) + timesteps = noise_scheduler_copy.timesteps[indices].to( + device=model_input.device + ) # Add noise according to flow matching. # zt = (1 - texp) * x + texp * z1 - sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) + sigmas = get_sigmas( + timesteps, n_dim=model_input.ndim, dtype=model_input.dtype + ) noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise packed_noisy_model_input = FluxPipeline._pack_latents( @@ -1656,7 +1852,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # handle guidance if transformer.config.guidance_embeds: - guidance = torch.tensor([args.guidance_scale], device=accelerator.device) + guidance = torch.tensor( + [args.guidance_scale], device=accelerator.device + ) guidance = guidance.expand(model_input.shape[0]) else: guidance = None @@ -1682,7 +1880,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss - weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + weighting = compute_loss_weighting_for_sd3( + weighting_scheme=args.weighting_scheme, sigmas=sigmas + ) # flow matching loss target = noise - model_input @@ -1694,16 +1894,19 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Compute prior loss prior_loss = torch.mean( - (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( - target_prior.shape[0], -1 - ), + ( + weighting.float() + * (model_pred_prior.float() - target_prior.float()) ** 2 + ).reshape(target_prior.shape[0], -1), 1, ) prior_loss = prior_loss.mean() # Compute regular loss. loss = torch.mean( - (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + ( + weighting.float() * (model_pred.float() - target.float()) ** 2 + ).reshape(target.shape[0], -1), 1, ) loss = loss.mean() @@ -1715,7 +1918,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator.backward(loss) if accelerator.sync_gradients: params_to_clip = ( - itertools.chain(transformer.parameters(), text_encoder_one.parameters()) + itertools.chain( + transformer.parameters(), text_encoder_one.parameters() + ) if args.train_text_encoder else transformer.parameters() ) @@ -1735,24 +1940,36 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: checkpoints = os.listdir(args.output_dir) - checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] - checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + checkpoints = [ + d for d in checkpoints if d.startswith("checkpoint") + ] + checkpoints = sorted( + checkpoints, key=lambda x: int(x.split("-")[1]) + ) # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints if len(checkpoints) >= args.checkpoints_total_limit: - num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + num_to_remove = ( + len(checkpoints) - args.checkpoints_total_limit + 1 + ) removing_checkpoints = checkpoints[0:num_to_remove] logger.info( f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" ) - logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + logger.info( + f"removing checkpoints: {', '.join(removing_checkpoints)}" + ) for removing_checkpoint in removing_checkpoints: - removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + removing_checkpoint = os.path.join( + args.output_dir, removing_checkpoint + ) shutil.rmtree(removing_checkpoint) - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + save_path = os.path.join( + args.output_dir, f"checkpoint-{global_step}" + ) accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") @@ -1764,10 +1981,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): break if accelerator.is_main_process: - if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + if ( + args.validation_prompt is not None + and epoch % args.validation_epochs == 0 + ): # create pipeline if not args.train_text_encoder: - text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) + text_encoder_one, text_encoder_two = load_text_encoders( + text_encoder_cls_one, text_encoder_cls_two + ) pipeline = FluxPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, @@ -1785,6 +2007,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator=accelerator, pipeline_args=pipeline_args, epoch=epoch, + torch_dtype=weight_dtype, ) if not args.train_text_encoder: del text_encoder_one, text_encoder_two @@ -1800,7 +2023,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.train_text_encoder: text_encoder_one = unwrap_model(text_encoder_one) - text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32)) + text_encoder_lora_layers = get_peft_model_state_dict( + text_encoder_one.to(torch.float32) + ) else: text_encoder_lora_layers = None @@ -1832,6 +2057,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): pipeline_args=pipeline_args, epoch=epoch, is_final_validation=True, + torch_dtype=weight_dtype, ) if args.push_to_hub: diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 17e6e107b079..061226b9b57a 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -31,7 +31,11 @@ import transformers from accelerate import Accelerator from accelerate.logging import get_logger -from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from accelerate.utils import ( + DistributedDataParallelKwargs, + ProjectConfiguration, + set_seed, +) from huggingface_hub import create_repo, upload_folder from huggingface_hub.utils import insecure_hashlib from peft import LoraConfig, set_peft_model_state_dict @@ -67,7 +71,6 @@ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.torch_utils import is_compiled_module - if is_wandb_available(): import wandb @@ -91,7 +94,10 @@ def save_model_card( for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) widget_dict.append( - {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} + { + "text": validation_prompt if validation_prompt else " ", + "output": {"url": f"image_{i}.png"}, + } ) model_description = f""" @@ -162,13 +168,22 @@ def save_model_card( def load_text_encoders(class_one, class_two, class_three): text_encoder_one = class_one.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + variant=args.variant, ) text_encoder_two = class_two.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant + args.pretrained_model_name_or_path, + subfolder="text_encoder_2", + revision=args.revision, + variant=args.variant, ) text_encoder_three = class_three.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder_3", revision=args.revision, variant=args.variant + args.pretrained_model_name_or_path, + subfolder="text_encoder_3", + revision=args.revision, + variant=args.variant, ) return text_encoder_one, text_encoder_two, text_encoder_three @@ -179,22 +194,30 @@ def log_validation( accelerator, pipeline_args, epoch, + torch_dtype, is_final_validation=False, ): logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) - pipeline = pipeline.to(accelerator.device) + pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) pipeline.set_progress_bar_config(disable=True) # run inference - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + generator = ( + torch.Generator(device=accelerator.device).manual_seed(args.seed) + if args.seed + else None + ) # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() autocast_ctx = nullcontext() with autocast_ctx: - images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] + images = [ + pipeline(**pipeline_args, generator=generator).images[0] + for _ in range(args.num_validation_images) + ] for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" @@ -205,7 +228,8 @@ def log_validation( tracker.log( { phase_name: [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) ] } ) @@ -301,7 +325,12 @@ def parse_args(input_args=None): help="The column of the dataset containing the instance prompt for each image", ) - parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") + parser.add_argument( + "--repeats", + type=int, + default=1, + help="How many times to repeat the training data.", + ) parser.add_argument( "--class_data_dir", @@ -362,7 +391,12 @@ def parse_args(input_args=None): action="store_true", help="Flag to add prior preservation loss.", ) - parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--prior_loss_weight", + type=float, + default=1.0, + help="The weight of prior preservation loss.", + ) parser.add_argument( "--num_class_images", type=int, @@ -378,7 +412,9 @@ def parse_args(input_args=None): default="sd3-dreambooth", help="The output directory where the model predictions and checkpoints will be written.", ) - parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--seed", type=int, default=None, help="A seed for reproducible training." + ) parser.add_argument( "--resolution", type=int, @@ -409,10 +445,16 @@ def parse_args(input_args=None): ) parser.add_argument( - "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + "--train_batch_size", + type=int, + default=4, + help="Batch size (per device) for the training dataloader.", ) parser.add_argument( - "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + "--sample_batch_size", + type=int, + default=4, + help="Batch size (per device) for sampling images.", ) parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( @@ -486,7 +528,10 @@ def parse_args(input_args=None): ), ) parser.add_argument( - "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + "--lr_warmup_steps", + type=int, + default=500, + help="Number of steps for the warmup in the lr scheduler.", ) parser.add_argument( "--lr_num_cycles", @@ -494,7 +539,12 @@ def parse_args(input_args=None): default=1, help="Number of hard resets of the lr in cosine_with_restarts scheduler.", ) - parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--lr_power", + type=float, + default=1.0, + help="Power factor of the polynomial scheduler.", + ) parser.add_argument( "--dataloader_num_workers", type=int, @@ -510,10 +560,16 @@ def parse_args(input_args=None): choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"], ) parser.add_argument( - "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + "--logit_mean", + type=float, + default=0.0, + help="mean to use when using the `'logit_normal'` weighting scheme.", ) parser.add_argument( - "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + "--logit_std", + type=float, + default=1.0, + help="std to use when using the `'logit_normal'` weighting scheme.", ) parser.add_argument( "--mode_scale", @@ -542,10 +598,16 @@ def parse_args(input_args=None): ) parser.add_argument( - "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + "--adam_beta1", + type=float, + default=0.9, + help="The beta1 parameter for the Adam and Prodigy optimizers.", ) parser.add_argument( - "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." + "--adam_beta2", + type=float, + default=0.999, + help="The beta2 parameter for the Adam and Prodigy optimizers.", ) parser.add_argument( "--prodigy_beta3", @@ -554,10 +616,23 @@ def parse_args(input_args=None): help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " "uses the value of square root of beta2. Ignored if optimizer is adamW", ) - parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") - parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") parser.add_argument( - "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" + "--prodigy_decouple", + type=bool, + default=True, + help="Use AdamW style decoupled weight decay", + ) + parser.add_argument( + "--adam_weight_decay", + type=float, + default=1e-04, + help="Weight decay to use for unet params", + ) + parser.add_argument( + "--adam_weight_decay_text_encoder", + type=float, + default=1e-03, + help="Weight decay to use for text_encoder", ) parser.add_argument( @@ -580,9 +655,20 @@ def parse_args(input_args=None): help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " "Ignored if optimizer is adamW", ) - parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") - parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") - parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." + ) + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether or not to push the model to the Hub.", + ) + parser.add_argument( + "--hub_token", + type=str, + default=None, + help="The token to use to push to the Model Hub.", + ) parser.add_argument( "--hub_model_id", type=str, @@ -636,7 +722,12 @@ def parse_args(input_args=None): " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." ), ) - parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--local_rank", + type=int, + default=-1, + help="For distributed training: local_rank", + ) if input_args is not None: args = parser.parse_args(input_args) @@ -647,7 +738,9 @@ def parse_args(input_args=None): raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") if args.dataset_name is not None and args.instance_data_dir is not None: - raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") + raise ValueError( + "Specify only one of `--dataset_name` or `--instance_data_dir`" + ) env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: @@ -661,9 +754,13 @@ def parse_args(input_args=None): else: # logger is not available yet if args.class_data_dir is not None: - warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + warnings.warn( + "You need not use --class_data_dir without --with_prior_preservation." + ) if args.class_prompt is not None: - warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + warnings.warn( + "You need not use --class_prompt without --with_prior_preservation." + ) return args @@ -742,13 +839,17 @@ def __init__( # create final list of captions according to --repeats self.custom_instance_prompts = [] for caption in custom_instance_prompts: - self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) + self.custom_instance_prompts.extend( + itertools.repeat(caption, repeats) + ) else: self.instance_data_root = Path(instance_data_root) if not self.instance_data_root.exists(): raise ValueError("Instance images root doesn't exists.") - instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] + instance_images = [ + Image.open(path) for path in list(Path(instance_data_root).iterdir()) + ] self.custom_instance_prompts = None self.instance_images = [] @@ -756,8 +857,12 @@ def __init__( self.instance_images.extend(itertools.repeat(img, repeats)) self.pixel_values = [] - train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) - train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) + train_resize = transforms.Resize( + size, interpolation=transforms.InterpolationMode.BILINEAR + ) + train_crop = ( + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) + ) train_flip = transforms.RandomHorizontalFlip(p=1.0) train_transforms = transforms.Compose( [ @@ -778,7 +883,9 @@ def __init__( x1 = max(0, int(round((image.width - args.resolution) / 2.0))) image = train_crop(image) else: - y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) + y1, x1, h, w = train_crop.get_params( + image, (args.resolution, args.resolution) + ) image = crop(image, y1, x1, h, w) image = train_transforms(image) self.pixel_values.append(image) @@ -800,8 +907,14 @@ def __init__( self.image_transforms = transforms.Compose( [ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.Resize( + size, interpolation=transforms.InterpolationMode.BILINEAR + ), + ( + transforms.CenterCrop(size) + if center_crop + else transforms.RandomCrop(size) + ), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] @@ -826,7 +939,9 @@ def __getitem__(self, index): example["instance_prompt"] = self.instance_prompt if self.class_data_root: - class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = Image.open( + self.class_images_path[index % self.num_class_images] + ) class_image = exif_transpose(class_image) if not class_image.mode == "RGB": @@ -907,7 +1022,9 @@ def _encode_prompt_with_t5( text_input_ids = text_inputs.input_ids else: if text_input_ids is None: - raise ValueError("text_input_ids must be provided when the tokenizer is not specified") + raise ValueError( + "text_input_ids must be provided when the tokenizer is not specified" + ) prompt_embeds = text_encoder(text_input_ids.to(device))[0] @@ -946,7 +1063,9 @@ def _encode_prompt_with_clip( text_input_ids = text_inputs.input_ids else: if text_input_ids is None: - raise ValueError("text_input_ids must be provided when the tokenizer is not specified") + raise ValueError( + "text_input_ids must be provided when the tokenizer is not specified" + ) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) @@ -978,7 +1097,9 @@ def encode_prompt( clip_prompt_embeds_list = [] clip_pooled_prompt_embeds_list = [] - for i, (tokenizer, text_encoder) in enumerate(zip(clip_tokenizers, clip_text_encoders)): + for i, (tokenizer, text_encoder) in enumerate( + zip(clip_tokenizers, clip_text_encoders) + ): prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip( text_encoder=text_encoder, tokenizer=tokenizer, @@ -1004,7 +1125,8 @@ def encode_prompt( ) clip_prompt_embeds = torch.nn.functional.pad( - clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) + clip_prompt_embeds, + (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]), ) prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) @@ -1026,7 +1148,9 @@ def main(args): logging_dir = Path(args.output_dir, args.logging_dir) - accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + accelerator_project_config = ProjectConfiguration( + project_dir=args.output_dir, logging_dir=logging_dir + ) kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -1042,7 +1166,9 @@ def main(args): if args.report_to == "wandb": if not is_wandb_available(): - raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + raise ImportError( + "Make sure to install wandb if you want to use it for logging during training." + ) # Make one log on every process with the configuration for debugging. logging.basicConfig( @@ -1070,8 +1196,12 @@ def main(args): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available() - torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32 + has_supported_fp16_accelerator = ( + torch.cuda.is_available() or torch.backends.mps.is_available() + ) + torch_dtype = ( + torch.float16 if has_supported_fp16_accelerator else torch.float32 + ) if args.prior_generation_precision == "fp32": torch_dtype = torch.float32 elif args.prior_generation_precision == "fp16": @@ -1090,19 +1220,26 @@ def main(args): logger.info(f"Number of class images to sample: {num_new_images}.") sample_dataset = PromptDataset(args.class_prompt, num_new_images) - sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + sample_dataloader = torch.utils.data.DataLoader( + sample_dataset, batch_size=args.sample_batch_size + ) sample_dataloader = accelerator.prepare(sample_dataloader) pipeline.to(accelerator.device) for example in tqdm( - sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + sample_dataloader, + desc="Generating class images", + disable=not accelerator.is_local_main_process, ): images = pipeline(example["prompt"]).images for i, image in enumerate(images): hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() - image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image_filename = ( + class_images_dir + / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + ) image.save(image_filename) clear_objs_and_retain_memory(objs=[pipeline]) @@ -1161,7 +1298,10 @@ def main(args): variant=args.variant, ) transformer = SD3Transformer2DModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant + args.pretrained_model_name_or_path, + subfolder="transformer", + revision=args.revision, + variant=args.variant, ) transformer.requires_grad_(False) @@ -1231,9 +1371,13 @@ def save_model_hook(models, weights, output_dir): if isinstance(model, type(unwrap_model(transformer))): transformer_lora_layers_to_save = get_peft_model_state_dict(model) elif isinstance(model, type(unwrap_model(text_encoder_one))): - text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) + text_encoder_one_lora_layers_to_save = get_peft_model_state_dict( + model + ) elif isinstance(model, type(unwrap_model(text_encoder_two))): - text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model) + text_encoder_two_lora_layers_to_save = get_peft_model_state_dict( + model + ) else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1267,10 +1411,14 @@ def load_model_hook(models, input_dir): lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir) transformer_state_dict = { - f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") + f'{k.replace("transformer.", "")}': v + for k, v in lora_state_dict.items() + if k.startswith("transformer.") } transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) - incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") + incompatible_keys = set_peft_model_state_dict( + transformer_, transformer_state_dict, adapter_name="default" + ) if incompatible_keys is not None: # check only for unexpected keys unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) @@ -1281,10 +1429,14 @@ def load_model_hook(models, input_dir): ) if args.train_text_encoder: # Do we need to call `scale_lora_layers()` here? - _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) + _set_state_dict_into_text_encoder( + lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_ + ) _set_state_dict_into_text_encoder( - lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_ + lora_state_dict, + prefix="text_encoder_2.", + text_encoder=text_encoder_two_, ) # Make sure the trainable params are in float32. This is again needed since the base models @@ -1307,7 +1459,10 @@ def load_model_hook(models, input_dir): if args.scale_lr: args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + args.learning_rate + * args.gradient_accumulation_steps + * args.train_batch_size + * accelerator.num_processes ) # Make sure the trainable params are in float32. @@ -1318,13 +1473,22 @@ def load_model_hook(models, input_dir): # only upcast trainable parameters (LoRA) into fp32 cast_training_params(models, dtype=torch.float32) - transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + transformer_lora_parameters = list( + filter(lambda p: p.requires_grad, transformer.parameters()) + ) if args.train_text_encoder: - text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters())) - text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters())) + text_lora_parameters_one = list( + filter(lambda p: p.requires_grad, text_encoder_one.parameters()) + ) + text_lora_parameters_two = list( + filter(lambda p: p.requires_grad, text_encoder_two.parameters()) + ) # Optimization parameters - transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} + transformer_parameters_with_lr = { + "params": transformer_lora_parameters, + "lr": args.learning_rate, + } if args.train_text_encoder: # different learning rate for text encoder and unet text_lora_parameters_one_with_lr = { @@ -1383,7 +1547,9 @@ def load_model_hook(models, input_dir): try: import prodigyopt except ImportError: - raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + raise ImportError( + "To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`" + ) optimizer_class = prodigyopt.Prodigy @@ -1438,22 +1604,28 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): return prompt_embeds, pooled_prompt_embeds if not args.train_text_encoder and not train_dataset.custom_instance_prompts: - instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings( - args.instance_prompt, text_encoders, tokenizers + instance_prompt_hidden_states, instance_pooled_prompt_embeds = ( + compute_text_embeddings(args.instance_prompt, text_encoders, tokenizers) ) # Handle class prompt for prior-preservation. if args.with_prior_preservation: if not args.train_text_encoder: - class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings( - args.class_prompt, text_encoders, tokenizers + class_prompt_hidden_states, class_pooled_prompt_embeds = ( + compute_text_embeddings(args.class_prompt, text_encoders, tokenizers) ) # Clear the memory here if not args.train_text_encoder and not train_dataset.custom_instance_prompts: # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection clear_objs_and_retain_memory( - objs=[tokenizers, text_encoders, text_encoder_one, text_encoder_two, text_encoder_three] + objs=[ + tokenizers, + text_encoders, + text_encoder_one, + text_encoder_two, + text_encoder_three, + ] ) # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), @@ -1465,8 +1637,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): prompt_embeds = instance_prompt_hidden_states pooled_prompt_embeds = instance_pooled_prompt_embeds if args.with_prior_preservation: - prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) - pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0) + prompt_embeds = torch.cat( + [prompt_embeds, class_prompt_hidden_states], dim=0 + ) + pooled_prompt_embeds = torch.cat( + [pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0 + ) # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the # batch prompts on all training steps else: @@ -1483,7 +1659,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Scheduler and math around the number of training steps. overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True @@ -1508,7 +1686,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): train_dataloader, lr_scheduler, ) = accelerator.prepare( - transformer, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler + transformer, + text_encoder_one, + text_encoder_two, + optimizer, + train_dataloader, + lr_scheduler, ) assert text_encoder_one is not None assert text_encoder_two is not None @@ -1519,7 +1702,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs @@ -1532,14 +1717,20 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): accelerator.init_trackers(tracker_name, config=vars(args)) # Train! - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + total_batch_size = ( + args.train_batch_size + * accelerator.num_processes + * args.gradient_accumulation_steps + ) logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num batches each epoch = {len(train_dataloader)}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info( + f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" + ) logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") global_step = 0 @@ -1599,8 +1790,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): text_encoder_two.train() # set top parameter requires_grad = True for gradient checkpointing works - accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) - accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True) + accelerator.unwrap_model( + text_encoder_one + ).text_model.embeddings.requires_grad_(True) + accelerator.unwrap_model( + text_encoder_two + ).text_model.embeddings.requires_grad_(True) for step, batch in enumerate(train_dataloader): models_to_accumulate = [transformer] @@ -1619,7 +1814,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): tokens_two = tokenize_prompt(tokenizer_two, prompts) tokens_three = tokenize_prompt(tokenizer_three, prompts) prompt_embeds, pooled_prompt_embeds = encode_prompt( - text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three], + text_encoders=[ + text_encoder_one, + text_encoder_two, + text_encoder_three, + ], tokenizers=[None, None, None], prompt=prompts, max_sequence_length=args.max_sequence_length, @@ -1628,7 +1827,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): else: if args.train_text_encoder: prompt_embeds, pooled_prompt_embeds = encode_prompt( - text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three], + text_encoders=[ + text_encoder_one, + text_encoder_two, + text_encoder_three, + ], tokenizers=[None, None, tokenizer_three], prompt=args.instance_prompt, max_sequence_length=args.max_sequence_length, @@ -1637,7 +1840,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Convert images to latent space model_input = vae.encode(pixel_values).latent_dist.sample() - model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor + model_input = ( + model_input - vae.config.shift_factor + ) * vae.config.scaling_factor model_input = model_input.to(dtype=weight_dtype) # Sample noise that we'll add to the latents @@ -1654,11 +1859,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): mode_scale=args.mode_scale, ) indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() - timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) + timesteps = noise_scheduler_copy.timesteps[indices].to( + device=model_input.device + ) # Add noise according to flow matching. # zt = (1 - texp) * x + texp * z1 - sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) + sigmas = get_sigmas( + timesteps, n_dim=model_input.ndim, dtype=model_input.dtype + ) noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise # Predict the noise residual @@ -1677,7 +1886,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss - weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + weighting = compute_loss_weighting_for_sd3( + weighting_scheme=args.weighting_scheme, sigmas=sigmas + ) # flow matching loss if args.precondition_outputs: @@ -1692,16 +1903,19 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Compute prior loss prior_loss = torch.mean( - (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( - target_prior.shape[0], -1 - ), + ( + weighting.float() + * (model_pred_prior.float() - target_prior.float()) ** 2 + ).reshape(target_prior.shape[0], -1), 1, ) prior_loss = prior_loss.mean() # Compute regular loss. loss = torch.mean( - (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + ( + weighting.float() * (model_pred.float() - target.float()) ** 2 + ).reshape(target.shape[0], -1), 1, ) loss = loss.mean() @@ -1714,7 +1928,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if accelerator.sync_gradients: params_to_clip = ( itertools.chain( - transformer_lora_parameters, text_lora_parameters_one, text_lora_parameters_two + transformer_lora_parameters, + text_lora_parameters_one, + text_lora_parameters_two, ) if args.train_text_encoder else transformer_lora_parameters @@ -1735,24 +1951,36 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: checkpoints = os.listdir(args.output_dir) - checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] - checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + checkpoints = [ + d for d in checkpoints if d.startswith("checkpoint") + ] + checkpoints = sorted( + checkpoints, key=lambda x: int(x.split("-")[1]) + ) # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints if len(checkpoints) >= args.checkpoints_total_limit: - num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + num_to_remove = ( + len(checkpoints) - args.checkpoints_total_limit + 1 + ) removing_checkpoints = checkpoints[0:num_to_remove] logger.info( f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" ) - logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + logger.info( + f"removing checkpoints: {', '.join(removing_checkpoints)}" + ) for removing_checkpoint in removing_checkpoints: - removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + removing_checkpoint = os.path.join( + args.output_dir, removing_checkpoint + ) shutil.rmtree(removing_checkpoint) - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + save_path = os.path.join( + args.output_dir, f"checkpoint-{global_step}" + ) accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") @@ -1764,11 +1992,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): break if accelerator.is_main_process: - if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + if ( + args.validation_prompt is not None + and epoch % args.validation_epochs == 0 + ): if not args.train_text_encoder: # create pipeline - text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders( - text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three + text_encoder_one, text_encoder_two, text_encoder_three = ( + load_text_encoders( + text_encoder_cls_one, + text_encoder_cls_two, + text_encoder_cls_three, + ) ) pipeline = StableDiffusion3Pipeline.from_pretrained( args.pretrained_model_name_or_path, @@ -1788,10 +2023,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator=accelerator, pipeline_args=pipeline_args, epoch=epoch, + torch_dtype=weight_dtype, ) objs = [] if not args.train_text_encoder: - objs.extend([text_encoder_one, text_encoder_two, text_encoder_three]) + objs.extend( + [text_encoder_one, text_encoder_two, text_encoder_three] + ) clear_objs_and_retain_memory(objs=objs) @@ -1804,9 +2042,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.train_text_encoder: text_encoder_one = unwrap_model(text_encoder_one) - text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32)) + text_encoder_lora_layers = get_peft_model_state_dict( + text_encoder_one.to(torch.float32) + ) text_encoder_two = unwrap_model(text_encoder_two) - text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two.to(torch.float32)) + text_encoder_2_lora_layers = get_peft_model_state_dict( + text_encoder_two.to(torch.float32) + ) else: text_encoder_lora_layers = None text_encoder_2_lora_layers = None @@ -1840,6 +2082,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): pipeline_args=pipeline_args, epoch=epoch, is_final_validation=True, + torch_dtype=weight_dtype, ) if args.push_to_hub: diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 17cc00db9525..05218041750c 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -33,7 +33,11 @@ import transformers from accelerate import Accelerator from accelerate.logging import get_logger -from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from accelerate.utils import ( + DistributedDataParallelKwargs, + ProjectConfiguration, + set_seed, +) from huggingface_hub import create_repo, hf_hub_download, upload_folder from huggingface_hub.utils import insecure_hashlib from packaging import version @@ -60,7 +64,11 @@ ) from diffusers.loaders import StableDiffusionLoraLoaderMixin from diffusers.optimization import get_scheduler -from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr +from diffusers.training_utils import ( + _set_state_dict_into_text_encoder, + cast_training_params, + compute_snr, +) from diffusers.utils import ( check_min_version, convert_all_state_dict_to_peft, @@ -73,7 +81,6 @@ from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.torch_utils import is_compiled_module - if is_wandb_available(): import wandb @@ -89,7 +96,9 @@ def determine_scheduler_type(pretrained_model_name_or_path, revision): model_index = os.path.join(pretrained_model_name_or_path, model_index_filename) else: model_index = hf_hub_download( - repo_id=pretrained_model_name_or_path, filename=model_index_filename, revision=revision + repo_id=pretrained_model_name_or_path, + filename=model_index_filename, + revision=revision, ) with open(model_index, "r") as f: @@ -113,7 +122,10 @@ def save_model_card( for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) widget_dict.append( - {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} + { + "text": validation_prompt if validation_prompt else " ", + "output": {"url": f"image_{i}.png"}, + } ) model_description = f""" @@ -151,7 +163,11 @@ def save_model_card( model_card = load_or_create_model_card( repo_id_or_path=repo_id, from_training=True, - license="openrail++" if "playground" not in base_model else "playground-v2dot5-community", + license=( + "openrail++" + if "playground" not in base_model + else "playground-v2dot5-community" + ), base_model=base_model, prompt=instance_prompt, model_description=model_description, @@ -180,6 +196,7 @@ def log_validation( accelerator, pipeline_args, epoch, + torch_dtype, is_final_validation=False, ): logger.info( @@ -199,22 +216,34 @@ def log_validation( scheduler_args["variance_type"] = variance_type - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) + pipeline.scheduler = DPMSolverMultistepScheduler.from_config( + pipeline.scheduler.config, **scheduler_args + ) - pipeline = pipeline.to(accelerator.device) + pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) pipeline.set_progress_bar_config(disable=True) # run inference - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + generator = ( + torch.Generator(device=accelerator.device).manual_seed(args.seed) + if args.seed + else None + ) # Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better # way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051 - if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path: + if ( + torch.backends.mps.is_available() + or "playground" in args.pretrained_model_name_or_path + ): autocast_ctx = nullcontext() else: autocast_ctx = torch.autocast(accelerator.device.type) with autocast_ctx: - images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] + images = [ + pipeline(**pipeline_args, generator=generator).images[0] + for _ in range(args.num_validation_images) + ] for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" @@ -225,7 +254,8 @@ def log_validation( tracker.log( { phase_name: [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) ] } ) @@ -330,7 +360,12 @@ def parse_args(input_args=None): help="The column of the dataset containing the instance prompt for each image", ) - parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") + parser.add_argument( + "--repeats", + type=int, + default=1, + help="How many times to repeat the training data.", + ) parser.add_argument( "--class_data_dir", @@ -385,7 +420,12 @@ def parse_args(input_args=None): action="store_true", help="Flag to add prior preservation loss.", ) - parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--prior_loss_weight", + type=float, + default=1.0, + help="The weight of prior preservation loss.", + ) parser.add_argument( "--num_class_images", type=int, @@ -406,7 +446,9 @@ def parse_args(input_args=None): action="store_true", help="Flag to additionally generate final state dict in the Kohya format so that it becomes compatible with A111, Comfy, Kohya, etc.", ) - parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--seed", type=int, default=None, help="A seed for reproducible training." + ) parser.add_argument( "--resolution", type=int, @@ -436,10 +478,16 @@ def parse_args(input_args=None): help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", ) parser.add_argument( - "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + "--train_batch_size", + type=int, + default=4, + help="Batch size (per device) for the training dataloader.", ) parser.add_argument( - "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + "--sample_batch_size", + type=int, + default=4, + help="Batch size (per device) for sampling images.", ) parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( @@ -521,7 +569,10 @@ def parse_args(input_args=None): "More details here: https://arxiv.org/abs/2303.09556.", ) parser.add_argument( - "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + "--lr_warmup_steps", + type=int, + default=500, + help="Number of steps for the warmup in the lr scheduler.", ) parser.add_argument( "--lr_num_cycles", @@ -529,7 +580,12 @@ def parse_args(input_args=None): default=1, help="Number of hard resets of the lr in cosine_with_restarts scheduler.", ) - parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--lr_power", + type=float, + default=1.0, + help="Power factor of the polynomial scheduler.", + ) parser.add_argument( "--dataloader_num_workers", type=int, @@ -553,10 +609,16 @@ def parse_args(input_args=None): ) parser.add_argument( - "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + "--adam_beta1", + type=float, + default=0.9, + help="The beta1 parameter for the Adam and Prodigy optimizers.", ) parser.add_argument( - "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." + "--adam_beta2", + type=float, + default=0.999, + help="The beta2 parameter for the Adam and Prodigy optimizers.", ) parser.add_argument( "--prodigy_beta3", @@ -565,10 +627,23 @@ def parse_args(input_args=None): help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " "uses the value of square root of beta2. Ignored if optimizer is adamW", ) - parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") - parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") parser.add_argument( - "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" + "--prodigy_decouple", + type=bool, + default=True, + help="Use AdamW style decoupled weight decay", + ) + parser.add_argument( + "--adam_weight_decay", + type=float, + default=1e-04, + help="Weight decay to use for unet params", + ) + parser.add_argument( + "--adam_weight_decay_text_encoder", + type=float, + default=1e-03, + help="Weight decay to use for text_encoder", ) parser.add_argument( @@ -591,9 +666,20 @@ def parse_args(input_args=None): help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " "Ignored if optimizer is adamW", ) - parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") - parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") - parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." + ) + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether or not to push the model to the Hub.", + ) + parser.add_argument( + "--hub_token", + type=str, + default=None, + help="The token to use to push to the Model Hub.", + ) parser.add_argument( "--hub_model_id", type=str, @@ -647,9 +733,16 @@ def parse_args(input_args=None): " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." ), ) - parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument( - "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + "--local_rank", + type=int, + default=-1, + help="For distributed training: local_rank", + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", + action="store_true", + help="Whether or not to use xformers.", ) parser.add_argument( "--rank", @@ -676,7 +769,9 @@ def parse_args(input_args=None): raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") if args.dataset_name is not None and args.instance_data_dir is not None: - raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") + raise ValueError( + "Specify only one of `--dataset_name` or `--instance_data_dir`" + ) env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: @@ -690,9 +785,13 @@ def parse_args(input_args=None): else: # logger is not available yet if args.class_data_dir is not None: - warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + warnings.warn( + "You need not use --class_data_dir without --with_prior_preservation." + ) if args.class_prompt is not None: - warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + warnings.warn( + "You need not use --class_prompt without --with_prior_preservation." + ) return args @@ -771,13 +870,17 @@ def __init__( # create final list of captions according to --repeats self.custom_instance_prompts = [] for caption in custom_instance_prompts: - self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) + self.custom_instance_prompts.extend( + itertools.repeat(caption, repeats) + ) else: self.instance_data_root = Path(instance_data_root) if not self.instance_data_root.exists(): raise ValueError("Instance images root doesn't exists.") - instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] + instance_images = [ + Image.open(path) for path in list(Path(instance_data_root).iterdir()) + ] self.custom_instance_prompts = None self.instance_images = [] @@ -788,8 +891,12 @@ def __init__( self.original_sizes = [] self.crop_top_lefts = [] self.pixel_values = [] - train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) - train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) + train_resize = transforms.Resize( + size, interpolation=transforms.InterpolationMode.BILINEAR + ) + train_crop = ( + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) + ) train_flip = transforms.RandomHorizontalFlip(p=1.0) train_transforms = transforms.Compose( [ @@ -811,7 +918,9 @@ def __init__( x1 = max(0, int(round((image.width - args.resolution) / 2.0))) image = train_crop(image) else: - y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) + y1, x1, h, w = train_crop.get_params( + image, (args.resolution, args.resolution) + ) image = crop(image, y1, x1, h, w) crop_top_left = (y1, x1) self.crop_top_lefts.append(crop_top_left) @@ -835,8 +944,14 @@ def __init__( self.image_transforms = transforms.Compose( [ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.Resize( + size, interpolation=transforms.InterpolationMode.BILINEAR + ), + ( + transforms.CenterCrop(size) + if center_crop + else transforms.RandomCrop(size) + ), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] @@ -865,7 +980,9 @@ def __getitem__(self, index): example["instance_prompt"] = self.instance_prompt if self.class_data_root: - class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = Image.open( + self.class_images_path[index % self.num_class_images] + ) class_image = exif_transpose(class_image) if not class_image.mode == "RGB": @@ -944,7 +1061,9 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None): text_input_ids = text_input_ids_list[i] prompt_embeds = text_encoder( - text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False + text_input_ids.to(text_encoder.device), + output_hidden_states=True, + return_dict=False, ) # We are only ALWAYS interested in the pooled output of the final text encoder @@ -967,7 +1086,9 @@ def main(args): ) if args.do_edm_style_training and args.snr_gamma is not None: - raise ValueError("Min-SNR formulation is not supported when conducting EDM-style training.") + raise ValueError( + "Min-SNR formulation is not supported when conducting EDM-style training." + ) if torch.backends.mps.is_available() and args.mixed_precision == "bf16": # due to pytorch#99272, MPS does not yet support bfloat16. @@ -977,7 +1098,9 @@ def main(args): logging_dir = Path(args.output_dir, args.logging_dir) - accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + accelerator_project_config = ProjectConfiguration( + project_dir=args.output_dir, logging_dir=logging_dir + ) kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -993,7 +1116,9 @@ def main(args): if args.report_to == "wandb": if not is_wandb_available(): - raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + raise ImportError( + "Make sure to install wandb if you want to use it for logging during training." + ) # Make one log on every process with the configuration for debugging. logging.basicConfig( @@ -1021,8 +1146,12 @@ def main(args): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available() - torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32 + has_supported_fp16_accelerator = ( + torch.cuda.is_available() or torch.backends.mps.is_available() + ) + torch_dtype = ( + torch.float16 if has_supported_fp16_accelerator else torch.float32 + ) if args.prior_generation_precision == "fp32": torch_dtype = torch.float32 elif args.prior_generation_precision == "fp16": @@ -1041,19 +1170,26 @@ def main(args): logger.info(f"Number of class images to sample: {num_new_images}.") sample_dataset = PromptDataset(args.class_prompt, num_new_images) - sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + sample_dataloader = torch.utils.data.DataLoader( + sample_dataset, batch_size=args.sample_batch_size + ) sample_dataloader = accelerator.prepare(sample_dataloader) pipeline.to(accelerator.device) for example in tqdm( - sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + sample_dataloader, + desc="Generating class images", + disable=not accelerator.is_local_main_process, ): images = pipeline(example["prompt"]).images for i, image in enumerate(images): hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() - image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image_filename = ( + class_images_dir + / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + ) image.save(image_filename) del pipeline @@ -1067,7 +1203,9 @@ def main(args): if args.push_to_hub: repo_id = create_repo( - repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + token=args.hub_token, ).repo_id # Load the tokenizers @@ -1093,10 +1231,14 @@ def main(args): ) # Load scheduler and models - scheduler_type = determine_scheduler_type(args.pretrained_model_name_or_path, args.revision) + scheduler_type = determine_scheduler_type( + args.pretrained_model_name_or_path, args.revision + ) if "EDM" in scheduler_type: args.do_edm_style_training = True - noise_scheduler = EDMEulerScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + noise_scheduler = EDMEulerScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder="scheduler" + ) logger.info("Performing EDM-style training!") elif args.do_edm_style_training: noise_scheduler = EulerDiscreteScheduler.from_pretrained( @@ -1104,13 +1246,21 @@ def main(args): ) logger.info("Performing EDM-style training!") else: - noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + noise_scheduler = DDPMScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder="scheduler" + ) text_encoder_one = text_encoder_cls_one.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + variant=args.variant, ) text_encoder_two = text_encoder_cls_two.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant + args.pretrained_model_name_or_path, + subfolder="text_encoder_2", + revision=args.revision, + variant=args.variant, ) vae_path = ( args.pretrained_model_name_or_path @@ -1130,7 +1280,10 @@ def main(args): latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1) unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant + args.pretrained_model_name_or_path, + subfolder="unet", + revision=args.revision, + variant=args.variant, ) # We only train the additional adapter LoRA layers @@ -1174,7 +1327,9 @@ def main(args): ) unet.enable_xformers_memory_efficient_attention() else: - raise ValueError("xformers is not available. Make sure it is installed correctly") + raise ValueError( + "xformers is not available. Make sure it is installed correctly" + ) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() @@ -1221,14 +1376,20 @@ def save_model_hook(models, weights, output_dir): for model in models: if isinstance(model, type(unwrap_model(unet))): - unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) - elif isinstance(model, type(unwrap_model(text_encoder_one))): - text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( + unet_lora_layers_to_save = convert_state_dict_to_diffusers( get_peft_model_state_dict(model) ) + elif isinstance(model, type(unwrap_model(text_encoder_one))): + text_encoder_one_lora_layers_to_save = ( + convert_state_dict_to_diffusers( + get_peft_model_state_dict(model) + ) + ) elif isinstance(model, type(unwrap_model(text_encoder_two))): - text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers( - get_peft_model_state_dict(model) + text_encoder_two_lora_layers_to_save = ( + convert_state_dict_to_diffusers( + get_peft_model_state_dict(model) + ) ) else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1260,11 +1421,19 @@ def load_model_hook(models, input_dir): else: raise ValueError(f"unexpected save model: {model.__class__}") - lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) + lora_state_dict, network_alphas = ( + StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) + ) - unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} + unet_state_dict = { + f'{k.replace("unet.", "")}': v + for k, v in lora_state_dict.items() + if k.startswith("unet.") + } unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) - incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") + incompatible_keys = set_peft_model_state_dict( + unet_, unet_state_dict, adapter_name="default" + ) if incompatible_keys is not None: # check only for unexpected keys unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) @@ -1276,10 +1445,14 @@ def load_model_hook(models, input_dir): if args.train_text_encoder: # Do we need to call `scale_lora_layers()` here? - _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) + _set_state_dict_into_text_encoder( + lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_ + ) _set_state_dict_into_text_encoder( - lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_ + lora_state_dict, + prefix="text_encoder_2.", + text_encoder=text_encoder_two_, ) # Make sure the trainable params are in float32. This is again needed since the base models @@ -1302,7 +1475,10 @@ def load_model_hook(models, input_dir): if args.scale_lr: args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + args.learning_rate + * args.gradient_accumulation_steps + * args.train_batch_size + * accelerator.num_processes ) # Make sure the trainable params are in float32. @@ -1317,11 +1493,18 @@ def load_model_hook(models, input_dir): unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters())) if args.train_text_encoder: - text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters())) - text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters())) + text_lora_parameters_one = list( + filter(lambda p: p.requires_grad, text_encoder_one.parameters()) + ) + text_lora_parameters_two = list( + filter(lambda p: p.requires_grad, text_encoder_two.parameters()) + ) # Optimization parameters - unet_lora_parameters_with_lr = {"params": unet_lora_parameters, "lr": args.learning_rate} + unet_lora_parameters_with_lr = { + "params": unet_lora_parameters, + "lr": args.learning_rate, + } if args.train_text_encoder: # different learning rate for text encoder and unet text_lora_parameters_one_with_lr = { @@ -1380,7 +1563,9 @@ def load_model_hook(models, input_dir): try: import prodigyopt except ImportError: - raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + raise ImportError( + "To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`" + ) optimizer_class = prodigyopt.Prodigy @@ -1450,7 +1635,9 @@ def compute_time_ids(original_size, crops_coords_top_left): def compute_text_embeddings(prompt, text_encoders, tokenizers): with torch.no_grad(): - prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt) + prompt_embeds, pooled_prompt_embeds = encode_prompt( + text_encoders, tokenizers, prompt + ) prompt_embeds = prompt_embeds.to(accelerator.device) pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) return prompt_embeds, pooled_prompt_embeds @@ -1459,15 +1646,15 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid # the redundant encoding. if not args.train_text_encoder and not train_dataset.custom_instance_prompts: - instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings( - args.instance_prompt, text_encoders, tokenizers + instance_prompt_hidden_states, instance_pooled_prompt_embeds = ( + compute_text_embeddings(args.instance_prompt, text_encoders, tokenizers) ) # Handle class prompt for prior-preservation. if args.with_prior_preservation: if not args.train_text_encoder: - class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings( - args.class_prompt, text_encoders, tokenizers + class_prompt_hidden_states, class_pooled_prompt_embeds = ( + compute_text_embeddings(args.class_prompt, text_encoders, tokenizers) ) # Clear the memory here @@ -1486,8 +1673,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): prompt_embeds = instance_prompt_hidden_states unet_add_text_embeds = instance_pooled_prompt_embeds if args.with_prior_preservation: - prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) - unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0) + prompt_embeds = torch.cat( + [prompt_embeds, class_prompt_hidden_states], dim=0 + ) + unet_add_text_embeds = torch.cat( + [unet_add_text_embeds, class_pooled_prompt_embeds], dim=0 + ) # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the # batch prompts on all training steps else: @@ -1501,7 +1692,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Scheduler and math around the number of training steps. overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True @@ -1517,8 +1710,20 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Prepare everything with our `accelerator`. if args.train_text_encoder: - unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler + ( + unet, + text_encoder_one, + text_encoder_two, + optimizer, + train_dataloader, + lr_scheduler, + ) = accelerator.prepare( + unet, + text_encoder_one, + text_encoder_two, + optimizer, + train_dataloader, + lr_scheduler, ) else: unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( @@ -1526,7 +1731,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs @@ -1543,14 +1750,20 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): accelerator.init_trackers(tracker_name, config=vars(args)) # Train! - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + total_batch_size = ( + args.train_batch_size + * accelerator.num_processes + * args.gradient_accumulation_steps + ) logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num batches each epoch = {len(train_dataloader)}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info( + f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" + ) logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") global_step = 0 @@ -1611,8 +1824,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): text_encoder_two.train() # set top parameter requires_grad = True for gradient checkpointing works - accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) - accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True) + accelerator.unwrap_model( + text_encoder_one + ).text_model.embeddings.requires_grad_(True) + accelerator.unwrap_model( + text_encoder_two + ).text_model.embeddings.requires_grad_(True) for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): @@ -1637,9 +1854,17 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.pretrained_vae_model_name_or_path is None: model_input = model_input.to(weight_dtype) else: - latents_mean = latents_mean.to(device=model_input.device, dtype=model_input.dtype) - latents_std = latents_std.to(device=model_input.device, dtype=model_input.dtype) - model_input = (model_input - latents_mean) * vae.config.scaling_factor / latents_std + latents_mean = latents_mean.to( + device=model_input.device, dtype=model_input.dtype + ) + latents_std = latents_std.to( + device=model_input.device, dtype=model_input.dtype + ) + model_input = ( + (model_input - latents_mean) + * vae.config.scaling_factor + / latents_std + ) model_input = model_input.to(dtype=weight_dtype) # Sample noise that we'll add to the latents @@ -1649,26 +1874,39 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Sample a random timestep for each image if not args.do_edm_style_training: timesteps = torch.randint( - 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device + 0, + noise_scheduler.config.num_train_timesteps, + (bsz,), + device=model_input.device, ) timesteps = timesteps.long() else: # in EDM formulation, the model is conditioned on the pre-conditioned noise levels # instead of discrete timesteps, so here we sample indices to get the noise levels # from `scheduler.timesteps` - indices = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,)) - timesteps = noise_scheduler.timesteps[indices].to(device=model_input.device) + indices = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,) + ) + timesteps = noise_scheduler.timesteps[indices].to( + device=model_input.device + ) # Add noise to the model input according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) + noisy_model_input = noise_scheduler.add_noise( + model_input, noise, timesteps + ) # For EDM-style training, we first obtain the sigmas based on the continuous timesteps. # We then precondition the final model inputs based on these sigmas instead of the timesteps. # Follow: Section 5 of https://arxiv.org/abs/2206.00364. if args.do_edm_style_training: - sigmas = get_sigmas(timesteps, len(noisy_model_input.shape), noisy_model_input.dtype) + sigmas = get_sigmas( + timesteps, len(noisy_model_input.shape), noisy_model_input.dtype + ) if "EDM" in scheduler_type: - inp_noisy_latents = noise_scheduler.precondition_inputs(noisy_model_input, sigmas) + inp_noisy_latents = noise_scheduler.precondition_inputs( + noisy_model_input, sigmas + ) else: inp_noisy_latents = noisy_model_input / ((sigmas**2 + 1) ** 0.5) @@ -1676,13 +1914,17 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): add_time_ids = torch.cat( [ compute_time_ids(original_size=s, crops_coords_top_left=c) - for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"]) + for s, c in zip( + batch["original_sizes"], batch["crop_top_lefts"] + ) ] ) # Calculate the elements to repeat depending on the use of prior-preservation and custom captions. if not train_dataset.custom_instance_prompts: - elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz + elems_to_repeat_text_embeds = ( + bsz // 2 if args.with_prior_preservation else bsz + ) else: elems_to_repeat_text_embeds = 1 @@ -1690,11 +1932,19 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if not args.train_text_encoder: unet_added_conditions = { "time_ids": add_time_ids, - "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1), + "text_embeds": unet_add_text_embeds.repeat( + elems_to_repeat_text_embeds, 1 + ), } - prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) + prompt_embeds_input = prompt_embeds.repeat( + elems_to_repeat_text_embeds, 1, 1 + ) model_pred = unet( - inp_noisy_latents if args.do_edm_style_training else noisy_model_input, + ( + inp_noisy_latents + if args.do_edm_style_training + else noisy_model_input + ), timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions, @@ -1709,11 +1959,21 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): text_input_ids_list=[tokens_one, tokens_two], ) unet_added_conditions.update( - {"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)} + { + "text_embeds": pooled_prompt_embeds.repeat( + elems_to_repeat_text_embeds, 1 + ) + } + ) + prompt_embeds_input = prompt_embeds.repeat( + elems_to_repeat_text_embeds, 1, 1 ) - prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) model_pred = unet( - inp_noisy_latents if args.do_edm_style_training else noisy_model_input, + ( + inp_noisy_latents + if args.do_edm_style_training + else noisy_model_input + ), timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions, @@ -1726,14 +1986,16 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # on noised model inputs (before preconditioning) and the sigmas. # Follow: Section 5 of https://arxiv.org/abs/2206.00364. if "EDM" in scheduler_type: - model_pred = noise_scheduler.precondition_outputs(noisy_model_input, model_pred, sigmas) + model_pred = noise_scheduler.precondition_outputs( + noisy_model_input, model_pred, sigmas + ) else: if noise_scheduler.config.prediction_type == "epsilon": model_pred = model_pred * (-sigmas) + noisy_model_input elif noise_scheduler.config.prediction_type == "v_prediction": - model_pred = model_pred * (-sigmas / (sigmas**2 + 1) ** 0.5) + ( - noisy_model_input / (sigmas**2 + 1) - ) + model_pred = model_pred * ( + -sigmas / (sigmas**2 + 1) ** 0.5 + ) + (noisy_model_input / (sigmas**2 + 1)) # We are not doing weighting here because it tends result in numerical problems. # See: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051 # There might be other alternatives for weighting as well: @@ -1751,7 +2013,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): else noise_scheduler.get_velocity(model_input, noise, timesteps) ) else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + raise ValueError( + f"Unknown prediction type {noise_scheduler.config.prediction_type}" + ) if args.with_prior_preservation: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. @@ -1761,33 +2025,44 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Compute prior loss if weighting is not None: prior_loss = torch.mean( - (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( - target_prior.shape[0], -1 - ), + ( + weighting.float() + * (model_pred_prior.float() - target_prior.float()) ** 2 + ).reshape(target_prior.shape[0], -1), 1, ) prior_loss = prior_loss.mean() else: - prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + prior_loss = F.mse_loss( + model_pred_prior.float(), + target_prior.float(), + reduction="mean", + ) if args.snr_gamma is None: if weighting is not None: loss = torch.mean( - (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape( - target.shape[0], -1 - ), + ( + weighting.float() + * (model_pred.float() - target.float()) ** 2 + ).reshape(target.shape[0], -1), 1, ) loss = loss.mean() else: - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + loss = F.mse_loss( + model_pred.float(), target.float(), reduction="mean" + ) else: # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) base_weight = ( - torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr + torch.stack( + [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 + ).min(dim=1)[0] + / snr ) if noise_scheduler.config.prediction_type == "v_prediction": @@ -1797,8 +2072,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Epsilon and sample both use the same loss weights. mse_loss_weights = base_weight - loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") - loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = F.mse_loss( + model_pred.float(), target.float(), reduction="none" + ) + loss = ( + loss.mean(dim=list(range(1, len(loss.shape)))) + * mse_loss_weights + ) loss = loss.mean() if args.with_prior_preservation: @@ -1808,7 +2088,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator.backward(loss) if accelerator.sync_gradients: params_to_clip = ( - itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two) + itertools.chain( + unet_lora_parameters, + text_lora_parameters_one, + text_lora_parameters_two, + ) if args.train_text_encoder else unet_lora_parameters ) @@ -1828,24 +2112,36 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: checkpoints = os.listdir(args.output_dir) - checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] - checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + checkpoints = [ + d for d in checkpoints if d.startswith("checkpoint") + ] + checkpoints = sorted( + checkpoints, key=lambda x: int(x.split("-")[1]) + ) # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints if len(checkpoints) >= args.checkpoints_total_limit: - num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + num_to_remove = ( + len(checkpoints) - args.checkpoints_total_limit + 1 + ) removing_checkpoints = checkpoints[0:num_to_remove] logger.info( f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" ) - logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + logger.info( + f"removing checkpoints: {', '.join(removing_checkpoints)}" + ) for removing_checkpoint in removing_checkpoints: - removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + removing_checkpoint = os.path.join( + args.output_dir, removing_checkpoint + ) shutil.rmtree(removing_checkpoint) - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + save_path = os.path.join( + args.output_dir, f"checkpoint-{global_step}" + ) accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") @@ -1857,7 +2153,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): break if accelerator.is_main_process: - if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + if ( + args.validation_prompt is not None + and epoch % args.validation_epochs == 0 + ): # create pipeline if not args.train_text_encoder: text_encoder_one = text_encoder_cls_one.from_pretrained( @@ -1890,6 +2189,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator, pipeline_args, epoch, + torch_dtype=weight_dtype, ) # Save the lora layers @@ -1897,7 +2197,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if accelerator.is_main_process: unet = unwrap_model(unet) unet = unet.to(torch.float32) - unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet)) + unet_lora_layers = convert_state_dict_to_diffusers( + get_peft_model_state_dict(unet) + ) if args.train_text_encoder: text_encoder_one = unwrap_model(text_encoder_one) @@ -1919,10 +2221,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): text_encoder_2_lora_layers=text_encoder_2_lora_layers, ) if args.output_kohya_format: - lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors") + lora_state_dict = load_file( + f"{args.output_dir}/pytorch_lora_weights.safetensors" + ) peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict) kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict) - save_file(kohya_state_dict, f"{args.output_dir}/pytorch_lora_weights_kohya.safetensors") + save_file( + kohya_state_dict, + f"{args.output_dir}/pytorch_lora_weights_kohya.safetensors", + ) # Final inference # Load previous pipeline @@ -1947,7 +2254,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # run inference images = [] if args.validation_prompt and args.num_validation_images > 0: - pipeline_args = {"prompt": args.validation_prompt, "num_inference_steps": 25} + pipeline_args = { + "prompt": args.validation_prompt, + "num_inference_steps": 25, + } images = log_validation( pipeline, args, @@ -1955,6 +2265,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): pipeline_args, epoch, is_final_validation=True, + torch_dtype=weight_dtype, ) if args.push_to_hub: diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index 985814205d06..4095447d4765 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -32,7 +32,11 @@ import transformers from accelerate import Accelerator from accelerate.logging import get_logger -from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from accelerate.utils import ( + DistributedDataParallelKwargs, + ProjectConfiguration, + set_seed, +) from huggingface_hub import create_repo, upload_folder from huggingface_hub.utils import insecure_hashlib from PIL import Image @@ -41,7 +45,13 @@ from torchvision import transforms from torchvision.transforms.functional import crop from tqdm.auto import tqdm -from transformers import CLIPTextModelWithProjection, CLIPTokenizer, PretrainedConfig, T5EncoderModel, T5TokenizerFast +from transformers import ( + CLIPTextModelWithProjection, + CLIPTokenizer, + PretrainedConfig, + T5EncoderModel, + T5TokenizerFast, +) import diffusers from diffusers import ( @@ -51,15 +61,14 @@ StableDiffusion3Pipeline, ) from diffusers.optimization import get_scheduler -from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3 -from diffusers.utils import ( - check_min_version, - is_wandb_available, +from diffusers.training_utils import ( + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3, ) +from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.torch_utils import is_compiled_module - if is_wandb_available(): import wandb @@ -83,7 +92,10 @@ def save_model_card( for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) widget_dict.append( - {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} + { + "text": validation_prompt if validation_prompt else " ", + "output": {"url": f"image_{i}.png"}, + } ) model_description = f""" @@ -140,13 +152,22 @@ def save_model_card( def load_text_encoders(class_one, class_two, class_three): text_encoder_one = class_one.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + variant=args.variant, ) text_encoder_two = class_two.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant + args.pretrained_model_name_or_path, + subfolder="text_encoder_2", + revision=args.revision, + variant=args.variant, ) text_encoder_three = class_three.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder_3", revision=args.revision, variant=args.variant + args.pretrained_model_name_or_path, + subfolder="text_encoder_3", + revision=args.revision, + variant=args.variant, ) return text_encoder_one, text_encoder_two, text_encoder_three @@ -157,22 +178,30 @@ def log_validation( accelerator, pipeline_args, epoch, + torch_dtype, is_final_validation=False, ): logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) - pipeline = pipeline.to(accelerator.device) + pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) pipeline.set_progress_bar_config(disable=True) # run inference - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + generator = ( + torch.Generator(device=accelerator.device).manual_seed(args.seed) + if args.seed + else None + ) # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() autocast_ctx = nullcontext() with autocast_ctx: - images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] + images = [ + pipeline(**pipeline_args, generator=generator).images[0] + for _ in range(args.num_validation_images) + ] for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" @@ -183,7 +212,8 @@ def log_validation( tracker.log( { phase_name: [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) ] } ) @@ -281,7 +311,12 @@ def parse_args(input_args=None): help="The column of the dataset containing the instance prompt for each image", ) - parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") + parser.add_argument( + "--repeats", + type=int, + default=1, + help="How many times to repeat the training data.", + ) parser.add_argument( "--class_data_dir", @@ -336,7 +371,12 @@ def parse_args(input_args=None): action="store_true", help="Flag to add prior preservation loss.", ) - parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--prior_loss_weight", + type=float, + default=1.0, + help="The weight of prior preservation loss.", + ) parser.add_argument( "--num_class_images", type=int, @@ -352,7 +392,9 @@ def parse_args(input_args=None): default="sd3-dreambooth", help="The output directory where the model predictions and checkpoints will be written.", ) - parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--seed", type=int, default=None, help="A seed for reproducible training." + ) parser.add_argument( "--resolution", type=int, @@ -382,10 +424,16 @@ def parse_args(input_args=None): help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", ) parser.add_argument( - "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + "--train_batch_size", + type=int, + default=4, + help="Batch size (per device) for the training dataloader.", ) parser.add_argument( - "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + "--sample_batch_size", + type=int, + default=4, + help="Batch size (per device) for sampling images.", ) parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( @@ -459,7 +507,10 @@ def parse_args(input_args=None): ), ) parser.add_argument( - "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + "--lr_warmup_steps", + type=int, + default=500, + help="Number of steps for the warmup in the lr scheduler.", ) parser.add_argument( "--lr_num_cycles", @@ -467,7 +518,12 @@ def parse_args(input_args=None): default=1, help="Number of hard resets of the lr in cosine_with_restarts scheduler.", ) - parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--lr_power", + type=float, + default=1.0, + help="Power factor of the polynomial scheduler.", + ) parser.add_argument( "--dataloader_num_workers", type=int, @@ -483,10 +539,16 @@ def parse_args(input_args=None): choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"], ) parser.add_argument( - "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + "--logit_mean", + type=float, + default=0.0, + help="mean to use when using the `'logit_normal'` weighting scheme.", ) parser.add_argument( - "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + "--logit_std", + type=float, + default=1.0, + help="std to use when using the `'logit_normal'` weighting scheme.", ) parser.add_argument( "--mode_scale", @@ -515,10 +577,16 @@ def parse_args(input_args=None): ) parser.add_argument( - "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + "--adam_beta1", + type=float, + default=0.9, + help="The beta1 parameter for the Adam and Prodigy optimizers.", ) parser.add_argument( - "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." + "--adam_beta2", + type=float, + default=0.999, + help="The beta2 parameter for the Adam and Prodigy optimizers.", ) parser.add_argument( "--prodigy_beta3", @@ -527,10 +595,23 @@ def parse_args(input_args=None): help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " "uses the value of square root of beta2. Ignored if optimizer is adamW", ) - parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") - parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") parser.add_argument( - "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" + "--prodigy_decouple", + type=bool, + default=True, + help="Use AdamW style decoupled weight decay", + ) + parser.add_argument( + "--adam_weight_decay", + type=float, + default=1e-04, + help="Weight decay to use for unet params", + ) + parser.add_argument( + "--adam_weight_decay_text_encoder", + type=float, + default=1e-03, + help="Weight decay to use for text_encoder", ) parser.add_argument( @@ -553,9 +634,20 @@ def parse_args(input_args=None): help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " "Ignored if optimizer is adamW", ) - parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") - parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") - parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." + ) + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether or not to push the model to the Hub.", + ) + parser.add_argument( + "--hub_token", + type=str, + default=None, + help="The token to use to push to the Model Hub.", + ) parser.add_argument( "--hub_model_id", type=str, @@ -609,7 +701,12 @@ def parse_args(input_args=None): " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." ), ) - parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--local_rank", + type=int, + default=-1, + help="For distributed training: local_rank", + ) if input_args is not None: args = parser.parse_args(input_args) @@ -620,7 +717,9 @@ def parse_args(input_args=None): raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") if args.dataset_name is not None and args.instance_data_dir is not None: - raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") + raise ValueError( + "Specify only one of `--dataset_name` or `--instance_data_dir`" + ) env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: @@ -634,9 +733,13 @@ def parse_args(input_args=None): else: # logger is not available yet if args.class_data_dir is not None: - warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + warnings.warn( + "You need not use --class_data_dir without --with_prior_preservation." + ) if args.class_prompt is not None: - warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + warnings.warn( + "You need not use --class_prompt without --with_prior_preservation." + ) return args @@ -715,13 +818,17 @@ def __init__( # create final list of captions according to --repeats self.custom_instance_prompts = [] for caption in custom_instance_prompts: - self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) + self.custom_instance_prompts.extend( + itertools.repeat(caption, repeats) + ) else: self.instance_data_root = Path(instance_data_root) if not self.instance_data_root.exists(): raise ValueError("Instance images root doesn't exists.") - instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] + instance_images = [ + Image.open(path) for path in list(Path(instance_data_root).iterdir()) + ] self.custom_instance_prompts = None self.instance_images = [] @@ -729,8 +836,12 @@ def __init__( self.instance_images.extend(itertools.repeat(img, repeats)) self.pixel_values = [] - train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) - train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) + train_resize = transforms.Resize( + size, interpolation=transforms.InterpolationMode.BILINEAR + ) + train_crop = ( + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) + ) train_flip = transforms.RandomHorizontalFlip(p=1.0) train_transforms = transforms.Compose( [ @@ -751,7 +862,9 @@ def __init__( x1 = max(0, int(round((image.width - args.resolution) / 2.0))) image = train_crop(image) else: - y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) + y1, x1, h, w = train_crop.get_params( + image, (args.resolution, args.resolution) + ) image = crop(image, y1, x1, h, w) image = train_transforms(image) self.pixel_values.append(image) @@ -773,8 +886,14 @@ def __init__( self.image_transforms = transforms.Compose( [ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.Resize( + size, interpolation=transforms.InterpolationMode.BILINEAR + ), + ( + transforms.CenterCrop(size) + if center_crop + else transforms.RandomCrop(size) + ), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] @@ -799,7 +918,9 @@ def __getitem__(self, index): example["instance_prompt"] = self.instance_prompt if self.class_data_root: - class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = Image.open( + self.class_images_path[index % self.num_class_images] + ) class_image = exif_transpose(class_image) if not class_image.mode == "RGB": @@ -962,7 +1083,8 @@ def encode_prompt( ) clip_prompt_embeds = torch.nn.functional.pad( - clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) + clip_prompt_embeds, + (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]), ) prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) @@ -984,7 +1106,9 @@ def main(args): logging_dir = Path(args.output_dir, args.logging_dir) - accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + accelerator_project_config = ProjectConfiguration( + project_dir=args.output_dir, logging_dir=logging_dir + ) kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -1000,7 +1124,9 @@ def main(args): if args.report_to == "wandb": if not is_wandb_available(): - raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + raise ImportError( + "Make sure to install wandb if you want to use it for logging during training." + ) # Make one log on every process with the configuration for debugging. logging.basicConfig( @@ -1028,8 +1154,12 @@ def main(args): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available() - torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32 + has_supported_fp16_accelerator = ( + torch.cuda.is_available() or torch.backends.mps.is_available() + ) + torch_dtype = ( + torch.float16 if has_supported_fp16_accelerator else torch.float32 + ) if args.prior_generation_precision == "fp32": torch_dtype = torch.float32 elif args.prior_generation_precision == "fp16": @@ -1048,19 +1178,26 @@ def main(args): logger.info(f"Number of class images to sample: {num_new_images}.") sample_dataset = PromptDataset(args.class_prompt, num_new_images) - sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + sample_dataloader = torch.utils.data.DataLoader( + sample_dataset, batch_size=args.sample_batch_size + ) sample_dataloader = accelerator.prepare(sample_dataloader) pipeline.to(accelerator.device) for example in tqdm( - sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + sample_dataloader, + desc="Generating class images", + disable=not accelerator.is_local_main_process, ): images = pipeline(example["prompt"]).images for i, image in enumerate(images): hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() - image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image_filename = ( + class_images_dir + / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + ) image.save(image_filename) del pipeline @@ -1121,7 +1258,10 @@ def main(args): variant=args.variant, ) transformer = SD3Transformer2DModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant + args.pretrained_model_name_or_path, + subfolder="transformer", + revision=args.revision, + variant=args.variant, ) transformer.requires_grad_(True) @@ -1172,16 +1312,26 @@ def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: for i, model in enumerate(models): if isinstance(unwrap_model(model), SD3Transformer2DModel): - unwrap_model(model).save_pretrained(os.path.join(output_dir, "transformer")) - elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)): + unwrap_model(model).save_pretrained( + os.path.join(output_dir, "transformer") + ) + elif isinstance( + unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel) + ): if isinstance(unwrap_model(model), CLIPTextModelWithProjection): hidden_size = unwrap_model(model).config.hidden_size if hidden_size == 768: - unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder")) + unwrap_model(model).save_pretrained( + os.path.join(output_dir, "text_encoder") + ) elif hidden_size == 1280: - unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder_2")) + unwrap_model(model).save_pretrained( + os.path.join(output_dir, "text_encoder_2") + ) else: - unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder_3")) + unwrap_model(model).save_pretrained( + os.path.join(output_dir, "text_encoder_3") + ) else: raise ValueError(f"Wrong model supplied: {type(model)=}.") @@ -1195,27 +1345,39 @@ def load_model_hook(models, input_dir): # load diffusers style into model if isinstance(unwrap_model(model), SD3Transformer2DModel): - load_model = SD3Transformer2DModel.from_pretrained(input_dir, subfolder="transformer") + load_model = SD3Transformer2DModel.from_pretrained( + input_dir, subfolder="transformer" + ) model.register_to_config(**load_model.config) model.load_state_dict(load_model.state_dict()) - elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)): + elif isinstance( + unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel) + ): try: - load_model = CLIPTextModelWithProjection.from_pretrained(input_dir, subfolder="text_encoder") + load_model = CLIPTextModelWithProjection.from_pretrained( + input_dir, subfolder="text_encoder" + ) model(**load_model.config) model.load_state_dict(load_model.state_dict()) except Exception: try: - load_model = CLIPTextModelWithProjection.from_pretrained(input_dir, subfolder="text_encoder_2") + load_model = CLIPTextModelWithProjection.from_pretrained( + input_dir, subfolder="text_encoder_2" + ) model(**load_model.config) model.load_state_dict(load_model.state_dict()) except Exception: try: - load_model = T5EncoderModel.from_pretrained(input_dir, subfolder="text_encoder_3") + load_model = T5EncoderModel.from_pretrained( + input_dir, subfolder="text_encoder_3" + ) model(**load_model.config) model.load_state_dict(load_model.state_dict()) except Exception: - raise ValueError(f"Couldn't load the model of type: ({type(model)}).") + raise ValueError( + f"Couldn't load the model of type: ({type(model)})." + ) else: raise ValueError(f"Unsupported model found: {type(model)=}") @@ -1231,11 +1393,17 @@ def load_model_hook(models, input_dir): if args.scale_lr: args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + args.learning_rate + * args.gradient_accumulation_steps + * args.train_batch_size + * accelerator.num_processes ) # Optimization parameters - transformer_parameters_with_lr = {"params": transformer.parameters(), "lr": args.learning_rate} + transformer_parameters_with_lr = { + "params": transformer.parameters(), + "lr": args.learning_rate, + } if args.train_text_encoder: # different learning rate for text encoder and unet text_parameters_one_with_lr = { @@ -1300,7 +1468,9 @@ def load_model_hook(models, input_dir): try: import prodigyopt except ImportError: - raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + raise ImportError( + "To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`" + ) optimizer_class = prodigyopt.Prodigy @@ -1369,15 +1539,15 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid # the redundant encoding. if not args.train_text_encoder and not train_dataset.custom_instance_prompts: - instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings( - args.instance_prompt, text_encoders, tokenizers + instance_prompt_hidden_states, instance_pooled_prompt_embeds = ( + compute_text_embeddings(args.instance_prompt, text_encoders, tokenizers) ) # Handle class prompt for prior-preservation. if args.with_prior_preservation: if not args.train_text_encoder: - class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings( - args.class_prompt, text_encoders, tokenizers + class_prompt_hidden_states, class_pooled_prompt_embeds = ( + compute_text_embeddings(args.class_prompt, text_encoders, tokenizers) ) # Clear the memory here @@ -1398,8 +1568,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): prompt_embeds = instance_prompt_hidden_states pooled_prompt_embeds = instance_pooled_prompt_embeds if args.with_prior_preservation: - prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) - pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0) + prompt_embeds = torch.cat( + [prompt_embeds, class_prompt_hidden_states], dim=0 + ) + pooled_prompt_embeds = torch.cat( + [pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0 + ) # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the # batch prompts on all training steps else: @@ -1416,7 +1590,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Scheduler and math around the number of training steps. overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True @@ -1455,7 +1631,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs @@ -1468,14 +1646,20 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): accelerator.init_trackers(tracker_name, config=vars(args)) # Train! - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + total_batch_size = ( + args.train_batch_size + * accelerator.num_processes + * args.gradient_accumulation_steps + ) logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num batches each epoch = {len(train_dataloader)}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info( + f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" + ) logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") global_step = 0 @@ -1538,7 +1722,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): for step, batch in enumerate(train_dataloader): models_to_accumulate = [transformer] if args.train_text_encoder: - models_to_accumulate.extend([text_encoder_one, text_encoder_two, text_encoder_three]) + models_to_accumulate.extend( + [text_encoder_one, text_encoder_two, text_encoder_three] + ) with accelerator.accumulate(models_to_accumulate): pixel_values = batch["pixel_values"].to(dtype=vae.dtype) prompts = batch["prompts"] @@ -1556,7 +1742,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Convert images to latent space model_input = vae.encode(pixel_values).latent_dist.sample() - model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor + model_input = ( + model_input - vae.config.shift_factor + ) * vae.config.scaling_factor model_input = model_input.to(dtype=weight_dtype) # Sample noise that we'll add to the latents @@ -1573,11 +1761,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): mode_scale=args.mode_scale, ) indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() - timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) + timesteps = noise_scheduler_copy.timesteps[indices].to( + device=model_input.device + ) # Add noise according to flow matching. # zt = (1 - texp) * x + texp * z1 - sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) + sigmas = get_sigmas( + timesteps, n_dim=model_input.ndim, dtype=model_input.dtype + ) noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise # Predict the noise residual @@ -1591,7 +1783,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): )[0] else: prompt_embeds, pooled_prompt_embeds = encode_prompt( - text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three], + text_encoders=[ + text_encoder_one, + text_encoder_two, + text_encoder_three, + ], tokenizers=None, prompt=None, text_input_ids_list=[tokens_one, tokens_two, tokens_three], @@ -1611,7 +1807,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss - weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + weighting = compute_loss_weighting_for_sd3( + weighting_scheme=args.weighting_scheme, sigmas=sigmas + ) # flow matching loss if args.precondition_outputs: @@ -1626,16 +1824,19 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Compute prior loss prior_loss = torch.mean( - (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( - target_prior.shape[0], -1 - ), + ( + weighting.float() + * (model_pred_prior.float() - target_prior.float()) ** 2 + ).reshape(target_prior.shape[0], -1), 1, ) prior_loss = prior_loss.mean() # Compute regular loss. loss = torch.mean( - (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + ( + weighting.float() * (model_pred.float() - target.float()) ** 2 + ).reshape(target.shape[0], -1), 1, ) loss = loss.mean() @@ -1672,24 +1873,36 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: checkpoints = os.listdir(args.output_dir) - checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] - checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + checkpoints = [ + d for d in checkpoints if d.startswith("checkpoint") + ] + checkpoints = sorted( + checkpoints, key=lambda x: int(x.split("-")[1]) + ) # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints if len(checkpoints) >= args.checkpoints_total_limit: - num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + num_to_remove = ( + len(checkpoints) - args.checkpoints_total_limit + 1 + ) removing_checkpoints = checkpoints[0:num_to_remove] logger.info( f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" ) - logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + logger.info( + f"removing checkpoints: {', '.join(removing_checkpoints)}" + ) for removing_checkpoint in removing_checkpoints: - removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + removing_checkpoint = os.path.join( + args.output_dir, removing_checkpoint + ) shutil.rmtree(removing_checkpoint) - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + save_path = os.path.join( + args.output_dir, f"checkpoint-{global_step}" + ) accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") @@ -1701,11 +1914,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): break if accelerator.is_main_process: - if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + if ( + args.validation_prompt is not None + and epoch % args.validation_epochs == 0 + ): # create pipeline if not args.train_text_encoder: - text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders( - text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three + text_encoder_one, text_encoder_two, text_encoder_three = ( + load_text_encoders( + text_encoder_cls_one, + text_encoder_cls_two, + text_encoder_cls_three, + ) ) pipeline = StableDiffusion3Pipeline.from_pretrained( args.pretrained_model_name_or_path, @@ -1725,6 +1945,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator=accelerator, pipeline_args=pipeline_args, epoch=epoch, + torch_dtype=weight_dtype, ) if not args.train_text_encoder: del text_encoder_one, text_encoder_two, text_encoder_three @@ -1775,6 +1996,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): pipeline_args=pipeline_args, epoch=epoch, is_final_validation=True, + torch_dtype=weight_dtype, ) if args.push_to_hub: From 5d6f5a9cc42c4fbc202a37386418f56c90270b66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=8B=E7=A1=95?= Date: Thu, 12 Sep 2024 13:11:19 +0800 Subject: [PATCH 3/3] [bugfix] Fixed the issue on sd3 dreambooth training --- examples/dreambooth/train_dreambooth_flux.py | 443 ++++---------- examples/dreambooth/train_dreambooth_lora.py | 388 +++--------- .../dreambooth/train_dreambooth_lora_flux.py | 419 +++---------- .../dreambooth/train_dreambooth_lora_sd3.py | 438 +++----------- .../dreambooth/train_dreambooth_lora_sdxl.py | 570 ++++-------------- examples/dreambooth/train_dreambooth_sd3.py | 409 +++---------- 6 files changed, 615 insertions(+), 2052 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index a47abfaa74b8..8e0f4e09a461 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -32,11 +32,7 @@ import transformers from accelerate import Accelerator from accelerate.logging import get_logger -from accelerate.utils import ( - DistributedDataParallelKwargs, - ProjectConfiguration, - set_seed, -) +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder from huggingface_hub.utils import insecure_hashlib from PIL import Image @@ -45,13 +41,7 @@ from torchvision import transforms from torchvision.transforms.functional import crop from tqdm.auto import tqdm -from transformers import ( - CLIPTextModelWithProjection, - CLIPTokenizer, - PretrainedConfig, - T5EncoderModel, - T5TokenizerFast, -) +from transformers import CLIPTextModelWithProjection, CLIPTokenizer, PretrainedConfig, T5EncoderModel, T5TokenizerFast import diffusers from diffusers import ( @@ -61,14 +51,15 @@ FluxTransformer2DModel, ) from diffusers.optimization import get_scheduler -from diffusers.training_utils import ( - compute_density_for_timestep_sampling, - compute_loss_weighting_for_sd3, +from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3 +from diffusers.utils import ( + check_min_version, + is_wandb_available, ) -from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.torch_utils import is_compiled_module + if is_wandb_available(): import wandb @@ -92,10 +83,7 @@ def save_model_card( for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) widget_dict.append( - { - "text": validation_prompt if validation_prompt else " ", - "output": {"url": f"image_{i}.png"}, - } + {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} ) model_description = f""" @@ -152,16 +140,10 @@ def save_model_card( def load_text_encoders(class_one, class_two): text_encoder_one = class_one.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="text_encoder", - revision=args.revision, - variant=args.variant, + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant ) text_encoder_two = class_two.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="text_encoder_2", - revision=args.revision, - variant=args.variant, + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant ) return text_encoder_one, text_encoder_two @@ -183,19 +165,12 @@ def log_validation( pipeline.set_progress_bar_config(disable=True) # run inference - generator = ( - torch.Generator(device=accelerator.device).manual_seed(args.seed) - if args.seed - else None - ) + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() autocast_ctx = nullcontext() with autocast_ctx: - images = [ - pipeline(**pipeline_args, generator=generator).images[0] - for _ in range(args.num_validation_images) - ] + images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" @@ -206,8 +181,7 @@ def log_validation( tracker.log( { phase_name: [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") - for i, image in enumerate(images) + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) ] } ) @@ -305,12 +279,7 @@ def parse_args(input_args=None): help="The column of the dataset containing the instance prompt for each image", ) - parser.add_argument( - "--repeats", - type=int, - default=1, - help="How many times to repeat the training data.", - ) + parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") parser.add_argument( "--class_data_dir", @@ -365,12 +334,7 @@ def parse_args(input_args=None): action="store_true", help="Flag to add prior preservation loss.", ) - parser.add_argument( - "--prior_loss_weight", - type=float, - default=1.0, - help="The weight of prior preservation loss.", - ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") parser.add_argument( "--num_class_images", type=int, @@ -386,9 +350,7 @@ def parse_args(input_args=None): default="flux-dreambooth", help="The output directory where the model predictions and checkpoints will be written.", ) - parser.add_argument( - "--seed", type=int, default=None, help="A seed for reproducible training." - ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") parser.add_argument( "--resolution", type=int, @@ -418,16 +380,10 @@ def parse_args(input_args=None): help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", ) parser.add_argument( - "--train_batch_size", - type=int, - default=4, - help="Batch size (per device) for the training dataloader.", + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) parser.add_argument( - "--sample_batch_size", - type=int, - default=4, - help="Batch size (per device) for sampling images.", + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." ) parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( @@ -508,10 +464,7 @@ def parse_args(input_args=None): ), ) parser.add_argument( - "--lr_warmup_steps", - type=int, - default=500, - help="Number of steps for the warmup in the lr scheduler.", + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( "--lr_num_cycles", @@ -519,12 +472,7 @@ def parse_args(input_args=None): default=1, help="Number of hard resets of the lr in cosine_with_restarts scheduler.", ) - parser.add_argument( - "--lr_power", - type=float, - default=1.0, - help="Power factor of the polynomial scheduler.", - ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") parser.add_argument( "--dataloader_num_workers", type=int, @@ -538,21 +486,13 @@ def parse_args(input_args=None): type=str, default="none", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], - help=( - 'We default to the "none" weighting scheme for uniform sampling and uniform loss' - ), + help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), ) parser.add_argument( - "--logit_mean", - type=float, - default=0.0, - help="mean to use when using the `'logit_normal'` weighting scheme.", + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." ) parser.add_argument( - "--logit_std", - type=float, - default=1.0, - help="std to use when using the `'logit_normal'` weighting scheme.", + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." ) parser.add_argument( "--mode_scale", @@ -574,16 +514,10 @@ def parse_args(input_args=None): ) parser.add_argument( - "--adam_beta1", - type=float, - default=0.9, - help="The beta1 parameter for the Adam and Prodigy optimizers.", + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." ) parser.add_argument( - "--adam_beta2", - type=float, - default=0.999, - help="The beta2 parameter for the Adam and Prodigy optimizers.", + "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." ) parser.add_argument( "--prodigy_beta3", @@ -592,23 +526,10 @@ def parse_args(input_args=None): help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " "uses the value of square root of beta2. Ignored if optimizer is adamW", ) + parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") parser.add_argument( - "--prodigy_decouple", - type=bool, - default=True, - help="Use AdamW style decoupled weight decay", - ) - parser.add_argument( - "--adam_weight_decay", - type=float, - default=1e-04, - help="Weight decay to use for unet params", - ) - parser.add_argument( - "--adam_weight_decay_text_encoder", - type=float, - default=1e-03, - help="Weight decay to use for text_encoder", + "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" ) parser.add_argument( @@ -631,20 +552,9 @@ def parse_args(input_args=None): help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " "Ignored if optimizer is adamW", ) - parser.add_argument( - "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." - ) - parser.add_argument( - "--push_to_hub", - action="store_true", - help="Whether or not to push the model to the Hub.", - ) - parser.add_argument( - "--hub_token", - type=str, - default=None, - help="The token to use to push to the Model Hub.", - ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") parser.add_argument( "--hub_model_id", type=str, @@ -698,12 +608,7 @@ def parse_args(input_args=None): " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." ), ) - parser.add_argument( - "--local_rank", - type=int, - default=-1, - help="For distributed training: local_rank", - ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") if input_args is not None: args = parser.parse_args(input_args) @@ -714,9 +619,7 @@ def parse_args(input_args=None): raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") if args.dataset_name is not None and args.instance_data_dir is not None: - raise ValueError( - "Specify only one of `--dataset_name` or `--instance_data_dir`" - ) + raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: @@ -730,13 +633,9 @@ def parse_args(input_args=None): else: # logger is not available yet if args.class_data_dir is not None: - warnings.warn( - "You need not use --class_data_dir without --with_prior_preservation." - ) + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") if args.class_prompt is not None: - warnings.warn( - "You need not use --class_prompt without --with_prior_preservation." - ) + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") return args @@ -815,17 +714,13 @@ def __init__( # create final list of captions according to --repeats self.custom_instance_prompts = [] for caption in custom_instance_prompts: - self.custom_instance_prompts.extend( - itertools.repeat(caption, repeats) - ) + self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) else: self.instance_data_root = Path(instance_data_root) if not self.instance_data_root.exists(): raise ValueError("Instance images root doesn't exists.") - instance_images = [ - Image.open(path) for path in list(Path(instance_data_root).iterdir()) - ] + instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] self.custom_instance_prompts = None self.instance_images = [] @@ -833,12 +728,8 @@ def __init__( self.instance_images.extend(itertools.repeat(img, repeats)) self.pixel_values = [] - train_resize = transforms.Resize( - size, interpolation=transforms.InterpolationMode.BILINEAR - ) - train_crop = ( - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) - ) + train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) train_flip = transforms.RandomHorizontalFlip(p=1.0) train_transforms = transforms.Compose( [ @@ -859,9 +750,7 @@ def __init__( x1 = max(0, int(round((image.width - args.resolution) / 2.0))) image = train_crop(image) else: - y1, x1, h, w = train_crop.get_params( - image, (args.resolution, args.resolution) - ) + y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) image = crop(image, y1, x1, h, w) image = train_transforms(image) self.pixel_values.append(image) @@ -883,14 +772,8 @@ def __init__( self.image_transforms = transforms.Compose( [ - transforms.Resize( - size, interpolation=transforms.InterpolationMode.BILINEAR - ), - ( - transforms.CenterCrop(size) - if center_crop - else transforms.RandomCrop(size) - ), + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] @@ -915,9 +798,7 @@ def __getitem__(self, index): example["instance_prompt"] = self.instance_prompt if self.class_data_root: - class_image = Image.open( - self.class_images_path[index % self.num_class_images] - ) + class_image = Image.open(self.class_images_path[index % self.num_class_images]) class_image = exif_transpose(class_image) if not class_image.mode == "RGB": @@ -1001,9 +882,7 @@ def _encode_prompt_with_t5( text_input_ids = text_inputs.input_ids else: if text_input_ids is None: - raise ValueError( - "text_input_ids must be provided when the tokenizer is not specified" - ) + raise ValueError("text_input_ids must be provided when the tokenizer is not specified") prompt_embeds = text_encoder(text_input_ids.to(device))[0] @@ -1044,9 +923,7 @@ def _encode_prompt_with_clip( text_input_ids = text_inputs.input_ids else: if text_input_ids is None: - raise ValueError( - "text_input_ids must be provided when the tokenizer is not specified" - ) + raise ValueError("text_input_ids must be provided when the tokenizer is not specified") prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) @@ -1093,9 +970,7 @@ def encode_prompt( text_input_ids=text_input_ids_list[1] if text_input_ids_list else None, ) - text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to( - device=device, dtype=dtype - ) + text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) text_ids = text_ids.repeat(num_images_per_prompt, 1, 1) return prompt_embeds, pooled_prompt_embeds, text_ids @@ -1116,9 +991,7 @@ def main(args): logging_dir = Path(args.output_dir, args.logging_dir) - accelerator_project_config = ProjectConfiguration( - project_dir=args.output_dir, logging_dir=logging_dir - ) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -1134,9 +1007,7 @@ def main(args): if args.report_to == "wandb": if not is_wandb_available(): - raise ImportError( - "Make sure to install wandb if you want to use it for logging during training." - ) + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") # Make one log on every process with the configuration for debugging. logging.basicConfig( @@ -1164,12 +1035,8 @@ def main(args): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - has_supported_fp16_accelerator = ( - torch.cuda.is_available() or torch.backends.mps.is_available() - ) - torch_dtype = ( - torch.float16 if has_supported_fp16_accelerator else torch.float32 - ) + has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available() + torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32 if args.prior_generation_precision == "fp32": torch_dtype = torch.float32 elif args.prior_generation_precision == "fp16": @@ -1188,26 +1055,19 @@ def main(args): logger.info(f"Number of class images to sample: {num_new_images}.") sample_dataset = PromptDataset(args.class_prompt, num_new_images) - sample_dataloader = torch.utils.data.DataLoader( - sample_dataset, batch_size=args.sample_batch_size - ) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) sample_dataloader = accelerator.prepare(sample_dataloader) pipeline.to(accelerator.device) for example in tqdm( - sample_dataloader, - desc="Generating class images", - disable=not accelerator.is_local_main_process, + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): images = pipeline(example["prompt"]).images for i, image in enumerate(images): hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() - image_filename = ( - class_images_dir - / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" - ) + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" image.save(image_filename) del pipeline @@ -1250,9 +1110,7 @@ def main(args): args.pretrained_model_name_or_path, subfolder="scheduler" ) noise_scheduler_copy = copy.deepcopy(noise_scheduler) - text_encoder_one, text_encoder_two = load_text_encoders( - text_encoder_cls_one, text_encoder_cls_two - ) + text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) vae = AutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae", @@ -1260,10 +1118,7 @@ def main(args): variant=args.variant, ) transformer = FluxTransformer2DModel.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="transformer", - revision=args.revision, - variant=args.variant, + args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant ) transformer.requires_grad_(True) @@ -1309,20 +1164,12 @@ def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: for i, model in enumerate(models): if isinstance(unwrap_model(model), FluxTransformer2DModel): - unwrap_model(model).save_pretrained( - os.path.join(output_dir, "transformer") - ) - elif isinstance( - unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel) - ): + unwrap_model(model).save_pretrained(os.path.join(output_dir, "transformer")) + elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)): if isinstance(unwrap_model(model), CLIPTextModelWithProjection): - unwrap_model(model).save_pretrained( - os.path.join(output_dir, "text_encoder") - ) + unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder")) else: - unwrap_model(model).save_pretrained( - os.path.join(output_dir, "text_encoder_2") - ) + unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder_2")) else: raise ValueError(f"Wrong model supplied: {type(model)=}.") @@ -1336,32 +1183,22 @@ def load_model_hook(models, input_dir): # load diffusers style into model if isinstance(unwrap_model(model), FluxTransformer2DModel): - load_model = FluxTransformer2DModel.from_pretrained( - input_dir, subfolder="transformer" - ) + load_model = FluxTransformer2DModel.from_pretrained(input_dir, subfolder="transformer") model.register_to_config(**load_model.config) model.load_state_dict(load_model.state_dict()) - elif isinstance( - unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel) - ): + elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)): try: - load_model = CLIPTextModelWithProjection.from_pretrained( - input_dir, subfolder="text_encoder" - ) + load_model = CLIPTextModelWithProjection.from_pretrained(input_dir, subfolder="text_encoder") model(**load_model.config) model.load_state_dict(load_model.state_dict()) except Exception: try: - load_model = T5EncoderModel.from_pretrained( - input_dir, subfolder="text_encoder_2" - ) + load_model = T5EncoderModel.from_pretrained(input_dir, subfolder="text_encoder_2") model(**load_model.config) model.load_state_dict(load_model.state_dict()) except Exception: - raise ValueError( - f"Couldn't load the model of type: ({type(model)})." - ) + raise ValueError(f"Couldn't load the model of type: ({type(model)}).") else: raise ValueError(f"Unsupported model found: {type(model)=}") @@ -1377,17 +1214,11 @@ def load_model_hook(models, input_dir): if args.scale_lr: args.learning_rate = ( - args.learning_rate - * args.gradient_accumulation_steps - * args.train_batch_size - * accelerator.num_processes + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) # Optimization parameters - transformer_parameters_with_lr = { - "params": transformer.parameters(), - "lr": args.learning_rate, - } + transformer_parameters_with_lr = {"params": transformer.parameters(), "lr": args.learning_rate} if args.train_text_encoder: # different learning rate for text encoder and unet text_parameters_one_with_lr = { @@ -1440,9 +1271,7 @@ def load_model_hook(models, input_dir): try: import prodigyopt except ImportError: - raise ImportError( - "To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`" - ) + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") optimizer_class = prodigyopt.Prodigy @@ -1511,17 +1340,15 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid # the redundant encoding. if not args.train_text_encoder and not train_dataset.custom_instance_prompts: - ( - instance_prompt_hidden_states, - instance_pooled_prompt_embeds, - instance_text_ids, - ) = compute_text_embeddings(args.instance_prompt, text_encoders, tokenizers) + instance_prompt_hidden_states, instance_pooled_prompt_embeds, instance_text_ids = compute_text_embeddings( + args.instance_prompt, text_encoders, tokenizers + ) # Handle class prompt for prior-preservation. if args.with_prior_preservation: if not args.train_text_encoder: - class_prompt_hidden_states, class_pooled_prompt_embeds, class_text_ids = ( - compute_text_embeddings(args.class_prompt, text_encoders, tokenizers) + class_prompt_hidden_states, class_pooled_prompt_embeds, class_text_ids = compute_text_embeddings( + args.class_prompt, text_encoders, tokenizers ) # Clear the memory here @@ -1543,37 +1370,23 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): pooled_prompt_embeds = instance_pooled_prompt_embeds text_ids = instance_text_ids if args.with_prior_preservation: - prompt_embeds = torch.cat( - [prompt_embeds, class_prompt_hidden_states], dim=0 - ) - pooled_prompt_embeds = torch.cat( - [pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0 - ) + prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) + pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0) text_ids = torch.cat([text_ids, class_text_ids], dim=0) # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the # batch prompts on all training steps else: - tokens_one = tokenize_prompt( - tokenizer_one, args.instance_prompt, max_sequence_length=77 - ) - tokens_two = tokenize_prompt( - tokenizer_two, args.instance_prompt, max_sequence_length=512 - ) + tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt, max_sequence_length=77) + tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt, max_sequence_length=512) if args.with_prior_preservation: - class_tokens_one = tokenize_prompt( - tokenizer_one, args.class_prompt, max_sequence_length=77 - ) - class_tokens_two = tokenize_prompt( - tokenizer_two, args.class_prompt, max_sequence_length=512 - ) + class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt, max_sequence_length=77) + class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt, max_sequence_length=512) tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) # Scheduler and math around the number of training steps. overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil( - len(train_dataloader) / args.gradient_accumulation_steps - ) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True @@ -1608,9 +1421,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil( - len(train_dataloader) / args.gradient_accumulation_steps - ) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs @@ -1623,20 +1434,14 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): accelerator.init_trackers(tracker_name, config=vars(args)) # Train! - total_batch_size = ( - args.train_batch_size - * accelerator.num_processes - * args.gradient_accumulation_steps - ) + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num batches each epoch = {len(train_dataloader)}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") - logger.info( - f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" - ) + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") global_step = 0 @@ -1705,17 +1510,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # encode batch prompts when custom prompts are provided for each image - if train_dataset.custom_instance_prompts: if not args.train_text_encoder: - prompt_embeds, pooled_prompt_embeds, text_ids = ( - compute_text_embeddings(prompts, text_encoders, tokenizers) + prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings( + prompts, text_encoders, tokenizers ) else: - tokens_one = tokenize_prompt( - tokenizer_one, prompts, max_sequence_length=77 - ) + tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77) tokens_two = tokenize_prompt( - tokenizer_two, - prompts, - max_sequence_length=args.max_sequence_length, + tokenizer_two, prompts, max_sequence_length=args.max_sequence_length ) prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( text_encoders=[text_encoder_one, text_encoder_two], @@ -1736,9 +1537,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Convert images to latent space model_input = vae.encode(pixel_values).latent_dist.sample() - model_input = ( - model_input - vae.config.shift_factor - ) * vae.config.scaling_factor + model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor model_input = model_input.to(dtype=weight_dtype) vae_scale_factor = 2 ** (len(vae.config.block_out_channels)) @@ -1765,15 +1564,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): mode_scale=args.mode_scale, ) indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() - timesteps = noise_scheduler_copy.timesteps[indices].to( - device=model_input.device - ) + timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) # Add noise according to flow matching. # zt = (1 - texp) * x + texp * z1 - sigmas = get_sigmas( - timesteps, n_dim=model_input.ndim, dtype=model_input.dtype - ) + sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise packed_noisy_model_input = FluxPipeline._pack_latents( @@ -1786,9 +1581,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # handle guidance if transformer.config.guidance_embeds: - guidance = torch.tensor( - [args.guidance_scale], device=accelerator.device - ) + guidance = torch.tensor([args.guidance_scale], device=accelerator.device) guidance = guidance.expand(model_input.shape[0]) else: guidance = None @@ -1815,9 +1608,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss - weighting = compute_loss_weighting_for_sd3( - weighting_scheme=args.weighting_scheme, sigmas=sigmas - ) + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) # flow matching loss target = noise - model_input @@ -1829,19 +1620,16 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Compute prior loss prior_loss = torch.mean( - ( - weighting.float() - * (model_pred_prior.float() - target_prior.float()) ** 2 - ).reshape(target_prior.shape[0], -1), + (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( + target_prior.shape[0], -1 + ), 1, ) prior_loss = prior_loss.mean() # Compute regular loss. loss = torch.mean( - ( - weighting.float() * (model_pred.float() - target.float()) ** 2 - ).reshape(target.shape[0], -1), + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1, ) loss = loss.mean() @@ -1853,9 +1641,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator.backward(loss) if accelerator.sync_gradients: params_to_clip = ( - itertools.chain( - transformer.parameters(), text_encoder_one.parameters() - ) + itertools.chain(transformer.parameters(), text_encoder_one.parameters()) if args.train_text_encoder else transformer.parameters() ) @@ -1875,36 +1661,24 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: checkpoints = os.listdir(args.output_dir) - checkpoints = [ - d for d in checkpoints if d.startswith("checkpoint") - ] - checkpoints = sorted( - checkpoints, key=lambda x: int(x.split("-")[1]) - ) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints if len(checkpoints) >= args.checkpoints_total_limit: - num_to_remove = ( - len(checkpoints) - args.checkpoints_total_limit + 1 - ) + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 removing_checkpoints = checkpoints[0:num_to_remove] logger.info( f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" ) - logger.info( - f"removing checkpoints: {', '.join(removing_checkpoints)}" - ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") for removing_checkpoint in removing_checkpoints: - removing_checkpoint = os.path.join( - args.output_dir, removing_checkpoint - ) + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) shutil.rmtree(removing_checkpoint) - save_path = os.path.join( - args.output_dir, f"checkpoint-{global_step}" - ) + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") @@ -1916,15 +1690,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): break if accelerator.is_main_process: - if ( - args.validation_prompt is not None - and epoch % args.validation_epochs == 0 - ): + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: # create pipeline if not args.train_text_encoder: - text_encoder_one, text_encoder_two = load_text_encoders( - text_encoder_cls_one, text_encoder_cls_two - ) + text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) else: # even when training the text encoder we're only training text encoder one text_encoder_two = text_encoder_cls_two.from_pretrained( args.pretrained_model_name_or_path, @@ -1969,9 +1738,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): text_encoder=text_encoder_one, ) else: - pipeline = FluxPipeline.from_pretrained( - args.pretrained_model_name_or_path, transformer=transformer - ) + pipeline = FluxPipeline.from_pretrained(args.pretrained_model_name_or_path, transformer=transformer) # save the pipeline pipeline.save_pretrained(args.output_dir) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index e0fb47dcae30..5d7d697bb21d 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -54,10 +54,7 @@ ) from diffusers.loaders import StableDiffusionLoraLoaderMixin from diffusers.optimization import get_scheduler -from diffusers.training_utils import ( - _set_state_dict_into_text_encoder, - cast_training_params, -) +from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params from diffusers.utils import ( check_min_version, convert_state_dict_to_diffusers, @@ -68,6 +65,7 @@ from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.torch_utils import is_compiled_module + if is_wandb_available(): import wandb @@ -142,19 +140,13 @@ def log_validation( scheduler_args["variance_type"] = variance_type - pipeline.scheduler = DPMSolverMultistepScheduler.from_config( - pipeline.scheduler.config, **scheduler_args - ) + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) pipeline.set_progress_bar_config(disable=True) # run inference - generator = ( - torch.Generator(device=accelerator.device).manual_seed(args.seed) - if args.seed - else None - ) + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None if args.validation_images is None: images = [] @@ -167,9 +159,7 @@ def log_validation( for image in args.validation_images: image = Image.open(image) with torch.cuda.amp.autocast(): - image = pipeline( - **pipeline_args, image=image, generator=generator - ).images[0] + image = pipeline(**pipeline_args, image=image, generator=generator).images[0] images.append(image) for tracker in accelerator.trackers: @@ -181,8 +171,7 @@ def log_validation( tracker.log( { phase_name: [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") - for i, image in enumerate(images) + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) ] } ) @@ -193,9 +182,7 @@ def log_validation( return images -def import_model_class_from_model_name_or_path( - pretrained_model_name_or_path: str, revision: str -): +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder", @@ -208,9 +195,7 @@ def import_model_class_from_model_name_or_path( return CLIPTextModel elif model_class == "RobertaSeriesModelWithTransformation": - from diffusers.pipelines.alt_diffusion.modeling_roberta_series import ( - RobertaSeriesModelWithTransformation, - ) + from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation return RobertaSeriesModelWithTransformation elif model_class == "T5EncoderModel": @@ -303,12 +288,7 @@ def parse_args(input_args=None): action="store_true", help="Flag to add prior preservation loss.", ) - parser.add_argument( - "--prior_loss_weight", - type=float, - default=1.0, - help="The weight of prior preservation loss.", - ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") parser.add_argument( "--num_class_images", type=int, @@ -324,9 +304,7 @@ def parse_args(input_args=None): default="lora-dreambooth-model", help="The output directory where the model predictions and checkpoints will be written.", ) - parser.add_argument( - "--seed", type=int, default=None, help="A seed for reproducible training." - ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") parser.add_argument( "--resolution", type=int, @@ -351,16 +329,10 @@ def parse_args(input_args=None): help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", ) parser.add_argument( - "--train_batch_size", - type=int, - default=4, - help="Batch size (per device) for the training dataloader.", + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) parser.add_argument( - "--sample_batch_size", - type=int, - default=4, - help="Batch size (per device) for sampling images.", + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." ) parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( @@ -427,10 +399,7 @@ def parse_args(input_args=None): ), ) parser.add_argument( - "--lr_warmup_steps", - type=int, - default=500, - help="Number of steps for the warmup in the lr scheduler.", + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( "--lr_num_cycles", @@ -438,12 +407,7 @@ def parse_args(input_args=None): default=1, help="Number of hard resets of the lr in cosine_with_restarts scheduler.", ) - parser.add_argument( - "--lr_power", - type=float, - default=1.0, - help="Power factor of the polynomial scheduler.", - ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") parser.add_argument( "--dataloader_num_workers", type=int, @@ -453,45 +417,15 @@ def parse_args(input_args=None): ), ) parser.add_argument( - "--use_8bit_adam", - action="store_true", - help="Whether or not to use 8-bit Adam from bitsandbytes.", - ) - parser.add_argument( - "--adam_beta1", - type=float, - default=0.9, - help="The beta1 parameter for the Adam optimizer.", - ) - parser.add_argument( - "--adam_beta2", - type=float, - default=0.999, - help="The beta2 parameter for the Adam optimizer.", - ) - parser.add_argument( - "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use." - ) - parser.add_argument( - "--adam_epsilon", - type=float, - default=1e-08, - help="Epsilon value for the Adam optimizer", - ) - parser.add_argument( - "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." - ) - parser.add_argument( - "--push_to_hub", - action="store_true", - help="Whether or not to push the model to the Hub.", - ) - parser.add_argument( - "--hub_token", - type=str, - default=None, - help="The token to use to push to the Model Hub.", - ) + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") parser.add_argument( "--hub_model_id", type=str, @@ -545,16 +479,9 @@ def parse_args(input_args=None): " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." ), ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument( - "--local_rank", - type=int, - default=-1, - help="For distributed training: local_rank", - ) - parser.add_argument( - "--enable_xformers_memory_efficient_attention", - action="store_true", - help="Whether or not to use xformers.", + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) parser.add_argument( "--pre_compute_text_embeddings", @@ -611,18 +538,12 @@ def parse_args(input_args=None): else: # logger is not available yet if args.class_data_dir is not None: - warnings.warn( - "You need not use --class_data_dir without --with_prior_preservation." - ) + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") if args.class_prompt is not None: - warnings.warn( - "You need not use --class_prompt without --with_prior_preservation." - ) + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") if args.train_text_encoder and args.pre_compute_text_embeddings: - raise ValueError( - "`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`" - ) + raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`") return args @@ -678,14 +599,8 @@ def __init__( self.image_transforms = transforms.Compose( [ - transforms.Resize( - size, interpolation=transforms.InterpolationMode.BILINEAR - ), - ( - transforms.CenterCrop(size) - if center_crop - else transforms.RandomCrop(size) - ), + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] @@ -696,9 +611,7 @@ def __len__(self): def __getitem__(self, index): example = {} - instance_image = Image.open( - self.instance_images_path[index % self.num_instance_images] - ) + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) instance_image = exif_transpose(instance_image) if not instance_image.mode == "RGB": @@ -709,17 +622,13 @@ def __getitem__(self, index): example["instance_prompt_ids"] = self.encoder_hidden_states else: text_inputs = tokenize_prompt( - self.tokenizer, - self.instance_prompt, - tokenizer_max_length=self.tokenizer_max_length, + self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length ) example["instance_prompt_ids"] = text_inputs.input_ids example["instance_attention_mask"] = text_inputs.attention_mask if self.class_data_root: - class_image = Image.open( - self.class_images_path[index % self.num_class_images] - ) + class_image = Image.open(self.class_images_path[index % self.num_class_images]) class_image = exif_transpose(class_image) if not class_image.mode == "RGB": @@ -730,9 +639,7 @@ def __getitem__(self, index): example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states else: class_text_inputs = tokenize_prompt( - self.tokenizer, - self.class_prompt, - tokenizer_max_length=self.tokenizer_max_length, + self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length ) example["class_prompt_ids"] = class_text_inputs.input_ids example["class_attention_mask"] = class_text_inputs.attention_mask @@ -807,9 +714,7 @@ def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): return text_inputs -def encode_prompt( - text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None -): +def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None): text_input_ids = input_ids.to(text_encoder.device) if text_encoder_use_attention_mask: @@ -836,9 +741,7 @@ def main(args): logging_dir = Path(args.output_dir, args.logging_dir) - accelerator_project_config = ProjectConfiguration( - project_dir=args.output_dir, logging_dir=logging_dir - ) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -853,18 +756,12 @@ def main(args): if args.report_to == "wandb": if not is_wandb_available(): - raise ImportError( - "Make sure to install wandb if you want to use it for logging during training." - ) + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. # TODO (sayakpaul): Remove this check when gradient accumulation with two models is enabled in accelerate. - if ( - args.train_text_encoder - and args.gradient_accumulation_steps > 1 - and accelerator.num_processes > 1 - ): + if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: raise ValueError( "Gradient accumulation is not supported when training the text encoder in distributed training. " "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." @@ -896,9 +793,7 @@ def main(args): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - torch_dtype = ( - torch.float16 if accelerator.device.type == "cuda" else torch.float32 - ) + torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 if args.prior_generation_precision == "fp32": torch_dtype = torch.float32 elif args.prior_generation_precision == "fp16": @@ -918,26 +813,19 @@ def main(args): logger.info(f"Number of class images to sample: {num_new_images}.") sample_dataset = PromptDataset(args.class_prompt, num_new_images) - sample_dataloader = torch.utils.data.DataLoader( - sample_dataset, batch_size=args.sample_batch_size - ) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) sample_dataloader = accelerator.prepare(sample_dataloader) pipeline.to(accelerator.device) for example in tqdm( - sample_dataloader, - desc="Generating class images", - disable=not accelerator.is_local_main_process, + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): images = pipeline(example["prompt"]).images for i, image in enumerate(images): hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() - image_filename = ( - class_images_dir - / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" - ) + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" image.save(image_filename) del pipeline @@ -951,16 +839,12 @@ def main(args): if args.push_to_hub: repo_id = create_repo( - repo_id=args.hub_model_id or Path(args.output_dir).name, - exist_ok=True, - token=args.hub_token, + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token ).repo_id # Load the tokenizer if args.tokenizer_name: - tokenizer = AutoTokenizer.from_pretrained( - args.tokenizer_name, revision=args.revision, use_fast=False - ) + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) elif args.pretrained_model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( args.pretrained_model_name_or_path, @@ -970,26 +854,16 @@ def main(args): ) # import correct text encoder class - text_encoder_cls = import_model_class_from_model_name_or_path( - args.pretrained_model_name_or_path, args.revision - ) + text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) # Load scheduler and models - noise_scheduler = DDPMScheduler.from_pretrained( - args.pretrained_model_name_or_path, subfolder="scheduler" - ) + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") text_encoder = text_encoder_cls.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="text_encoder", - revision=args.revision, - variant=args.variant, + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant ) try: vae = AutoencoderKL.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="vae", - revision=args.revision, - variant=args.variant, + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant ) except OSError: # IF does not have a VAE so let's just set it to None @@ -997,10 +871,7 @@ def main(args): vae = None unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="unet", - revision=args.revision, - variant=args.variant, + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) # We only train the additional adapter LoRA layers @@ -1034,9 +905,7 @@ def main(args): ) unet.enable_xformers_memory_efficient_attention() else: - raise ValueError( - "xformers is not available. Make sure it is installed correctly" - ) + raise ValueError("xformers is not available. Make sure it is installed correctly") if args.gradient_checkpointing: unet.enable_gradient_checkpointing() @@ -1077,9 +946,7 @@ def save_model_hook(models, weights, output_dir): for model in models: if isinstance(model, type(unwrap_model(unet))): - unet_lora_layers_to_save = convert_state_dict_to_diffusers( - get_peft_model_state_dict(model) - ) + unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) elif isinstance(model, type(unwrap_model(text_encoder))): text_encoder_lora_layers_to_save = convert_state_dict_to_diffusers( get_peft_model_state_dict(model) @@ -1110,19 +977,11 @@ def load_model_hook(models, input_dir): else: raise ValueError(f"unexpected save model: {model.__class__}") - lora_state_dict, network_alphas = ( - StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) - ) + lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) - unet_state_dict = { - f'{k.replace("unet.", "")}': v - for k, v in lora_state_dict.items() - if k.startswith("unet.") - } + unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) - incompatible_keys = set_peft_model_state_dict( - unet_, unet_state_dict, adapter_name="default" - ) + incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") if incompatible_keys is not None: # check only for unexpected keys @@ -1134,9 +993,7 @@ def load_model_hook(models, input_dir): ) if args.train_text_encoder: - _set_state_dict_into_text_encoder( - lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_ - ) + _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_) # Make sure the trainable params are in float32. This is again needed since the base models # are in `weight_dtype`. More details: @@ -1159,10 +1016,7 @@ def load_model_hook(models, input_dir): if args.scale_lr: args.learning_rate = ( - args.learning_rate - * args.gradient_accumulation_steps - * args.train_batch_size - * accelerator.num_processes + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) # Make sure the trainable params are in float32. @@ -1190,9 +1044,7 @@ def load_model_hook(models, input_dir): # Optimizer creation params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters())) if args.train_text_encoder: - params_to_optimize = params_to_optimize + list( - filter(lambda p: p.requires_grad, text_encoder.parameters()) - ) + params_to_optimize = params_to_optimize + list(filter(lambda p: p.requires_grad, text_encoder.parameters())) optimizer = optimizer_class( params_to_optimize, @@ -1206,9 +1058,7 @@ def load_model_hook(models, input_dir): def compute_text_embeddings(prompt): with torch.no_grad(): - text_inputs = tokenize_prompt( - tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length - ) + text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length) prompt_embeds = encode_prompt( text_encoder, text_inputs.input_ids, @@ -1218,22 +1068,16 @@ def compute_text_embeddings(prompt): return prompt_embeds - pre_computed_encoder_hidden_states = compute_text_embeddings( - args.instance_prompt - ) + pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt) validation_prompt_negative_prompt_embeds = compute_text_embeddings("") if args.validation_prompt is not None: - validation_prompt_encoder_hidden_states = compute_text_embeddings( - args.validation_prompt - ) + validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt) else: validation_prompt_encoder_hidden_states = None if args.class_prompt is not None: - pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings( - args.class_prompt - ) + pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt) else: pre_computed_class_prompt_encoder_hidden_states = None @@ -1273,9 +1117,7 @@ def compute_text_embeddings(prompt): # Scheduler and math around the number of training steps. overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil( - len(train_dataloader) / args.gradient_accumulation_steps - ) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True @@ -1291,10 +1133,8 @@ def compute_text_embeddings(prompt): # Prepare everything with our `accelerator`. if args.train_text_encoder: - unet, text_encoder, optimizer, train_dataloader, lr_scheduler = ( - accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler - ) + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler ) else: unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( @@ -1302,9 +1142,7 @@ def compute_text_embeddings(prompt): ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil( - len(train_dataloader) / args.gradient_accumulation_steps - ) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs @@ -1318,20 +1156,14 @@ def compute_text_embeddings(prompt): accelerator.init_trackers("dreambooth-lora", config=tracker_config) # Train! - total_batch_size = ( - args.train_batch_size - * accelerator.num_processes - * args.gradient_accumulation_steps - ) + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num batches each epoch = {len(train_dataloader)}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") - logger.info( - f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" - ) + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") global_step = 0 @@ -1392,18 +1224,13 @@ def compute_text_embeddings(prompt): bsz, channels, height, width = model_input.shape # Sample a random timestep for each image timesteps = torch.randint( - 0, - noise_scheduler.config.num_train_timesteps, - (bsz,), - device=model_input.device, + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device ) timesteps = timesteps.long() # Add noise to the model input according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_model_input = noise_scheduler.add_noise( - model_input, noise, timesteps - ) + noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) # Get the text embedding for conditioning if args.pre_compute_text_embeddings: @@ -1417,9 +1244,7 @@ def compute_text_embeddings(prompt): ) if unwrap_model(unet).config.in_channels == channels * 2: - noisy_model_input = torch.cat( - [noisy_model_input, noisy_model_input], dim=1 - ) + noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) if args.class_labels_conditioning == "timesteps": class_labels = timesteps @@ -1447,9 +1272,7 @@ def compute_text_embeddings(prompt): elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(model_input, noise, timesteps) else: - raise ValueError( - f"Unknown prediction type {noise_scheduler.config.prediction_type}" - ) + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") if args.with_prior_preservation: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. @@ -1457,21 +1280,15 @@ def compute_text_embeddings(prompt): target, target_prior = torch.chunk(target, 2, dim=0) # Compute instance loss - loss = F.mse_loss( - model_pred.float(), target.float(), reduction="mean" - ) + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") # Compute prior loss - prior_loss = F.mse_loss( - model_pred_prior.float(), target_prior.float(), reduction="mean" - ) + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss else: - loss = F.mse_loss( - model_pred.float(), target.float(), reduction="mean" - ) + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") accelerator.backward(loss) if accelerator.sync_gradients: @@ -1490,36 +1307,24 @@ def compute_text_embeddings(prompt): # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: checkpoints = os.listdir(args.output_dir) - checkpoints = [ - d for d in checkpoints if d.startswith("checkpoint") - ] - checkpoints = sorted( - checkpoints, key=lambda x: int(x.split("-")[1]) - ) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints if len(checkpoints) >= args.checkpoints_total_limit: - num_to_remove = ( - len(checkpoints) - args.checkpoints_total_limit + 1 - ) + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 removing_checkpoints = checkpoints[0:num_to_remove] logger.info( f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" ) - logger.info( - f"removing checkpoints: {', '.join(removing_checkpoints)}" - ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") for removing_checkpoint in removing_checkpoints: - removing_checkpoint = os.path.join( - args.output_dir, removing_checkpoint - ) + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) shutil.rmtree(removing_checkpoint) - save_path = os.path.join( - args.output_dir, f"checkpoint-{global_step}" - ) + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") @@ -1531,19 +1336,12 @@ def compute_text_embeddings(prompt): break if accelerator.is_main_process: - if ( - args.validation_prompt is not None - and epoch % args.validation_epochs == 0 - ): + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: # create pipeline pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=unwrap_model(unet), - text_encoder=( - None - if args.pre_compute_text_embeddings - else unwrap_model(text_encoder) - ), + text_encoder=None if args.pre_compute_text_embeddings else unwrap_model(text_encoder), revision=args.revision, variant=args.variant, torch_dtype=weight_dtype, @@ -1572,15 +1370,11 @@ def compute_text_embeddings(prompt): unet = unwrap_model(unet) unet = unet.to(torch.float32) - unet_lora_state_dict = convert_state_dict_to_diffusers( - get_peft_model_state_dict(unet) - ) + unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet)) if args.train_text_encoder: text_encoder = unwrap_model(text_encoder) - text_encoder_state_dict = convert_state_dict_to_diffusers( - get_peft_model_state_dict(text_encoder) - ) + text_encoder_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder)) else: text_encoder_state_dict = None @@ -1593,24 +1387,16 @@ def compute_text_embeddings(prompt): # Final inference # Load previous pipeline pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, - revision=args.revision, - variant=args.variant, - torch_dtype=weight_dtype, + args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype ) # load attention processors - pipeline.load_lora_weights( - args.output_dir, weight_name="pytorch_lora_weights.safetensors" - ) + pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors") # run inference images = [] if args.validation_prompt and args.num_validation_images > 0: - pipeline_args = { - "prompt": args.validation_prompt, - "num_inference_steps": 25, - } + pipeline_args = {"prompt": args.validation_prompt, "num_inference_steps": 25} images = log_validation( pipeline, args, diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 9c60739f2aec..bd5b46cc9fa9 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -32,11 +32,7 @@ import transformers from accelerate import Accelerator from accelerate.logging import get_logger -from accelerate.utils import ( - DistributedDataParallelKwargs, - ProjectConfiguration, - set_seed, -) +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder from huggingface_hub.utils import insecure_hashlib from peft import LoraConfig, set_peft_model_state_dict @@ -71,6 +67,7 @@ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.torch_utils import is_compiled_module + if is_wandb_available(): import wandb @@ -94,10 +91,7 @@ def save_model_card( for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) widget_dict.append( - { - "text": validation_prompt if validation_prompt else " ", - "output": {"url": f"image_{i}.png"}, - } + {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} ) model_description = f""" @@ -162,16 +156,10 @@ def save_model_card( def load_text_encoders(class_one, class_two): text_encoder_one = class_one.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="text_encoder", - revision=args.revision, - variant=args.variant, + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant ) text_encoder_two = class_two.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="text_encoder_2", - revision=args.revision, - variant=args.variant, + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant ) return text_encoder_one, text_encoder_two @@ -193,19 +181,12 @@ def log_validation( pipeline.set_progress_bar_config(disable=True) # run inference - generator = ( - torch.Generator(device=accelerator.device).manual_seed(args.seed) - if args.seed - else None - ) + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() autocast_ctx = nullcontext() with autocast_ctx: - images = [ - pipeline(**pipeline_args, generator=generator).images[0] - for _ in range(args.num_validation_images) - ] + images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" @@ -216,8 +197,7 @@ def log_validation( tracker.log( { phase_name: [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") - for i, image in enumerate(images) + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) ] } ) @@ -315,12 +295,7 @@ def parse_args(input_args=None): help="The column of the dataset containing the instance prompt for each image", ) - parser.add_argument( - "--repeats", - type=int, - default=1, - help="How many times to repeat the training data.", - ) + parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") parser.add_argument( "--class_data_dir", @@ -381,12 +356,7 @@ def parse_args(input_args=None): action="store_true", help="Flag to add prior preservation loss.", ) - parser.add_argument( - "--prior_loss_weight", - type=float, - default=1.0, - help="The weight of prior preservation loss.", - ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") parser.add_argument( "--num_class_images", type=int, @@ -402,9 +372,7 @@ def parse_args(input_args=None): default="flux-dreambooth-lora", help="The output directory where the model predictions and checkpoints will be written.", ) - parser.add_argument( - "--seed", type=int, default=None, help="A seed for reproducible training." - ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") parser.add_argument( "--resolution", type=int, @@ -434,16 +402,10 @@ def parse_args(input_args=None): help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", ) parser.add_argument( - "--train_batch_size", - type=int, - default=4, - help="Batch size (per device) for the training dataloader.", + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) parser.add_argument( - "--sample_batch_size", - type=int, - default=4, - help="Batch size (per device) for sampling images.", + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." ) parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( @@ -524,10 +486,7 @@ def parse_args(input_args=None): ), ) parser.add_argument( - "--lr_warmup_steps", - type=int, - default=500, - help="Number of steps for the warmup in the lr scheduler.", + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( "--lr_num_cycles", @@ -535,12 +494,7 @@ def parse_args(input_args=None): default=1, help="Number of hard resets of the lr in cosine_with_restarts scheduler.", ) - parser.add_argument( - "--lr_power", - type=float, - default=1.0, - help="Power factor of the polynomial scheduler.", - ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") parser.add_argument( "--dataloader_num_workers", type=int, @@ -554,21 +508,13 @@ def parse_args(input_args=None): type=str, default="none", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], - help=( - 'We default to the "none" weighting scheme for uniform sampling and uniform loss' - ), + help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), ) parser.add_argument( - "--logit_mean", - type=float, - default=0.0, - help="mean to use when using the `'logit_normal'` weighting scheme.", + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." ) parser.add_argument( - "--logit_std", - type=float, - default=1.0, - help="std to use when using the `'logit_normal'` weighting scheme.", + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." ) parser.add_argument( "--mode_scale", @@ -590,16 +536,10 @@ def parse_args(input_args=None): ) parser.add_argument( - "--adam_beta1", - type=float, - default=0.9, - help="The beta1 parameter for the Adam and Prodigy optimizers.", + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." ) parser.add_argument( - "--adam_beta2", - type=float, - default=0.999, - help="The beta2 parameter for the Adam and Prodigy optimizers.", + "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." ) parser.add_argument( "--prodigy_beta3", @@ -608,23 +548,10 @@ def parse_args(input_args=None): help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " "uses the value of square root of beta2. Ignored if optimizer is adamW", ) + parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") parser.add_argument( - "--prodigy_decouple", - type=bool, - default=True, - help="Use AdamW style decoupled weight decay", - ) - parser.add_argument( - "--adam_weight_decay", - type=float, - default=1e-04, - help="Weight decay to use for unet params", - ) - parser.add_argument( - "--adam_weight_decay_text_encoder", - type=float, - default=1e-03, - help="Weight decay to use for text_encoder", + "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" ) parser.add_argument( @@ -647,20 +574,9 @@ def parse_args(input_args=None): help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " "Ignored if optimizer is adamW", ) - parser.add_argument( - "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." - ) - parser.add_argument( - "--push_to_hub", - action="store_true", - help="Whether or not to push the model to the Hub.", - ) - parser.add_argument( - "--hub_token", - type=str, - default=None, - help="The token to use to push to the Model Hub.", - ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") parser.add_argument( "--hub_model_id", type=str, @@ -714,12 +630,7 @@ def parse_args(input_args=None): " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." ), ) - parser.add_argument( - "--local_rank", - type=int, - default=-1, - help="For distributed training: local_rank", - ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") if input_args is not None: args = parser.parse_args(input_args) @@ -730,9 +641,7 @@ def parse_args(input_args=None): raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") if args.dataset_name is not None and args.instance_data_dir is not None: - raise ValueError( - "Specify only one of `--dataset_name` or `--instance_data_dir`" - ) + raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: @@ -746,13 +655,9 @@ def parse_args(input_args=None): else: # logger is not available yet if args.class_data_dir is not None: - warnings.warn( - "You need not use --class_data_dir without --with_prior_preservation." - ) + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") if args.class_prompt is not None: - warnings.warn( - "You need not use --class_prompt without --with_prior_preservation." - ) + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") return args @@ -831,17 +736,13 @@ def __init__( # create final list of captions according to --repeats self.custom_instance_prompts = [] for caption in custom_instance_prompts: - self.custom_instance_prompts.extend( - itertools.repeat(caption, repeats) - ) + self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) else: self.instance_data_root = Path(instance_data_root) if not self.instance_data_root.exists(): raise ValueError("Instance images root doesn't exists.") - instance_images = [ - Image.open(path) for path in list(Path(instance_data_root).iterdir()) - ] + instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] self.custom_instance_prompts = None self.instance_images = [] @@ -849,12 +750,8 @@ def __init__( self.instance_images.extend(itertools.repeat(img, repeats)) self.pixel_values = [] - train_resize = transforms.Resize( - size, interpolation=transforms.InterpolationMode.BILINEAR - ) - train_crop = ( - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) - ) + train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) train_flip = transforms.RandomHorizontalFlip(p=1.0) train_transforms = transforms.Compose( [ @@ -875,9 +772,7 @@ def __init__( x1 = max(0, int(round((image.width - args.resolution) / 2.0))) image = train_crop(image) else: - y1, x1, h, w = train_crop.get_params( - image, (args.resolution, args.resolution) - ) + y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) image = crop(image, y1, x1, h, w) image = train_transforms(image) self.pixel_values.append(image) @@ -899,14 +794,8 @@ def __init__( self.image_transforms = transforms.Compose( [ - transforms.Resize( - size, interpolation=transforms.InterpolationMode.BILINEAR - ), - ( - transforms.CenterCrop(size) - if center_crop - else transforms.RandomCrop(size) - ), + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] @@ -931,9 +820,7 @@ def __getitem__(self, index): example["instance_prompt"] = self.instance_prompt if self.class_data_root: - class_image = Image.open( - self.class_images_path[index % self.num_class_images] - ) + class_image = Image.open(self.class_images_path[index % self.num_class_images]) class_image = exif_transpose(class_image) if not class_image.mode == "RGB": @@ -1017,9 +904,7 @@ def _encode_prompt_with_t5( text_input_ids = text_inputs.input_ids else: if text_input_ids is None: - raise ValueError( - "text_input_ids must be provided when the tokenizer is not specified" - ) + raise ValueError("text_input_ids must be provided when the tokenizer is not specified") prompt_embeds = text_encoder(text_input_ids.to(device))[0] @@ -1060,9 +945,7 @@ def _encode_prompt_with_clip( text_input_ids = text_inputs.input_ids else: if text_input_ids is None: - raise ValueError( - "text_input_ids must be provided when the tokenizer is not specified" - ) + raise ValueError("text_input_ids must be provided when the tokenizer is not specified") prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) @@ -1109,9 +992,7 @@ def encode_prompt( text_input_ids=text_input_ids_list[1] if text_input_ids_list else None, ) - text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to( - device=device, dtype=dtype - ) + text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) text_ids = text_ids.repeat(num_images_per_prompt, 1, 1) return prompt_embeds, pooled_prompt_embeds, text_ids @@ -1132,9 +1013,7 @@ def main(args): logging_dir = Path(args.output_dir, args.logging_dir) - accelerator_project_config = ProjectConfiguration( - project_dir=args.output_dir, logging_dir=logging_dir - ) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -1150,9 +1029,7 @@ def main(args): if args.report_to == "wandb": if not is_wandb_available(): - raise ImportError( - "Make sure to install wandb if you want to use it for logging during training." - ) + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") # Make one log on every process with the configuration for debugging. logging.basicConfig( @@ -1180,12 +1057,8 @@ def main(args): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - has_supported_fp16_accelerator = ( - torch.cuda.is_available() or torch.backends.mps.is_available() - ) - torch_dtype = ( - torch.float16 if has_supported_fp16_accelerator else torch.float32 - ) + has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available() + torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32 if args.prior_generation_precision == "fp32": torch_dtype = torch.float32 elif args.prior_generation_precision == "fp16": @@ -1204,26 +1077,19 @@ def main(args): logger.info(f"Number of class images to sample: {num_new_images}.") sample_dataset = PromptDataset(args.class_prompt, num_new_images) - sample_dataloader = torch.utils.data.DataLoader( - sample_dataset, batch_size=args.sample_batch_size - ) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) sample_dataloader = accelerator.prepare(sample_dataloader) pipeline.to(accelerator.device) for example in tqdm( - sample_dataloader, - desc="Generating class images", - disable=not accelerator.is_local_main_process, + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): images = pipeline(example["prompt"]).images for i, image in enumerate(images): hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() - image_filename = ( - class_images_dir - / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" - ) + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" image.save(image_filename) del pipeline @@ -1266,9 +1132,7 @@ def main(args): args.pretrained_model_name_or_path, subfolder="scheduler" ) noise_scheduler_copy = copy.deepcopy(noise_scheduler) - text_encoder_one, text_encoder_two = load_text_encoders( - text_encoder_cls_one, text_encoder_cls_two - ) + text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) vae = AutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae", @@ -1276,10 +1140,7 @@ def main(args): variant=args.variant, ) transformer = FluxTransformer2DModel.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="transformer", - revision=args.revision, - variant=args.variant, + args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant ) # We only train the additional adapter LoRA layers @@ -1344,9 +1205,7 @@ def save_model_hook(models, weights, output_dir): if isinstance(model, type(unwrap_model(transformer))): transformer_lora_layers_to_save = get_peft_model_state_dict(model) elif isinstance(model, type(unwrap_model(text_encoder_one))): - text_encoder_one_lora_layers_to_save = get_peft_model_state_dict( - model - ) + text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1376,14 +1235,10 @@ def load_model_hook(models, input_dir): lora_state_dict = FluxPipeline.lora_state_dict(input_dir) transformer_state_dict = { - f'{k.replace("transformer.", "")}': v - for k, v in lora_state_dict.items() - if k.startswith("transformer.") + f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") } transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) - incompatible_keys = set_peft_model_state_dict( - transformer_, transformer_state_dict, adapter_name="default" - ) + incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") if incompatible_keys is not None: # check only for unexpected keys unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) @@ -1394,9 +1249,7 @@ def load_model_hook(models, input_dir): ) if args.train_text_encoder: # Do we need to call `scale_lora_layers()` here? - _set_state_dict_into_text_encoder( - lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_ - ) + _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) # Make sure the trainable params are in float32. This is again needed since the base models # are in `weight_dtype`. More details: @@ -1418,10 +1271,7 @@ def load_model_hook(models, input_dir): if args.scale_lr: args.learning_rate = ( - args.learning_rate - * args.gradient_accumulation_steps - * args.train_batch_size - * accelerator.num_processes + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) # Make sure the trainable params are in float32. @@ -1432,19 +1282,12 @@ def load_model_hook(models, input_dir): # only upcast trainable parameters (LoRA) into fp32 cast_training_params(models, dtype=torch.float32) - transformer_lora_parameters = list( - filter(lambda p: p.requires_grad, transformer.parameters()) - ) + transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) if args.train_text_encoder: - text_lora_parameters_one = list( - filter(lambda p: p.requires_grad, text_encoder_one.parameters()) - ) + text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters())) # Optimization parameters - transformer_parameters_with_lr = { - "params": transformer_lora_parameters, - "lr": args.learning_rate, - } + transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} if args.train_text_encoder: # different learning rate for text encoder and unet text_parameters_one_with_lr = { @@ -1497,9 +1340,7 @@ def load_model_hook(models, input_dir): try: import prodigyopt except ImportError: - raise ImportError( - "To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`" - ) + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") optimizer_class = prodigyopt.Prodigy @@ -1568,17 +1409,15 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid # the redundant encoding. if not args.train_text_encoder and not train_dataset.custom_instance_prompts: - ( - instance_prompt_hidden_states, - instance_pooled_prompt_embeds, - instance_text_ids, - ) = compute_text_embeddings(args.instance_prompt, text_encoders, tokenizers) + instance_prompt_hidden_states, instance_pooled_prompt_embeds, instance_text_ids = compute_text_embeddings( + args.instance_prompt, text_encoders, tokenizers + ) # Handle class prompt for prior-preservation. if args.with_prior_preservation: if not args.train_text_encoder: - class_prompt_hidden_states, class_pooled_prompt_embeds, class_text_ids = ( - compute_text_embeddings(args.class_prompt, text_encoders, tokenizers) + class_prompt_hidden_states, class_pooled_prompt_embeds, class_text_ids = compute_text_embeddings( + args.class_prompt, text_encoders, tokenizers ) # Clear the memory here @@ -1600,41 +1439,27 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): pooled_prompt_embeds = instance_pooled_prompt_embeds text_ids = instance_text_ids if args.with_prior_preservation: - prompt_embeds = torch.cat( - [prompt_embeds, class_prompt_hidden_states], dim=0 - ) - pooled_prompt_embeds = torch.cat( - [pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0 - ) + prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) + pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0) text_ids = torch.cat([text_ids, class_text_ids], dim=0) # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) # we need to tokenize and encode the batch prompts on all training steps else: - tokens_one = tokenize_prompt( - tokenizer_one, args.instance_prompt, max_sequence_length=77 - ) + tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt, max_sequence_length=77) tokens_two = tokenize_prompt( - tokenizer_two, - args.instance_prompt, - max_sequence_length=args.max_sequence_length, + tokenizer_two, args.instance_prompt, max_sequence_length=args.max_sequence_length ) if args.with_prior_preservation: - class_tokens_one = tokenize_prompt( - tokenizer_one, args.class_prompt, max_sequence_length=77 - ) + class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt, max_sequence_length=77) class_tokens_two = tokenize_prompt( - tokenizer_two, - args.class_prompt, - max_sequence_length=args.max_sequence_length, + tokenizer_two, args.class_prompt, max_sequence_length=args.max_sequence_length ) tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) # Scheduler and math around the number of training steps. overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil( - len(train_dataloader) / args.gradient_accumulation_steps - ) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True @@ -1669,9 +1494,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil( - len(train_dataloader) / args.gradient_accumulation_steps - ) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs @@ -1684,20 +1507,14 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): accelerator.init_trackers(tracker_name, config=vars(args)) # Train! - total_batch_size = ( - args.train_batch_size - * accelerator.num_processes - * args.gradient_accumulation_steps - ) + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num batches each epoch = {len(train_dataloader)}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") - logger.info( - f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" - ) + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") global_step = 0 @@ -1755,9 +1572,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.train_text_encoder: text_encoder_one.train() # set top parameter requires_grad = True for gradient checkpointing works - accelerator.unwrap_model( - text_encoder_one - ).text_model.embeddings.requires_grad_(True) + accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) for step, batch in enumerate(train_dataloader): models_to_accumulate = [transformer] @@ -1770,17 +1585,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # encode batch prompts when custom prompts are provided for each image - if train_dataset.custom_instance_prompts: if not args.train_text_encoder: - prompt_embeds, pooled_prompt_embeds, text_ids = ( - compute_text_embeddings(prompts, text_encoders, tokenizers) + prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings( + prompts, text_encoders, tokenizers ) else: - tokens_one = tokenize_prompt( - tokenizer_one, prompts, max_sequence_length=77 - ) + tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77) tokens_two = tokenize_prompt( - tokenizer_two, - prompts, - max_sequence_length=args.max_sequence_length, + tokenizer_two, prompts, max_sequence_length=args.max_sequence_length ) prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( text_encoders=[text_encoder_one, text_encoder_two], @@ -1803,9 +1614,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Convert images to latent space model_input = vae.encode(pixel_values).latent_dist.sample() - model_input = ( - model_input - vae.config.shift_factor - ) * vae.config.scaling_factor + model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor model_input = model_input.to(dtype=weight_dtype) vae_scale_factor = 2 ** (len(vae.config.block_out_channels)) @@ -1831,15 +1640,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): mode_scale=args.mode_scale, ) indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() - timesteps = noise_scheduler_copy.timesteps[indices].to( - device=model_input.device - ) + timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) # Add noise according to flow matching. # zt = (1 - texp) * x + texp * z1 - sigmas = get_sigmas( - timesteps, n_dim=model_input.ndim, dtype=model_input.dtype - ) + sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise packed_noisy_model_input = FluxPipeline._pack_latents( @@ -1852,9 +1657,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # handle guidance if transformer.config.guidance_embeds: - guidance = torch.tensor( - [args.guidance_scale], device=accelerator.device - ) + guidance = torch.tensor([args.guidance_scale], device=accelerator.device) guidance = guidance.expand(model_input.shape[0]) else: guidance = None @@ -1880,9 +1683,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss - weighting = compute_loss_weighting_for_sd3( - weighting_scheme=args.weighting_scheme, sigmas=sigmas - ) + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) # flow matching loss target = noise - model_input @@ -1894,19 +1695,16 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Compute prior loss prior_loss = torch.mean( - ( - weighting.float() - * (model_pred_prior.float() - target_prior.float()) ** 2 - ).reshape(target_prior.shape[0], -1), + (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( + target_prior.shape[0], -1 + ), 1, ) prior_loss = prior_loss.mean() # Compute regular loss. loss = torch.mean( - ( - weighting.float() * (model_pred.float() - target.float()) ** 2 - ).reshape(target.shape[0], -1), + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1, ) loss = loss.mean() @@ -1918,9 +1716,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator.backward(loss) if accelerator.sync_gradients: params_to_clip = ( - itertools.chain( - transformer.parameters(), text_encoder_one.parameters() - ) + itertools.chain(transformer.parameters(), text_encoder_one.parameters()) if args.train_text_encoder else transformer.parameters() ) @@ -1940,36 +1736,24 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: checkpoints = os.listdir(args.output_dir) - checkpoints = [ - d for d in checkpoints if d.startswith("checkpoint") - ] - checkpoints = sorted( - checkpoints, key=lambda x: int(x.split("-")[1]) - ) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints if len(checkpoints) >= args.checkpoints_total_limit: - num_to_remove = ( - len(checkpoints) - args.checkpoints_total_limit + 1 - ) + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 removing_checkpoints = checkpoints[0:num_to_remove] logger.info( f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" ) - logger.info( - f"removing checkpoints: {', '.join(removing_checkpoints)}" - ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") for removing_checkpoint in removing_checkpoints: - removing_checkpoint = os.path.join( - args.output_dir, removing_checkpoint - ) + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) shutil.rmtree(removing_checkpoint) - save_path = os.path.join( - args.output_dir, f"checkpoint-{global_step}" - ) + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") @@ -1981,15 +1765,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): break if accelerator.is_main_process: - if ( - args.validation_prompt is not None - and epoch % args.validation_epochs == 0 - ): + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: # create pipeline if not args.train_text_encoder: - text_encoder_one, text_encoder_two = load_text_encoders( - text_encoder_cls_one, text_encoder_cls_two - ) + text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) pipeline = FluxPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, @@ -2023,9 +1802,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.train_text_encoder: text_encoder_one = unwrap_model(text_encoder_one) - text_encoder_lora_layers = get_peft_model_state_dict( - text_encoder_one.to(torch.float32) - ) + text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32)) else: text_encoder_lora_layers = None diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 061226b9b57a..3060813bbbdc 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -31,11 +31,7 @@ import transformers from accelerate import Accelerator from accelerate.logging import get_logger -from accelerate.utils import ( - DistributedDataParallelKwargs, - ProjectConfiguration, - set_seed, -) +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder from huggingface_hub.utils import insecure_hashlib from peft import LoraConfig, set_peft_model_state_dict @@ -71,6 +67,7 @@ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.torch_utils import is_compiled_module + if is_wandb_available(): import wandb @@ -94,10 +91,7 @@ def save_model_card( for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) widget_dict.append( - { - "text": validation_prompt if validation_prompt else " ", - "output": {"url": f"image_{i}.png"}, - } + {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} ) model_description = f""" @@ -168,22 +162,13 @@ def save_model_card( def load_text_encoders(class_one, class_two, class_three): text_encoder_one = class_one.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="text_encoder", - revision=args.revision, - variant=args.variant, + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant ) text_encoder_two = class_two.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="text_encoder_2", - revision=args.revision, - variant=args.variant, + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant ) text_encoder_three = class_three.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="text_encoder_3", - revision=args.revision, - variant=args.variant, + args.pretrained_model_name_or_path, subfolder="text_encoder_3", revision=args.revision, variant=args.variant ) return text_encoder_one, text_encoder_two, text_encoder_three @@ -205,19 +190,12 @@ def log_validation( pipeline.set_progress_bar_config(disable=True) # run inference - generator = ( - torch.Generator(device=accelerator.device).manual_seed(args.seed) - if args.seed - else None - ) + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() autocast_ctx = nullcontext() with autocast_ctx: - images = [ - pipeline(**pipeline_args, generator=generator).images[0] - for _ in range(args.num_validation_images) - ] + images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" @@ -228,8 +206,7 @@ def log_validation( tracker.log( { phase_name: [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") - for i, image in enumerate(images) + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) ] } ) @@ -325,12 +302,7 @@ def parse_args(input_args=None): help="The column of the dataset containing the instance prompt for each image", ) - parser.add_argument( - "--repeats", - type=int, - default=1, - help="How many times to repeat the training data.", - ) + parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") parser.add_argument( "--class_data_dir", @@ -391,12 +363,7 @@ def parse_args(input_args=None): action="store_true", help="Flag to add prior preservation loss.", ) - parser.add_argument( - "--prior_loss_weight", - type=float, - default=1.0, - help="The weight of prior preservation loss.", - ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") parser.add_argument( "--num_class_images", type=int, @@ -412,9 +379,7 @@ def parse_args(input_args=None): default="sd3-dreambooth", help="The output directory where the model predictions and checkpoints will be written.", ) - parser.add_argument( - "--seed", type=int, default=None, help="A seed for reproducible training." - ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") parser.add_argument( "--resolution", type=int, @@ -445,16 +410,10 @@ def parse_args(input_args=None): ) parser.add_argument( - "--train_batch_size", - type=int, - default=4, - help="Batch size (per device) for the training dataloader.", + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) parser.add_argument( - "--sample_batch_size", - type=int, - default=4, - help="Batch size (per device) for sampling images.", + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." ) parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( @@ -528,10 +487,7 @@ def parse_args(input_args=None): ), ) parser.add_argument( - "--lr_warmup_steps", - type=int, - default=500, - help="Number of steps for the warmup in the lr scheduler.", + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( "--lr_num_cycles", @@ -539,12 +495,7 @@ def parse_args(input_args=None): default=1, help="Number of hard resets of the lr in cosine_with_restarts scheduler.", ) - parser.add_argument( - "--lr_power", - type=float, - default=1.0, - help="Power factor of the polynomial scheduler.", - ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") parser.add_argument( "--dataloader_num_workers", type=int, @@ -560,16 +511,10 @@ def parse_args(input_args=None): choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"], ) parser.add_argument( - "--logit_mean", - type=float, - default=0.0, - help="mean to use when using the `'logit_normal'` weighting scheme.", + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." ) parser.add_argument( - "--logit_std", - type=float, - default=1.0, - help="std to use when using the `'logit_normal'` weighting scheme.", + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." ) parser.add_argument( "--mode_scale", @@ -598,16 +543,10 @@ def parse_args(input_args=None): ) parser.add_argument( - "--adam_beta1", - type=float, - default=0.9, - help="The beta1 parameter for the Adam and Prodigy optimizers.", + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." ) parser.add_argument( - "--adam_beta2", - type=float, - default=0.999, - help="The beta2 parameter for the Adam and Prodigy optimizers.", + "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." ) parser.add_argument( "--prodigy_beta3", @@ -616,23 +555,10 @@ def parse_args(input_args=None): help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " "uses the value of square root of beta2. Ignored if optimizer is adamW", ) + parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") parser.add_argument( - "--prodigy_decouple", - type=bool, - default=True, - help="Use AdamW style decoupled weight decay", - ) - parser.add_argument( - "--adam_weight_decay", - type=float, - default=1e-04, - help="Weight decay to use for unet params", - ) - parser.add_argument( - "--adam_weight_decay_text_encoder", - type=float, - default=1e-03, - help="Weight decay to use for text_encoder", + "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" ) parser.add_argument( @@ -655,20 +581,9 @@ def parse_args(input_args=None): help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " "Ignored if optimizer is adamW", ) - parser.add_argument( - "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." - ) - parser.add_argument( - "--push_to_hub", - action="store_true", - help="Whether or not to push the model to the Hub.", - ) - parser.add_argument( - "--hub_token", - type=str, - default=None, - help="The token to use to push to the Model Hub.", - ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") parser.add_argument( "--hub_model_id", type=str, @@ -722,12 +637,7 @@ def parse_args(input_args=None): " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." ), ) - parser.add_argument( - "--local_rank", - type=int, - default=-1, - help="For distributed training: local_rank", - ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") if input_args is not None: args = parser.parse_args(input_args) @@ -738,9 +648,7 @@ def parse_args(input_args=None): raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") if args.dataset_name is not None and args.instance_data_dir is not None: - raise ValueError( - "Specify only one of `--dataset_name` or `--instance_data_dir`" - ) + raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: @@ -754,13 +662,9 @@ def parse_args(input_args=None): else: # logger is not available yet if args.class_data_dir is not None: - warnings.warn( - "You need not use --class_data_dir without --with_prior_preservation." - ) + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") if args.class_prompt is not None: - warnings.warn( - "You need not use --class_prompt without --with_prior_preservation." - ) + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") return args @@ -839,17 +743,13 @@ def __init__( # create final list of captions according to --repeats self.custom_instance_prompts = [] for caption in custom_instance_prompts: - self.custom_instance_prompts.extend( - itertools.repeat(caption, repeats) - ) + self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) else: self.instance_data_root = Path(instance_data_root) if not self.instance_data_root.exists(): raise ValueError("Instance images root doesn't exists.") - instance_images = [ - Image.open(path) for path in list(Path(instance_data_root).iterdir()) - ] + instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] self.custom_instance_prompts = None self.instance_images = [] @@ -857,12 +757,8 @@ def __init__( self.instance_images.extend(itertools.repeat(img, repeats)) self.pixel_values = [] - train_resize = transforms.Resize( - size, interpolation=transforms.InterpolationMode.BILINEAR - ) - train_crop = ( - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) - ) + train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) train_flip = transforms.RandomHorizontalFlip(p=1.0) train_transforms = transforms.Compose( [ @@ -883,9 +779,7 @@ def __init__( x1 = max(0, int(round((image.width - args.resolution) / 2.0))) image = train_crop(image) else: - y1, x1, h, w = train_crop.get_params( - image, (args.resolution, args.resolution) - ) + y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) image = crop(image, y1, x1, h, w) image = train_transforms(image) self.pixel_values.append(image) @@ -907,14 +801,8 @@ def __init__( self.image_transforms = transforms.Compose( [ - transforms.Resize( - size, interpolation=transforms.InterpolationMode.BILINEAR - ), - ( - transforms.CenterCrop(size) - if center_crop - else transforms.RandomCrop(size) - ), + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] @@ -939,9 +827,7 @@ def __getitem__(self, index): example["instance_prompt"] = self.instance_prompt if self.class_data_root: - class_image = Image.open( - self.class_images_path[index % self.num_class_images] - ) + class_image = Image.open(self.class_images_path[index % self.num_class_images]) class_image = exif_transpose(class_image) if not class_image.mode == "RGB": @@ -1022,9 +908,7 @@ def _encode_prompt_with_t5( text_input_ids = text_inputs.input_ids else: if text_input_ids is None: - raise ValueError( - "text_input_ids must be provided when the tokenizer is not specified" - ) + raise ValueError("text_input_ids must be provided when the tokenizer is not specified") prompt_embeds = text_encoder(text_input_ids.to(device))[0] @@ -1063,9 +947,7 @@ def _encode_prompt_with_clip( text_input_ids = text_inputs.input_ids else: if text_input_ids is None: - raise ValueError( - "text_input_ids must be provided when the tokenizer is not specified" - ) + raise ValueError("text_input_ids must be provided when the tokenizer is not specified") prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) @@ -1097,9 +979,7 @@ def encode_prompt( clip_prompt_embeds_list = [] clip_pooled_prompt_embeds_list = [] - for i, (tokenizer, text_encoder) in enumerate( - zip(clip_tokenizers, clip_text_encoders) - ): + for i, (tokenizer, text_encoder) in enumerate(zip(clip_tokenizers, clip_text_encoders)): prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip( text_encoder=text_encoder, tokenizer=tokenizer, @@ -1125,8 +1005,7 @@ def encode_prompt( ) clip_prompt_embeds = torch.nn.functional.pad( - clip_prompt_embeds, - (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]), + clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) ) prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) @@ -1148,9 +1027,7 @@ def main(args): logging_dir = Path(args.output_dir, args.logging_dir) - accelerator_project_config = ProjectConfiguration( - project_dir=args.output_dir, logging_dir=logging_dir - ) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -1166,9 +1043,7 @@ def main(args): if args.report_to == "wandb": if not is_wandb_available(): - raise ImportError( - "Make sure to install wandb if you want to use it for logging during training." - ) + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") # Make one log on every process with the configuration for debugging. logging.basicConfig( @@ -1196,12 +1071,8 @@ def main(args): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - has_supported_fp16_accelerator = ( - torch.cuda.is_available() or torch.backends.mps.is_available() - ) - torch_dtype = ( - torch.float16 if has_supported_fp16_accelerator else torch.float32 - ) + has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available() + torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32 if args.prior_generation_precision == "fp32": torch_dtype = torch.float32 elif args.prior_generation_precision == "fp16": @@ -1220,26 +1091,19 @@ def main(args): logger.info(f"Number of class images to sample: {num_new_images}.") sample_dataset = PromptDataset(args.class_prompt, num_new_images) - sample_dataloader = torch.utils.data.DataLoader( - sample_dataset, batch_size=args.sample_batch_size - ) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) sample_dataloader = accelerator.prepare(sample_dataloader) pipeline.to(accelerator.device) for example in tqdm( - sample_dataloader, - desc="Generating class images", - disable=not accelerator.is_local_main_process, + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): images = pipeline(example["prompt"]).images for i, image in enumerate(images): hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() - image_filename = ( - class_images_dir - / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" - ) + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" image.save(image_filename) clear_objs_and_retain_memory(objs=[pipeline]) @@ -1298,10 +1162,7 @@ def main(args): variant=args.variant, ) transformer = SD3Transformer2DModel.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="transformer", - revision=args.revision, - variant=args.variant, + args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant ) transformer.requires_grad_(False) @@ -1371,13 +1232,9 @@ def save_model_hook(models, weights, output_dir): if isinstance(model, type(unwrap_model(transformer))): transformer_lora_layers_to_save = get_peft_model_state_dict(model) elif isinstance(model, type(unwrap_model(text_encoder_one))): - text_encoder_one_lora_layers_to_save = get_peft_model_state_dict( - model - ) + text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) elif isinstance(model, type(unwrap_model(text_encoder_two))): - text_encoder_two_lora_layers_to_save = get_peft_model_state_dict( - model - ) + text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model) else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1411,14 +1268,10 @@ def load_model_hook(models, input_dir): lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir) transformer_state_dict = { - f'{k.replace("transformer.", "")}': v - for k, v in lora_state_dict.items() - if k.startswith("transformer.") + f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") } transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) - incompatible_keys = set_peft_model_state_dict( - transformer_, transformer_state_dict, adapter_name="default" - ) + incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") if incompatible_keys is not None: # check only for unexpected keys unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) @@ -1429,14 +1282,10 @@ def load_model_hook(models, input_dir): ) if args.train_text_encoder: # Do we need to call `scale_lora_layers()` here? - _set_state_dict_into_text_encoder( - lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_ - ) + _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) _set_state_dict_into_text_encoder( - lora_state_dict, - prefix="text_encoder_2.", - text_encoder=text_encoder_two_, + lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_ ) # Make sure the trainable params are in float32. This is again needed since the base models @@ -1459,10 +1308,7 @@ def load_model_hook(models, input_dir): if args.scale_lr: args.learning_rate = ( - args.learning_rate - * args.gradient_accumulation_steps - * args.train_batch_size - * accelerator.num_processes + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) # Make sure the trainable params are in float32. @@ -1473,22 +1319,13 @@ def load_model_hook(models, input_dir): # only upcast trainable parameters (LoRA) into fp32 cast_training_params(models, dtype=torch.float32) - transformer_lora_parameters = list( - filter(lambda p: p.requires_grad, transformer.parameters()) - ) + transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) if args.train_text_encoder: - text_lora_parameters_one = list( - filter(lambda p: p.requires_grad, text_encoder_one.parameters()) - ) - text_lora_parameters_two = list( - filter(lambda p: p.requires_grad, text_encoder_two.parameters()) - ) + text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters())) + text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters())) # Optimization parameters - transformer_parameters_with_lr = { - "params": transformer_lora_parameters, - "lr": args.learning_rate, - } + transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} if args.train_text_encoder: # different learning rate for text encoder and unet text_lora_parameters_one_with_lr = { @@ -1547,9 +1384,7 @@ def load_model_hook(models, input_dir): try: import prodigyopt except ImportError: - raise ImportError( - "To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`" - ) + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") optimizer_class = prodigyopt.Prodigy @@ -1604,28 +1439,22 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): return prompt_embeds, pooled_prompt_embeds if not args.train_text_encoder and not train_dataset.custom_instance_prompts: - instance_prompt_hidden_states, instance_pooled_prompt_embeds = ( - compute_text_embeddings(args.instance_prompt, text_encoders, tokenizers) + instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings( + args.instance_prompt, text_encoders, tokenizers ) # Handle class prompt for prior-preservation. if args.with_prior_preservation: if not args.train_text_encoder: - class_prompt_hidden_states, class_pooled_prompt_embeds = ( - compute_text_embeddings(args.class_prompt, text_encoders, tokenizers) + class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings( + args.class_prompt, text_encoders, tokenizers ) # Clear the memory here if not args.train_text_encoder and not train_dataset.custom_instance_prompts: # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection clear_objs_and_retain_memory( - objs=[ - tokenizers, - text_encoders, - text_encoder_one, - text_encoder_two, - text_encoder_three, - ] + objs=[tokenizers, text_encoders, text_encoder_one, text_encoder_two, text_encoder_three] ) # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), @@ -1637,12 +1466,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): prompt_embeds = instance_prompt_hidden_states pooled_prompt_embeds = instance_pooled_prompt_embeds if args.with_prior_preservation: - prompt_embeds = torch.cat( - [prompt_embeds, class_prompt_hidden_states], dim=0 - ) - pooled_prompt_embeds = torch.cat( - [pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0 - ) + prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) + pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0) # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the # batch prompts on all training steps else: @@ -1659,9 +1484,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Scheduler and math around the number of training steps. overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil( - len(train_dataloader) / args.gradient_accumulation_steps - ) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True @@ -1686,12 +1509,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): train_dataloader, lr_scheduler, ) = accelerator.prepare( - transformer, - text_encoder_one, - text_encoder_two, - optimizer, - train_dataloader, - lr_scheduler, + transformer, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler ) assert text_encoder_one is not None assert text_encoder_two is not None @@ -1702,9 +1520,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil( - len(train_dataloader) / args.gradient_accumulation_steps - ) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs @@ -1717,20 +1533,14 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): accelerator.init_trackers(tracker_name, config=vars(args)) # Train! - total_batch_size = ( - args.train_batch_size - * accelerator.num_processes - * args.gradient_accumulation_steps - ) + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num batches each epoch = {len(train_dataloader)}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") - logger.info( - f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" - ) + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") global_step = 0 @@ -1790,12 +1600,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): text_encoder_two.train() # set top parameter requires_grad = True for gradient checkpointing works - accelerator.unwrap_model( - text_encoder_one - ).text_model.embeddings.requires_grad_(True) - accelerator.unwrap_model( - text_encoder_two - ).text_model.embeddings.requires_grad_(True) + accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) + accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True) for step, batch in enumerate(train_dataloader): models_to_accumulate = [transformer] @@ -1814,11 +1620,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): tokens_two = tokenize_prompt(tokenizer_two, prompts) tokens_three = tokenize_prompt(tokenizer_three, prompts) prompt_embeds, pooled_prompt_embeds = encode_prompt( - text_encoders=[ - text_encoder_one, - text_encoder_two, - text_encoder_three, - ], + text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three], tokenizers=[None, None, None], prompt=prompts, max_sequence_length=args.max_sequence_length, @@ -1827,11 +1629,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): else: if args.train_text_encoder: prompt_embeds, pooled_prompt_embeds = encode_prompt( - text_encoders=[ - text_encoder_one, - text_encoder_two, - text_encoder_three, - ], + text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three], tokenizers=[None, None, tokenizer_three], prompt=args.instance_prompt, max_sequence_length=args.max_sequence_length, @@ -1840,9 +1638,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Convert images to latent space model_input = vae.encode(pixel_values).latent_dist.sample() - model_input = ( - model_input - vae.config.shift_factor - ) * vae.config.scaling_factor + model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor model_input = model_input.to(dtype=weight_dtype) # Sample noise that we'll add to the latents @@ -1859,15 +1655,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): mode_scale=args.mode_scale, ) indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() - timesteps = noise_scheduler_copy.timesteps[indices].to( - device=model_input.device - ) + timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) # Add noise according to flow matching. # zt = (1 - texp) * x + texp * z1 - sigmas = get_sigmas( - timesteps, n_dim=model_input.ndim, dtype=model_input.dtype - ) + sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise # Predict the noise residual @@ -1886,9 +1678,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss - weighting = compute_loss_weighting_for_sd3( - weighting_scheme=args.weighting_scheme, sigmas=sigmas - ) + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) # flow matching loss if args.precondition_outputs: @@ -1903,19 +1693,16 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Compute prior loss prior_loss = torch.mean( - ( - weighting.float() - * (model_pred_prior.float() - target_prior.float()) ** 2 - ).reshape(target_prior.shape[0], -1), + (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( + target_prior.shape[0], -1 + ), 1, ) prior_loss = prior_loss.mean() # Compute regular loss. loss = torch.mean( - ( - weighting.float() * (model_pred.float() - target.float()) ** 2 - ).reshape(target.shape[0], -1), + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1, ) loss = loss.mean() @@ -1928,9 +1715,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if accelerator.sync_gradients: params_to_clip = ( itertools.chain( - transformer_lora_parameters, - text_lora_parameters_one, - text_lora_parameters_two, + transformer_lora_parameters, text_lora_parameters_one, text_lora_parameters_two ) if args.train_text_encoder else transformer_lora_parameters @@ -1951,36 +1736,24 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: checkpoints = os.listdir(args.output_dir) - checkpoints = [ - d for d in checkpoints if d.startswith("checkpoint") - ] - checkpoints = sorted( - checkpoints, key=lambda x: int(x.split("-")[1]) - ) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints if len(checkpoints) >= args.checkpoints_total_limit: - num_to_remove = ( - len(checkpoints) - args.checkpoints_total_limit + 1 - ) + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 removing_checkpoints = checkpoints[0:num_to_remove] logger.info( f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" ) - logger.info( - f"removing checkpoints: {', '.join(removing_checkpoints)}" - ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") for removing_checkpoint in removing_checkpoints: - removing_checkpoint = os.path.join( - args.output_dir, removing_checkpoint - ) + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) shutil.rmtree(removing_checkpoint) - save_path = os.path.join( - args.output_dir, f"checkpoint-{global_step}" - ) + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") @@ -1992,18 +1765,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): break if accelerator.is_main_process: - if ( - args.validation_prompt is not None - and epoch % args.validation_epochs == 0 - ): + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: if not args.train_text_encoder: # create pipeline - text_encoder_one, text_encoder_two, text_encoder_three = ( - load_text_encoders( - text_encoder_cls_one, - text_encoder_cls_two, - text_encoder_cls_three, - ) + text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders( + text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three ) pipeline = StableDiffusion3Pipeline.from_pretrained( args.pretrained_model_name_or_path, @@ -2027,9 +1793,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) objs = [] if not args.train_text_encoder: - objs.extend( - [text_encoder_one, text_encoder_two, text_encoder_three] - ) + objs.extend([text_encoder_one, text_encoder_two, text_encoder_three]) clear_objs_and_retain_memory(objs=objs) @@ -2042,13 +1806,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.train_text_encoder: text_encoder_one = unwrap_model(text_encoder_one) - text_encoder_lora_layers = get_peft_model_state_dict( - text_encoder_one.to(torch.float32) - ) + text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32)) text_encoder_two = unwrap_model(text_encoder_two) - text_encoder_2_lora_layers = get_peft_model_state_dict( - text_encoder_two.to(torch.float32) - ) + text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two.to(torch.float32)) else: text_encoder_lora_layers = None text_encoder_2_lora_layers = None diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 05218041750c..016464165c44 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -33,11 +33,7 @@ import transformers from accelerate import Accelerator from accelerate.logging import get_logger -from accelerate.utils import ( - DistributedDataParallelKwargs, - ProjectConfiguration, - set_seed, -) +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from huggingface_hub import create_repo, hf_hub_download, upload_folder from huggingface_hub.utils import insecure_hashlib from packaging import version @@ -64,11 +60,7 @@ ) from diffusers.loaders import StableDiffusionLoraLoaderMixin from diffusers.optimization import get_scheduler -from diffusers.training_utils import ( - _set_state_dict_into_text_encoder, - cast_training_params, - compute_snr, -) +from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr from diffusers.utils import ( check_min_version, convert_all_state_dict_to_peft, @@ -81,6 +73,7 @@ from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.torch_utils import is_compiled_module + if is_wandb_available(): import wandb @@ -96,9 +89,7 @@ def determine_scheduler_type(pretrained_model_name_or_path, revision): model_index = os.path.join(pretrained_model_name_or_path, model_index_filename) else: model_index = hf_hub_download( - repo_id=pretrained_model_name_or_path, - filename=model_index_filename, - revision=revision, + repo_id=pretrained_model_name_or_path, filename=model_index_filename, revision=revision ) with open(model_index, "r") as f: @@ -122,10 +113,7 @@ def save_model_card( for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) widget_dict.append( - { - "text": validation_prompt if validation_prompt else " ", - "output": {"url": f"image_{i}.png"}, - } + {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} ) model_description = f""" @@ -163,11 +151,7 @@ def save_model_card( model_card = load_or_create_model_card( repo_id_or_path=repo_id, from_training=True, - license=( - "openrail++" - if "playground" not in base_model - else "playground-v2dot5-community" - ), + license="openrail++" if "playground" not in base_model else "playground-v2dot5-community", base_model=base_model, prompt=instance_prompt, model_description=model_description, @@ -216,34 +200,22 @@ def log_validation( scheduler_args["variance_type"] = variance_type - pipeline.scheduler = DPMSolverMultistepScheduler.from_config( - pipeline.scheduler.config, **scheduler_args - ) + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) pipeline.set_progress_bar_config(disable=True) # run inference - generator = ( - torch.Generator(device=accelerator.device).manual_seed(args.seed) - if args.seed - else None - ) + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None # Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better # way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051 - if ( - torch.backends.mps.is_available() - or "playground" in args.pretrained_model_name_or_path - ): + if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path: autocast_ctx = nullcontext() else: autocast_ctx = torch.autocast(accelerator.device.type) with autocast_ctx: - images = [ - pipeline(**pipeline_args, generator=generator).images[0] - for _ in range(args.num_validation_images) - ] + images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" @@ -254,8 +226,7 @@ def log_validation( tracker.log( { phase_name: [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") - for i, image in enumerate(images) + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) ] } ) @@ -360,12 +331,7 @@ def parse_args(input_args=None): help="The column of the dataset containing the instance prompt for each image", ) - parser.add_argument( - "--repeats", - type=int, - default=1, - help="How many times to repeat the training data.", - ) + parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") parser.add_argument( "--class_data_dir", @@ -420,12 +386,7 @@ def parse_args(input_args=None): action="store_true", help="Flag to add prior preservation loss.", ) - parser.add_argument( - "--prior_loss_weight", - type=float, - default=1.0, - help="The weight of prior preservation loss.", - ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") parser.add_argument( "--num_class_images", type=int, @@ -446,9 +407,7 @@ def parse_args(input_args=None): action="store_true", help="Flag to additionally generate final state dict in the Kohya format so that it becomes compatible with A111, Comfy, Kohya, etc.", ) - parser.add_argument( - "--seed", type=int, default=None, help="A seed for reproducible training." - ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") parser.add_argument( "--resolution", type=int, @@ -478,16 +437,10 @@ def parse_args(input_args=None): help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", ) parser.add_argument( - "--train_batch_size", - type=int, - default=4, - help="Batch size (per device) for the training dataloader.", + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) parser.add_argument( - "--sample_batch_size", - type=int, - default=4, - help="Batch size (per device) for sampling images.", + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." ) parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( @@ -569,10 +522,7 @@ def parse_args(input_args=None): "More details here: https://arxiv.org/abs/2303.09556.", ) parser.add_argument( - "--lr_warmup_steps", - type=int, - default=500, - help="Number of steps for the warmup in the lr scheduler.", + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( "--lr_num_cycles", @@ -580,12 +530,7 @@ def parse_args(input_args=None): default=1, help="Number of hard resets of the lr in cosine_with_restarts scheduler.", ) - parser.add_argument( - "--lr_power", - type=float, - default=1.0, - help="Power factor of the polynomial scheduler.", - ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") parser.add_argument( "--dataloader_num_workers", type=int, @@ -609,16 +554,10 @@ def parse_args(input_args=None): ) parser.add_argument( - "--adam_beta1", - type=float, - default=0.9, - help="The beta1 parameter for the Adam and Prodigy optimizers.", + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." ) parser.add_argument( - "--adam_beta2", - type=float, - default=0.999, - help="The beta2 parameter for the Adam and Prodigy optimizers.", + "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." ) parser.add_argument( "--prodigy_beta3", @@ -627,23 +566,10 @@ def parse_args(input_args=None): help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " "uses the value of square root of beta2. Ignored if optimizer is adamW", ) + parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") parser.add_argument( - "--prodigy_decouple", - type=bool, - default=True, - help="Use AdamW style decoupled weight decay", - ) - parser.add_argument( - "--adam_weight_decay", - type=float, - default=1e-04, - help="Weight decay to use for unet params", - ) - parser.add_argument( - "--adam_weight_decay_text_encoder", - type=float, - default=1e-03, - help="Weight decay to use for text_encoder", + "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" ) parser.add_argument( @@ -666,20 +592,9 @@ def parse_args(input_args=None): help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " "Ignored if optimizer is adamW", ) - parser.add_argument( - "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." - ) - parser.add_argument( - "--push_to_hub", - action="store_true", - help="Whether or not to push the model to the Hub.", - ) - parser.add_argument( - "--hub_token", - type=str, - default=None, - help="The token to use to push to the Model Hub.", - ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") parser.add_argument( "--hub_model_id", type=str, @@ -733,16 +648,9 @@ def parse_args(input_args=None): " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." ), ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument( - "--local_rank", - type=int, - default=-1, - help="For distributed training: local_rank", - ) - parser.add_argument( - "--enable_xformers_memory_efficient_attention", - action="store_true", - help="Whether or not to use xformers.", + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) parser.add_argument( "--rank", @@ -769,9 +677,7 @@ def parse_args(input_args=None): raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") if args.dataset_name is not None and args.instance_data_dir is not None: - raise ValueError( - "Specify only one of `--dataset_name` or `--instance_data_dir`" - ) + raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: @@ -785,13 +691,9 @@ def parse_args(input_args=None): else: # logger is not available yet if args.class_data_dir is not None: - warnings.warn( - "You need not use --class_data_dir without --with_prior_preservation." - ) + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") if args.class_prompt is not None: - warnings.warn( - "You need not use --class_prompt without --with_prior_preservation." - ) + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") return args @@ -870,17 +772,13 @@ def __init__( # create final list of captions according to --repeats self.custom_instance_prompts = [] for caption in custom_instance_prompts: - self.custom_instance_prompts.extend( - itertools.repeat(caption, repeats) - ) + self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) else: self.instance_data_root = Path(instance_data_root) if not self.instance_data_root.exists(): raise ValueError("Instance images root doesn't exists.") - instance_images = [ - Image.open(path) for path in list(Path(instance_data_root).iterdir()) - ] + instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] self.custom_instance_prompts = None self.instance_images = [] @@ -891,12 +789,8 @@ def __init__( self.original_sizes = [] self.crop_top_lefts = [] self.pixel_values = [] - train_resize = transforms.Resize( - size, interpolation=transforms.InterpolationMode.BILINEAR - ) - train_crop = ( - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) - ) + train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) train_flip = transforms.RandomHorizontalFlip(p=1.0) train_transforms = transforms.Compose( [ @@ -918,9 +812,7 @@ def __init__( x1 = max(0, int(round((image.width - args.resolution) / 2.0))) image = train_crop(image) else: - y1, x1, h, w = train_crop.get_params( - image, (args.resolution, args.resolution) - ) + y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) image = crop(image, y1, x1, h, w) crop_top_left = (y1, x1) self.crop_top_lefts.append(crop_top_left) @@ -944,14 +836,8 @@ def __init__( self.image_transforms = transforms.Compose( [ - transforms.Resize( - size, interpolation=transforms.InterpolationMode.BILINEAR - ), - ( - transforms.CenterCrop(size) - if center_crop - else transforms.RandomCrop(size) - ), + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] @@ -980,9 +866,7 @@ def __getitem__(self, index): example["instance_prompt"] = self.instance_prompt if self.class_data_root: - class_image = Image.open( - self.class_images_path[index % self.num_class_images] - ) + class_image = Image.open(self.class_images_path[index % self.num_class_images]) class_image = exif_transpose(class_image) if not class_image.mode == "RGB": @@ -1061,9 +945,7 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None): text_input_ids = text_input_ids_list[i] prompt_embeds = text_encoder( - text_input_ids.to(text_encoder.device), - output_hidden_states=True, - return_dict=False, + text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False ) # We are only ALWAYS interested in the pooled output of the final text encoder @@ -1086,9 +968,7 @@ def main(args): ) if args.do_edm_style_training and args.snr_gamma is not None: - raise ValueError( - "Min-SNR formulation is not supported when conducting EDM-style training." - ) + raise ValueError("Min-SNR formulation is not supported when conducting EDM-style training.") if torch.backends.mps.is_available() and args.mixed_precision == "bf16": # due to pytorch#99272, MPS does not yet support bfloat16. @@ -1098,9 +978,7 @@ def main(args): logging_dir = Path(args.output_dir, args.logging_dir) - accelerator_project_config = ProjectConfiguration( - project_dir=args.output_dir, logging_dir=logging_dir - ) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -1116,9 +994,7 @@ def main(args): if args.report_to == "wandb": if not is_wandb_available(): - raise ImportError( - "Make sure to install wandb if you want to use it for logging during training." - ) + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") # Make one log on every process with the configuration for debugging. logging.basicConfig( @@ -1146,12 +1022,8 @@ def main(args): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - has_supported_fp16_accelerator = ( - torch.cuda.is_available() or torch.backends.mps.is_available() - ) - torch_dtype = ( - torch.float16 if has_supported_fp16_accelerator else torch.float32 - ) + has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available() + torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32 if args.prior_generation_precision == "fp32": torch_dtype = torch.float32 elif args.prior_generation_precision == "fp16": @@ -1170,26 +1042,19 @@ def main(args): logger.info(f"Number of class images to sample: {num_new_images}.") sample_dataset = PromptDataset(args.class_prompt, num_new_images) - sample_dataloader = torch.utils.data.DataLoader( - sample_dataset, batch_size=args.sample_batch_size - ) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) sample_dataloader = accelerator.prepare(sample_dataloader) pipeline.to(accelerator.device) for example in tqdm( - sample_dataloader, - desc="Generating class images", - disable=not accelerator.is_local_main_process, + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): images = pipeline(example["prompt"]).images for i, image in enumerate(images): hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() - image_filename = ( - class_images_dir - / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" - ) + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" image.save(image_filename) del pipeline @@ -1203,9 +1068,7 @@ def main(args): if args.push_to_hub: repo_id = create_repo( - repo_id=args.hub_model_id or Path(args.output_dir).name, - exist_ok=True, - token=args.hub_token, + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token ).repo_id # Load the tokenizers @@ -1231,14 +1094,10 @@ def main(args): ) # Load scheduler and models - scheduler_type = determine_scheduler_type( - args.pretrained_model_name_or_path, args.revision - ) + scheduler_type = determine_scheduler_type(args.pretrained_model_name_or_path, args.revision) if "EDM" in scheduler_type: args.do_edm_style_training = True - noise_scheduler = EDMEulerScheduler.from_pretrained( - args.pretrained_model_name_or_path, subfolder="scheduler" - ) + noise_scheduler = EDMEulerScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") logger.info("Performing EDM-style training!") elif args.do_edm_style_training: noise_scheduler = EulerDiscreteScheduler.from_pretrained( @@ -1246,21 +1105,13 @@ def main(args): ) logger.info("Performing EDM-style training!") else: - noise_scheduler = DDPMScheduler.from_pretrained( - args.pretrained_model_name_or_path, subfolder="scheduler" - ) + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") text_encoder_one = text_encoder_cls_one.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="text_encoder", - revision=args.revision, - variant=args.variant, + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant ) text_encoder_two = text_encoder_cls_two.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="text_encoder_2", - revision=args.revision, - variant=args.variant, + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant ) vae_path = ( args.pretrained_model_name_or_path @@ -1280,10 +1131,7 @@ def main(args): latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1) unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="unet", - revision=args.revision, - variant=args.variant, + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) # We only train the additional adapter LoRA layers @@ -1327,9 +1175,7 @@ def main(args): ) unet.enable_xformers_memory_efficient_attention() else: - raise ValueError( - "xformers is not available. Make sure it is installed correctly" - ) + raise ValueError("xformers is not available. Make sure it is installed correctly") if args.gradient_checkpointing: unet.enable_gradient_checkpointing() @@ -1376,20 +1222,14 @@ def save_model_hook(models, weights, output_dir): for model in models: if isinstance(model, type(unwrap_model(unet))): - unet_lora_layers_to_save = convert_state_dict_to_diffusers( - get_peft_model_state_dict(model) - ) + unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) elif isinstance(model, type(unwrap_model(text_encoder_one))): - text_encoder_one_lora_layers_to_save = ( - convert_state_dict_to_diffusers( - get_peft_model_state_dict(model) - ) + text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( + get_peft_model_state_dict(model) ) elif isinstance(model, type(unwrap_model(text_encoder_two))): - text_encoder_two_lora_layers_to_save = ( - convert_state_dict_to_diffusers( - get_peft_model_state_dict(model) - ) + text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers( + get_peft_model_state_dict(model) ) else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1421,19 +1261,11 @@ def load_model_hook(models, input_dir): else: raise ValueError(f"unexpected save model: {model.__class__}") - lora_state_dict, network_alphas = ( - StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) - ) + lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) - unet_state_dict = { - f'{k.replace("unet.", "")}': v - for k, v in lora_state_dict.items() - if k.startswith("unet.") - } + unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) - incompatible_keys = set_peft_model_state_dict( - unet_, unet_state_dict, adapter_name="default" - ) + incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") if incompatible_keys is not None: # check only for unexpected keys unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) @@ -1445,14 +1277,10 @@ def load_model_hook(models, input_dir): if args.train_text_encoder: # Do we need to call `scale_lora_layers()` here? - _set_state_dict_into_text_encoder( - lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_ - ) + _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) _set_state_dict_into_text_encoder( - lora_state_dict, - prefix="text_encoder_2.", - text_encoder=text_encoder_two_, + lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_ ) # Make sure the trainable params are in float32. This is again needed since the base models @@ -1475,10 +1303,7 @@ def load_model_hook(models, input_dir): if args.scale_lr: args.learning_rate = ( - args.learning_rate - * args.gradient_accumulation_steps - * args.train_batch_size - * accelerator.num_processes + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) # Make sure the trainable params are in float32. @@ -1493,18 +1318,11 @@ def load_model_hook(models, input_dir): unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters())) if args.train_text_encoder: - text_lora_parameters_one = list( - filter(lambda p: p.requires_grad, text_encoder_one.parameters()) - ) - text_lora_parameters_two = list( - filter(lambda p: p.requires_grad, text_encoder_two.parameters()) - ) + text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters())) + text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters())) # Optimization parameters - unet_lora_parameters_with_lr = { - "params": unet_lora_parameters, - "lr": args.learning_rate, - } + unet_lora_parameters_with_lr = {"params": unet_lora_parameters, "lr": args.learning_rate} if args.train_text_encoder: # different learning rate for text encoder and unet text_lora_parameters_one_with_lr = { @@ -1563,9 +1381,7 @@ def load_model_hook(models, input_dir): try: import prodigyopt except ImportError: - raise ImportError( - "To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`" - ) + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") optimizer_class = prodigyopt.Prodigy @@ -1635,9 +1451,7 @@ def compute_time_ids(original_size, crops_coords_top_left): def compute_text_embeddings(prompt, text_encoders, tokenizers): with torch.no_grad(): - prompt_embeds, pooled_prompt_embeds = encode_prompt( - text_encoders, tokenizers, prompt - ) + prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt) prompt_embeds = prompt_embeds.to(accelerator.device) pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) return prompt_embeds, pooled_prompt_embeds @@ -1646,15 +1460,15 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid # the redundant encoding. if not args.train_text_encoder and not train_dataset.custom_instance_prompts: - instance_prompt_hidden_states, instance_pooled_prompt_embeds = ( - compute_text_embeddings(args.instance_prompt, text_encoders, tokenizers) + instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings( + args.instance_prompt, text_encoders, tokenizers ) # Handle class prompt for prior-preservation. if args.with_prior_preservation: if not args.train_text_encoder: - class_prompt_hidden_states, class_pooled_prompt_embeds = ( - compute_text_embeddings(args.class_prompt, text_encoders, tokenizers) + class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings( + args.class_prompt, text_encoders, tokenizers ) # Clear the memory here @@ -1673,12 +1487,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): prompt_embeds = instance_prompt_hidden_states unet_add_text_embeds = instance_pooled_prompt_embeds if args.with_prior_preservation: - prompt_embeds = torch.cat( - [prompt_embeds, class_prompt_hidden_states], dim=0 - ) - unet_add_text_embeds = torch.cat( - [unet_add_text_embeds, class_pooled_prompt_embeds], dim=0 - ) + prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) + unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0) # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the # batch prompts on all training steps else: @@ -1692,9 +1502,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Scheduler and math around the number of training steps. overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil( - len(train_dataloader) / args.gradient_accumulation_steps - ) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True @@ -1710,20 +1518,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Prepare everything with our `accelerator`. if args.train_text_encoder: - ( - unet, - text_encoder_one, - text_encoder_two, - optimizer, - train_dataloader, - lr_scheduler, - ) = accelerator.prepare( - unet, - text_encoder_one, - text_encoder_two, - optimizer, - train_dataloader, - lr_scheduler, + unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler ) else: unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( @@ -1731,9 +1527,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil( - len(train_dataloader) / args.gradient_accumulation_steps - ) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs @@ -1750,20 +1544,14 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): accelerator.init_trackers(tracker_name, config=vars(args)) # Train! - total_batch_size = ( - args.train_batch_size - * accelerator.num_processes - * args.gradient_accumulation_steps - ) + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num batches each epoch = {len(train_dataloader)}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") - logger.info( - f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" - ) + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") global_step = 0 @@ -1824,12 +1612,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): text_encoder_two.train() # set top parameter requires_grad = True for gradient checkpointing works - accelerator.unwrap_model( - text_encoder_one - ).text_model.embeddings.requires_grad_(True) - accelerator.unwrap_model( - text_encoder_two - ).text_model.embeddings.requires_grad_(True) + accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) + accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True) for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): @@ -1854,17 +1638,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.pretrained_vae_model_name_or_path is None: model_input = model_input.to(weight_dtype) else: - latents_mean = latents_mean.to( - device=model_input.device, dtype=model_input.dtype - ) - latents_std = latents_std.to( - device=model_input.device, dtype=model_input.dtype - ) - model_input = ( - (model_input - latents_mean) - * vae.config.scaling_factor - / latents_std - ) + latents_mean = latents_mean.to(device=model_input.device, dtype=model_input.dtype) + latents_std = latents_std.to(device=model_input.device, dtype=model_input.dtype) + model_input = (model_input - latents_mean) * vae.config.scaling_factor / latents_std model_input = model_input.to(dtype=weight_dtype) # Sample noise that we'll add to the latents @@ -1874,39 +1650,26 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Sample a random timestep for each image if not args.do_edm_style_training: timesteps = torch.randint( - 0, - noise_scheduler.config.num_train_timesteps, - (bsz,), - device=model_input.device, + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device ) timesteps = timesteps.long() else: # in EDM formulation, the model is conditioned on the pre-conditioned noise levels # instead of discrete timesteps, so here we sample indices to get the noise levels # from `scheduler.timesteps` - indices = torch.randint( - 0, noise_scheduler.config.num_train_timesteps, (bsz,) - ) - timesteps = noise_scheduler.timesteps[indices].to( - device=model_input.device - ) + indices = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,)) + timesteps = noise_scheduler.timesteps[indices].to(device=model_input.device) # Add noise to the model input according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_model_input = noise_scheduler.add_noise( - model_input, noise, timesteps - ) + noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) # For EDM-style training, we first obtain the sigmas based on the continuous timesteps. # We then precondition the final model inputs based on these sigmas instead of the timesteps. # Follow: Section 5 of https://arxiv.org/abs/2206.00364. if args.do_edm_style_training: - sigmas = get_sigmas( - timesteps, len(noisy_model_input.shape), noisy_model_input.dtype - ) + sigmas = get_sigmas(timesteps, len(noisy_model_input.shape), noisy_model_input.dtype) if "EDM" in scheduler_type: - inp_noisy_latents = noise_scheduler.precondition_inputs( - noisy_model_input, sigmas - ) + inp_noisy_latents = noise_scheduler.precondition_inputs(noisy_model_input, sigmas) else: inp_noisy_latents = noisy_model_input / ((sigmas**2 + 1) ** 0.5) @@ -1914,17 +1677,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): add_time_ids = torch.cat( [ compute_time_ids(original_size=s, crops_coords_top_left=c) - for s, c in zip( - batch["original_sizes"], batch["crop_top_lefts"] - ) + for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"]) ] ) # Calculate the elements to repeat depending on the use of prior-preservation and custom captions. if not train_dataset.custom_instance_prompts: - elems_to_repeat_text_embeds = ( - bsz // 2 if args.with_prior_preservation else bsz - ) + elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz else: elems_to_repeat_text_embeds = 1 @@ -1932,19 +1691,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if not args.train_text_encoder: unet_added_conditions = { "time_ids": add_time_ids, - "text_embeds": unet_add_text_embeds.repeat( - elems_to_repeat_text_embeds, 1 - ), + "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1), } - prompt_embeds_input = prompt_embeds.repeat( - elems_to_repeat_text_embeds, 1, 1 - ) + prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) model_pred = unet( - ( - inp_noisy_latents - if args.do_edm_style_training - else noisy_model_input - ), + inp_noisy_latents if args.do_edm_style_training else noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions, @@ -1959,21 +1710,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): text_input_ids_list=[tokens_one, tokens_two], ) unet_added_conditions.update( - { - "text_embeds": pooled_prompt_embeds.repeat( - elems_to_repeat_text_embeds, 1 - ) - } - ) - prompt_embeds_input = prompt_embeds.repeat( - elems_to_repeat_text_embeds, 1, 1 + {"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)} ) + prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) model_pred = unet( - ( - inp_noisy_latents - if args.do_edm_style_training - else noisy_model_input - ), + inp_noisy_latents if args.do_edm_style_training else noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions, @@ -1986,16 +1727,14 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # on noised model inputs (before preconditioning) and the sigmas. # Follow: Section 5 of https://arxiv.org/abs/2206.00364. if "EDM" in scheduler_type: - model_pred = noise_scheduler.precondition_outputs( - noisy_model_input, model_pred, sigmas - ) + model_pred = noise_scheduler.precondition_outputs(noisy_model_input, model_pred, sigmas) else: if noise_scheduler.config.prediction_type == "epsilon": model_pred = model_pred * (-sigmas) + noisy_model_input elif noise_scheduler.config.prediction_type == "v_prediction": - model_pred = model_pred * ( - -sigmas / (sigmas**2 + 1) ** 0.5 - ) + (noisy_model_input / (sigmas**2 + 1)) + model_pred = model_pred * (-sigmas / (sigmas**2 + 1) ** 0.5) + ( + noisy_model_input / (sigmas**2 + 1) + ) # We are not doing weighting here because it tends result in numerical problems. # See: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051 # There might be other alternatives for weighting as well: @@ -2013,9 +1752,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): else noise_scheduler.get_velocity(model_input, noise, timesteps) ) else: - raise ValueError( - f"Unknown prediction type {noise_scheduler.config.prediction_type}" - ) + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") if args.with_prior_preservation: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. @@ -2025,44 +1762,33 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Compute prior loss if weighting is not None: prior_loss = torch.mean( - ( - weighting.float() - * (model_pred_prior.float() - target_prior.float()) ** 2 - ).reshape(target_prior.shape[0], -1), + (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( + target_prior.shape[0], -1 + ), 1, ) prior_loss = prior_loss.mean() else: - prior_loss = F.mse_loss( - model_pred_prior.float(), - target_prior.float(), - reduction="mean", - ) + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") if args.snr_gamma is None: if weighting is not None: loss = torch.mean( - ( - weighting.float() - * (model_pred.float() - target.float()) ** 2 - ).reshape(target.shape[0], -1), + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape( + target.shape[0], -1 + ), 1, ) loss = loss.mean() else: - loss = F.mse_loss( - model_pred.float(), target.float(), reduction="mean" - ) + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") else: # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) base_weight = ( - torch.stack( - [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 - ).min(dim=1)[0] - / snr + torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) if noise_scheduler.config.prediction_type == "v_prediction": @@ -2072,13 +1798,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Epsilon and sample both use the same loss weights. mse_loss_weights = base_weight - loss = F.mse_loss( - model_pred.float(), target.float(), reduction="none" - ) - loss = ( - loss.mean(dim=list(range(1, len(loss.shape)))) - * mse_loss_weights - ) + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean() if args.with_prior_preservation: @@ -2088,11 +1809,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator.backward(loss) if accelerator.sync_gradients: params_to_clip = ( - itertools.chain( - unet_lora_parameters, - text_lora_parameters_one, - text_lora_parameters_two, - ) + itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two) if args.train_text_encoder else unet_lora_parameters ) @@ -2112,36 +1829,24 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: checkpoints = os.listdir(args.output_dir) - checkpoints = [ - d for d in checkpoints if d.startswith("checkpoint") - ] - checkpoints = sorted( - checkpoints, key=lambda x: int(x.split("-")[1]) - ) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints if len(checkpoints) >= args.checkpoints_total_limit: - num_to_remove = ( - len(checkpoints) - args.checkpoints_total_limit + 1 - ) + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 removing_checkpoints = checkpoints[0:num_to_remove] logger.info( f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" ) - logger.info( - f"removing checkpoints: {', '.join(removing_checkpoints)}" - ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") for removing_checkpoint in removing_checkpoints: - removing_checkpoint = os.path.join( - args.output_dir, removing_checkpoint - ) + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) shutil.rmtree(removing_checkpoint) - save_path = os.path.join( - args.output_dir, f"checkpoint-{global_step}" - ) + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") @@ -2153,10 +1858,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): break if accelerator.is_main_process: - if ( - args.validation_prompt is not None - and epoch % args.validation_epochs == 0 - ): + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: # create pipeline if not args.train_text_encoder: text_encoder_one = text_encoder_cls_one.from_pretrained( @@ -2197,9 +1899,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if accelerator.is_main_process: unet = unwrap_model(unet) unet = unet.to(torch.float32) - unet_lora_layers = convert_state_dict_to_diffusers( - get_peft_model_state_dict(unet) - ) + unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet)) if args.train_text_encoder: text_encoder_one = unwrap_model(text_encoder_one) @@ -2221,15 +1921,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): text_encoder_2_lora_layers=text_encoder_2_lora_layers, ) if args.output_kohya_format: - lora_state_dict = load_file( - f"{args.output_dir}/pytorch_lora_weights.safetensors" - ) + lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors") peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict) kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict) - save_file( - kohya_state_dict, - f"{args.output_dir}/pytorch_lora_weights_kohya.safetensors", - ) + save_file(kohya_state_dict, f"{args.output_dir}/pytorch_lora_weights_kohya.safetensors") # Final inference # Load previous pipeline @@ -2254,10 +1949,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # run inference images = [] if args.validation_prompt and args.num_validation_images > 0: - pipeline_args = { - "prompt": args.validation_prompt, - "num_inference_steps": 25, - } + pipeline_args = {"prompt": args.validation_prompt, "num_inference_steps": 25} images = log_validation( pipeline, args, diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index 4095447d4765..c34024f478c1 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -32,11 +32,7 @@ import transformers from accelerate import Accelerator from accelerate.logging import get_logger -from accelerate.utils import ( - DistributedDataParallelKwargs, - ProjectConfiguration, - set_seed, -) +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder from huggingface_hub.utils import insecure_hashlib from PIL import Image @@ -45,13 +41,7 @@ from torchvision import transforms from torchvision.transforms.functional import crop from tqdm.auto import tqdm -from transformers import ( - CLIPTextModelWithProjection, - CLIPTokenizer, - PretrainedConfig, - T5EncoderModel, - T5TokenizerFast, -) +from transformers import CLIPTextModelWithProjection, CLIPTokenizer, PretrainedConfig, T5EncoderModel, T5TokenizerFast import diffusers from diffusers import ( @@ -61,14 +51,15 @@ StableDiffusion3Pipeline, ) from diffusers.optimization import get_scheduler -from diffusers.training_utils import ( - compute_density_for_timestep_sampling, - compute_loss_weighting_for_sd3, +from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3 +from diffusers.utils import ( + check_min_version, + is_wandb_available, ) -from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.torch_utils import is_compiled_module + if is_wandb_available(): import wandb @@ -92,10 +83,7 @@ def save_model_card( for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) widget_dict.append( - { - "text": validation_prompt if validation_prompt else " ", - "output": {"url": f"image_{i}.png"}, - } + {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} ) model_description = f""" @@ -152,22 +140,13 @@ def save_model_card( def load_text_encoders(class_one, class_two, class_three): text_encoder_one = class_one.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="text_encoder", - revision=args.revision, - variant=args.variant, + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant ) text_encoder_two = class_two.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="text_encoder_2", - revision=args.revision, - variant=args.variant, + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant ) text_encoder_three = class_three.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="text_encoder_3", - revision=args.revision, - variant=args.variant, + args.pretrained_model_name_or_path, subfolder="text_encoder_3", revision=args.revision, variant=args.variant ) return text_encoder_one, text_encoder_two, text_encoder_three @@ -189,19 +168,12 @@ def log_validation( pipeline.set_progress_bar_config(disable=True) # run inference - generator = ( - torch.Generator(device=accelerator.device).manual_seed(args.seed) - if args.seed - else None - ) + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() autocast_ctx = nullcontext() with autocast_ctx: - images = [ - pipeline(**pipeline_args, generator=generator).images[0] - for _ in range(args.num_validation_images) - ] + images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" @@ -212,8 +184,7 @@ def log_validation( tracker.log( { phase_name: [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") - for i, image in enumerate(images) + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) ] } ) @@ -311,12 +282,7 @@ def parse_args(input_args=None): help="The column of the dataset containing the instance prompt for each image", ) - parser.add_argument( - "--repeats", - type=int, - default=1, - help="How many times to repeat the training data.", - ) + parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") parser.add_argument( "--class_data_dir", @@ -371,12 +337,7 @@ def parse_args(input_args=None): action="store_true", help="Flag to add prior preservation loss.", ) - parser.add_argument( - "--prior_loss_weight", - type=float, - default=1.0, - help="The weight of prior preservation loss.", - ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") parser.add_argument( "--num_class_images", type=int, @@ -392,9 +353,7 @@ def parse_args(input_args=None): default="sd3-dreambooth", help="The output directory where the model predictions and checkpoints will be written.", ) - parser.add_argument( - "--seed", type=int, default=None, help="A seed for reproducible training." - ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") parser.add_argument( "--resolution", type=int, @@ -424,16 +383,10 @@ def parse_args(input_args=None): help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", ) parser.add_argument( - "--train_batch_size", - type=int, - default=4, - help="Batch size (per device) for the training dataloader.", + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) parser.add_argument( - "--sample_batch_size", - type=int, - default=4, - help="Batch size (per device) for sampling images.", + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." ) parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( @@ -507,10 +460,7 @@ def parse_args(input_args=None): ), ) parser.add_argument( - "--lr_warmup_steps", - type=int, - default=500, - help="Number of steps for the warmup in the lr scheduler.", + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( "--lr_num_cycles", @@ -518,12 +468,7 @@ def parse_args(input_args=None): default=1, help="Number of hard resets of the lr in cosine_with_restarts scheduler.", ) - parser.add_argument( - "--lr_power", - type=float, - default=1.0, - help="Power factor of the polynomial scheduler.", - ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") parser.add_argument( "--dataloader_num_workers", type=int, @@ -539,16 +484,10 @@ def parse_args(input_args=None): choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"], ) parser.add_argument( - "--logit_mean", - type=float, - default=0.0, - help="mean to use when using the `'logit_normal'` weighting scheme.", + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." ) parser.add_argument( - "--logit_std", - type=float, - default=1.0, - help="std to use when using the `'logit_normal'` weighting scheme.", + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." ) parser.add_argument( "--mode_scale", @@ -577,16 +516,10 @@ def parse_args(input_args=None): ) parser.add_argument( - "--adam_beta1", - type=float, - default=0.9, - help="The beta1 parameter for the Adam and Prodigy optimizers.", + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." ) parser.add_argument( - "--adam_beta2", - type=float, - default=0.999, - help="The beta2 parameter for the Adam and Prodigy optimizers.", + "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." ) parser.add_argument( "--prodigy_beta3", @@ -595,23 +528,10 @@ def parse_args(input_args=None): help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " "uses the value of square root of beta2. Ignored if optimizer is adamW", ) + parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") parser.add_argument( - "--prodigy_decouple", - type=bool, - default=True, - help="Use AdamW style decoupled weight decay", - ) - parser.add_argument( - "--adam_weight_decay", - type=float, - default=1e-04, - help="Weight decay to use for unet params", - ) - parser.add_argument( - "--adam_weight_decay_text_encoder", - type=float, - default=1e-03, - help="Weight decay to use for text_encoder", + "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" ) parser.add_argument( @@ -634,20 +554,9 @@ def parse_args(input_args=None): help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " "Ignored if optimizer is adamW", ) - parser.add_argument( - "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." - ) - parser.add_argument( - "--push_to_hub", - action="store_true", - help="Whether or not to push the model to the Hub.", - ) - parser.add_argument( - "--hub_token", - type=str, - default=None, - help="The token to use to push to the Model Hub.", - ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") parser.add_argument( "--hub_model_id", type=str, @@ -701,12 +610,7 @@ def parse_args(input_args=None): " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." ), ) - parser.add_argument( - "--local_rank", - type=int, - default=-1, - help="For distributed training: local_rank", - ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") if input_args is not None: args = parser.parse_args(input_args) @@ -717,9 +621,7 @@ def parse_args(input_args=None): raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") if args.dataset_name is not None and args.instance_data_dir is not None: - raise ValueError( - "Specify only one of `--dataset_name` or `--instance_data_dir`" - ) + raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: @@ -733,13 +635,9 @@ def parse_args(input_args=None): else: # logger is not available yet if args.class_data_dir is not None: - warnings.warn( - "You need not use --class_data_dir without --with_prior_preservation." - ) + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") if args.class_prompt is not None: - warnings.warn( - "You need not use --class_prompt without --with_prior_preservation." - ) + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") return args @@ -818,17 +716,13 @@ def __init__( # create final list of captions according to --repeats self.custom_instance_prompts = [] for caption in custom_instance_prompts: - self.custom_instance_prompts.extend( - itertools.repeat(caption, repeats) - ) + self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) else: self.instance_data_root = Path(instance_data_root) if not self.instance_data_root.exists(): raise ValueError("Instance images root doesn't exists.") - instance_images = [ - Image.open(path) for path in list(Path(instance_data_root).iterdir()) - ] + instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] self.custom_instance_prompts = None self.instance_images = [] @@ -836,12 +730,8 @@ def __init__( self.instance_images.extend(itertools.repeat(img, repeats)) self.pixel_values = [] - train_resize = transforms.Resize( - size, interpolation=transforms.InterpolationMode.BILINEAR - ) - train_crop = ( - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) - ) + train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) train_flip = transforms.RandomHorizontalFlip(p=1.0) train_transforms = transforms.Compose( [ @@ -862,9 +752,7 @@ def __init__( x1 = max(0, int(round((image.width - args.resolution) / 2.0))) image = train_crop(image) else: - y1, x1, h, w = train_crop.get_params( - image, (args.resolution, args.resolution) - ) + y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) image = crop(image, y1, x1, h, w) image = train_transforms(image) self.pixel_values.append(image) @@ -886,14 +774,8 @@ def __init__( self.image_transforms = transforms.Compose( [ - transforms.Resize( - size, interpolation=transforms.InterpolationMode.BILINEAR - ), - ( - transforms.CenterCrop(size) - if center_crop - else transforms.RandomCrop(size) - ), + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] @@ -918,9 +800,7 @@ def __getitem__(self, index): example["instance_prompt"] = self.instance_prompt if self.class_data_root: - class_image = Image.open( - self.class_images_path[index % self.num_class_images] - ) + class_image = Image.open(self.class_images_path[index % self.num_class_images]) class_image = exif_transpose(class_image) if not class_image.mode == "RGB": @@ -1083,8 +963,7 @@ def encode_prompt( ) clip_prompt_embeds = torch.nn.functional.pad( - clip_prompt_embeds, - (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]), + clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) ) prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) @@ -1106,9 +985,7 @@ def main(args): logging_dir = Path(args.output_dir, args.logging_dir) - accelerator_project_config = ProjectConfiguration( - project_dir=args.output_dir, logging_dir=logging_dir - ) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -1124,9 +1001,7 @@ def main(args): if args.report_to == "wandb": if not is_wandb_available(): - raise ImportError( - "Make sure to install wandb if you want to use it for logging during training." - ) + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") # Make one log on every process with the configuration for debugging. logging.basicConfig( @@ -1154,12 +1029,8 @@ def main(args): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - has_supported_fp16_accelerator = ( - torch.cuda.is_available() or torch.backends.mps.is_available() - ) - torch_dtype = ( - torch.float16 if has_supported_fp16_accelerator else torch.float32 - ) + has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available() + torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32 if args.prior_generation_precision == "fp32": torch_dtype = torch.float32 elif args.prior_generation_precision == "fp16": @@ -1178,26 +1049,19 @@ def main(args): logger.info(f"Number of class images to sample: {num_new_images}.") sample_dataset = PromptDataset(args.class_prompt, num_new_images) - sample_dataloader = torch.utils.data.DataLoader( - sample_dataset, batch_size=args.sample_batch_size - ) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) sample_dataloader = accelerator.prepare(sample_dataloader) pipeline.to(accelerator.device) for example in tqdm( - sample_dataloader, - desc="Generating class images", - disable=not accelerator.is_local_main_process, + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): images = pipeline(example["prompt"]).images for i, image in enumerate(images): hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() - image_filename = ( - class_images_dir - / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" - ) + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" image.save(image_filename) del pipeline @@ -1258,10 +1122,7 @@ def main(args): variant=args.variant, ) transformer = SD3Transformer2DModel.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="transformer", - revision=args.revision, - variant=args.variant, + args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant ) transformer.requires_grad_(True) @@ -1312,26 +1173,16 @@ def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: for i, model in enumerate(models): if isinstance(unwrap_model(model), SD3Transformer2DModel): - unwrap_model(model).save_pretrained( - os.path.join(output_dir, "transformer") - ) - elif isinstance( - unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel) - ): + unwrap_model(model).save_pretrained(os.path.join(output_dir, "transformer")) + elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)): if isinstance(unwrap_model(model), CLIPTextModelWithProjection): hidden_size = unwrap_model(model).config.hidden_size if hidden_size == 768: - unwrap_model(model).save_pretrained( - os.path.join(output_dir, "text_encoder") - ) + unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder")) elif hidden_size == 1280: - unwrap_model(model).save_pretrained( - os.path.join(output_dir, "text_encoder_2") - ) + unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder_2")) else: - unwrap_model(model).save_pretrained( - os.path.join(output_dir, "text_encoder_3") - ) + unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder_3")) else: raise ValueError(f"Wrong model supplied: {type(model)=}.") @@ -1345,39 +1196,27 @@ def load_model_hook(models, input_dir): # load diffusers style into model if isinstance(unwrap_model(model), SD3Transformer2DModel): - load_model = SD3Transformer2DModel.from_pretrained( - input_dir, subfolder="transformer" - ) + load_model = SD3Transformer2DModel.from_pretrained(input_dir, subfolder="transformer") model.register_to_config(**load_model.config) model.load_state_dict(load_model.state_dict()) - elif isinstance( - unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel) - ): + elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)): try: - load_model = CLIPTextModelWithProjection.from_pretrained( - input_dir, subfolder="text_encoder" - ) + load_model = CLIPTextModelWithProjection.from_pretrained(input_dir, subfolder="text_encoder") model(**load_model.config) model.load_state_dict(load_model.state_dict()) except Exception: try: - load_model = CLIPTextModelWithProjection.from_pretrained( - input_dir, subfolder="text_encoder_2" - ) + load_model = CLIPTextModelWithProjection.from_pretrained(input_dir, subfolder="text_encoder_2") model(**load_model.config) model.load_state_dict(load_model.state_dict()) except Exception: try: - load_model = T5EncoderModel.from_pretrained( - input_dir, subfolder="text_encoder_3" - ) + load_model = T5EncoderModel.from_pretrained(input_dir, subfolder="text_encoder_3") model(**load_model.config) model.load_state_dict(load_model.state_dict()) except Exception: - raise ValueError( - f"Couldn't load the model of type: ({type(model)})." - ) + raise ValueError(f"Couldn't load the model of type: ({type(model)}).") else: raise ValueError(f"Unsupported model found: {type(model)=}") @@ -1393,17 +1232,11 @@ def load_model_hook(models, input_dir): if args.scale_lr: args.learning_rate = ( - args.learning_rate - * args.gradient_accumulation_steps - * args.train_batch_size - * accelerator.num_processes + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) # Optimization parameters - transformer_parameters_with_lr = { - "params": transformer.parameters(), - "lr": args.learning_rate, - } + transformer_parameters_with_lr = {"params": transformer.parameters(), "lr": args.learning_rate} if args.train_text_encoder: # different learning rate for text encoder and unet text_parameters_one_with_lr = { @@ -1468,9 +1301,7 @@ def load_model_hook(models, input_dir): try: import prodigyopt except ImportError: - raise ImportError( - "To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`" - ) + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") optimizer_class = prodigyopt.Prodigy @@ -1539,15 +1370,15 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid # the redundant encoding. if not args.train_text_encoder and not train_dataset.custom_instance_prompts: - instance_prompt_hidden_states, instance_pooled_prompt_embeds = ( - compute_text_embeddings(args.instance_prompt, text_encoders, tokenizers) + instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings( + args.instance_prompt, text_encoders, tokenizers ) # Handle class prompt for prior-preservation. if args.with_prior_preservation: if not args.train_text_encoder: - class_prompt_hidden_states, class_pooled_prompt_embeds = ( - compute_text_embeddings(args.class_prompt, text_encoders, tokenizers) + class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings( + args.class_prompt, text_encoders, tokenizers ) # Clear the memory here @@ -1568,12 +1399,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): prompt_embeds = instance_prompt_hidden_states pooled_prompt_embeds = instance_pooled_prompt_embeds if args.with_prior_preservation: - prompt_embeds = torch.cat( - [prompt_embeds, class_prompt_hidden_states], dim=0 - ) - pooled_prompt_embeds = torch.cat( - [pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0 - ) + prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) + pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0) # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the # batch prompts on all training steps else: @@ -1590,9 +1417,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Scheduler and math around the number of training steps. overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil( - len(train_dataloader) / args.gradient_accumulation_steps - ) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True @@ -1631,9 +1456,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil( - len(train_dataloader) / args.gradient_accumulation_steps - ) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs @@ -1646,20 +1469,14 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): accelerator.init_trackers(tracker_name, config=vars(args)) # Train! - total_batch_size = ( - args.train_batch_size - * accelerator.num_processes - * args.gradient_accumulation_steps - ) + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num batches each epoch = {len(train_dataloader)}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") - logger.info( - f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" - ) + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") global_step = 0 @@ -1722,9 +1539,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): for step, batch in enumerate(train_dataloader): models_to_accumulate = [transformer] if args.train_text_encoder: - models_to_accumulate.extend( - [text_encoder_one, text_encoder_two, text_encoder_three] - ) + models_to_accumulate.extend([text_encoder_one, text_encoder_two, text_encoder_three]) with accelerator.accumulate(models_to_accumulate): pixel_values = batch["pixel_values"].to(dtype=vae.dtype) prompts = batch["prompts"] @@ -1742,9 +1557,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Convert images to latent space model_input = vae.encode(pixel_values).latent_dist.sample() - model_input = ( - model_input - vae.config.shift_factor - ) * vae.config.scaling_factor + model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor model_input = model_input.to(dtype=weight_dtype) # Sample noise that we'll add to the latents @@ -1761,15 +1574,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): mode_scale=args.mode_scale, ) indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() - timesteps = noise_scheduler_copy.timesteps[indices].to( - device=model_input.device - ) + timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) # Add noise according to flow matching. # zt = (1 - texp) * x + texp * z1 - sigmas = get_sigmas( - timesteps, n_dim=model_input.ndim, dtype=model_input.dtype - ) + sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise # Predict the noise residual @@ -1783,11 +1592,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): )[0] else: prompt_embeds, pooled_prompt_embeds = encode_prompt( - text_encoders=[ - text_encoder_one, - text_encoder_two, - text_encoder_three, - ], + text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three], tokenizers=None, prompt=None, text_input_ids_list=[tokens_one, tokens_two, tokens_three], @@ -1807,9 +1612,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss - weighting = compute_loss_weighting_for_sd3( - weighting_scheme=args.weighting_scheme, sigmas=sigmas - ) + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) # flow matching loss if args.precondition_outputs: @@ -1824,19 +1627,16 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Compute prior loss prior_loss = torch.mean( - ( - weighting.float() - * (model_pred_prior.float() - target_prior.float()) ** 2 - ).reshape(target_prior.shape[0], -1), + (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( + target_prior.shape[0], -1 + ), 1, ) prior_loss = prior_loss.mean() # Compute regular loss. loss = torch.mean( - ( - weighting.float() * (model_pred.float() - target.float()) ** 2 - ).reshape(target.shape[0], -1), + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1, ) loss = loss.mean() @@ -1873,36 +1673,24 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: checkpoints = os.listdir(args.output_dir) - checkpoints = [ - d for d in checkpoints if d.startswith("checkpoint") - ] - checkpoints = sorted( - checkpoints, key=lambda x: int(x.split("-")[1]) - ) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints if len(checkpoints) >= args.checkpoints_total_limit: - num_to_remove = ( - len(checkpoints) - args.checkpoints_total_limit + 1 - ) + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 removing_checkpoints = checkpoints[0:num_to_remove] logger.info( f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" ) - logger.info( - f"removing checkpoints: {', '.join(removing_checkpoints)}" - ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") for removing_checkpoint in removing_checkpoints: - removing_checkpoint = os.path.join( - args.output_dir, removing_checkpoint - ) + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) shutil.rmtree(removing_checkpoint) - save_path = os.path.join( - args.output_dir, f"checkpoint-{global_step}" - ) + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") @@ -1914,18 +1702,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): break if accelerator.is_main_process: - if ( - args.validation_prompt is not None - and epoch % args.validation_epochs == 0 - ): + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: # create pipeline if not args.train_text_encoder: - text_encoder_one, text_encoder_two, text_encoder_three = ( - load_text_encoders( - text_encoder_cls_one, - text_encoder_cls_two, - text_encoder_cls_three, - ) + text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders( + text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three ) pipeline = StableDiffusion3Pipeline.from_pretrained( args.pretrained_model_name_or_path,