diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 64fd0a6986ed..d54d9f1b2402 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -31,8 +31,6 @@ import numpy as np import torch import torch.nn.functional as F - -# imports of the TokenEmbeddingsHandler class import torch.utils.checkpoint import transformers from accelerate import Accelerator @@ -77,6 +75,9 @@ from diffusers.utils.torch_utils import is_compiled_module +if is_wandb_available(): + import wandb + # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.30.0.dev0") @@ -101,12 +102,12 @@ def save_model_card( repo_id: str, use_dora: bool, images=None, - base_model=str, + base_model: str = None, train_text_encoder=False, train_text_encoder_ti=False, token_abstraction_dict=None, - instance_prompt=str, - validation_prompt=str, + instance_prompt: str = None, + validation_prompt: str = None, repo_folder=None, vae_path=None, ): @@ -135,6 +136,14 @@ def save_model_card( diffusers_imports_pivotal = "" diffusers_example_pivotal = "" webui_example_pivotal = "" + license = "" + if "playground" in base_model: + license = """\n + ## License + + Please adhere to the licensing terms as described [here](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic/blob/main/LICENSE.md). + """ + if train_text_encoder_ti: trigger_str = ( "To trigger image generation of trained concept(or concepts) replace each concept identifier " @@ -223,11 +232,75 @@ def save_model_card( Special VAE used for training: {vae_path}. +{license} """ with open(os.path.join(repo_folder, "README.md"), "w") as f: f.write(yaml + model_card) +def log_validation( + pipeline, + args, + accelerator, + pipeline_args, + epoch, + is_final_validation=False, +): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if not args.do_edm_style_training: + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) + + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + # Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better + # way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051 + if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path: + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + + with autocast_ctx: + images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + phase_name: [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + ] + } + ) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return images + + def import_model_class_from_model_name_or_path( pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" ): @@ -390,6 +463,7 @@ def parse_args(input_args=None): ) parser.add_argument( "--do_edm_style_training", + default=False, action="store_true", help="Flag to conduct training using the EDM formulation as introduced in https://arxiv.org/abs/2206.00364.", ) @@ -571,7 +645,7 @@ def parse_args(input_args=None): parser.add_argument( "--optimizer", type=str, - default="adamW", + default="AdamW", help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), ) @@ -906,11 +980,6 @@ def __init__( instance_data_root, instance_prompt, class_prompt, - dataset_name, - dataset_config_name, - cache_dir, - image_column, - caption_column, train_text_encoder_ti, class_data_root=None, class_num=None, @@ -929,7 +998,7 @@ def __init__( self.train_text_encoder_ti = train_text_encoder_ti # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, # we load the training data using load_dataset - if dataset_name is not None: + if args.dataset_name is not None: try: from datasets import load_dataset except ImportError: @@ -942,25 +1011,26 @@ def __init__( # See more about loading custom images at # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script dataset = load_dataset( - dataset_name, - dataset_config_name, - cache_dir=cache_dir, + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, ) # Preprocessing the datasets. column_names = dataset["train"].column_names # 6. Get the column names for input/target. - if image_column is None: + if args.image_column is None: image_column = column_names[0] logger.info(f"image column defaulting to {image_column}") else: + image_column = args.image_column if image_column not in column_names: raise ValueError( - f"`--image_column` value '{image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" ) instance_images = dataset["train"][image_column] - if caption_column is None: + if args.caption_column is None: logger.info( "No caption column provided, defaulting to instance_prompt for all images. If your dataset " "contains captions/prompts for the images, make sure to specify the " @@ -968,11 +1038,11 @@ def __init__( ) self.custom_instance_prompts = None else: - if caption_column not in column_names: + if args.caption_column not in column_names: raise ValueError( - f"`--caption_column` value '{caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" ) - custom_instance_prompts = dataset["train"][caption_column] + custom_instance_prompts = dataset["train"][args.caption_column] # create final list of captions according to --repeats self.custom_instance_prompts = [] for caption in custom_instance_prompts: @@ -1178,13 +1248,12 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None): text_input_ids = text_input_ids_list[i] prompt_embeds = text_encoder( - text_input_ids.to(text_encoder.device), - output_hidden_states=True, + text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False ) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] - prompt_embeds = prompt_embeds.hidden_states[-2] + prompt_embeds = prompt_embeds[-1][-2] bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) prompt_embeds_list.append(prompt_embeds) @@ -1200,9 +1269,16 @@ def main(args): "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." " Please use `huggingface-cli login` to authenticate with the Hub." ) + if args.do_edm_style_training and args.snr_gamma is not None: raise ValueError("Min-SNR formulation is not supported when conducting EDM-style training.") + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + logging_dir = Path(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) @@ -1215,10 +1291,13 @@ def main(args): kwargs_handlers=[kwargs], ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") - import wandb # Make one log on every process with the configuration for debugging. logging.basicConfig( @@ -1246,7 +1325,8 @@ def main(args): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 + has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available() + torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32 if args.prior_generation_precision == "fp32": torch_dtype = torch.float32 elif args.prior_generation_precision == "fp16": @@ -1404,6 +1484,12 @@ def main(args): elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + # Move unet, vae and text_encoder to device and cast to weight_dtype unet.to(accelerator.device, dtype=weight_dtype) @@ -1508,15 +1594,13 @@ def save_model_hook(models, weights, output_dir): if isinstance(model, type(unwrap_model(unet))): unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) elif isinstance(model, type(unwrap_model(text_encoder_one))): - if args.train_text_encoder: - text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( - get_peft_model_state_dict(model) - ) + text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( + get_peft_model_state_dict(model) + ) elif isinstance(model, type(unwrap_model(text_encoder_two))): - if args.train_text_encoder: - text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers( - get_peft_model_state_dict(model) - ) + text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers( + get_peft_model_state_dict(model) + ) else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1564,6 +1648,7 @@ def load_model_hook(models, input_dir): ) if args.train_text_encoder: + # Do we need to call `scale_lora_layers()` here? _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) _set_state_dict_into_text_encoder( @@ -1578,14 +1663,14 @@ def load_model_hook(models, input_dir): if args.train_text_encoder: models.extend([text_encoder_one_, text_encoder_two_]) # only upcast trainable parameters (LoRA) into fp32 - cast_training_params(models) + cast_training_params(models) accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices - if args.allow_tf32: + if args.allow_tf32 and torch.cuda.is_available(): torch.backends.cuda.matmul.allow_tf32 = True if args.scale_lr: @@ -1711,12 +1796,7 @@ def load_model_hook(models, input_dir): instance_data_root=args.instance_data_dir, instance_prompt=args.instance_prompt, class_prompt=args.class_prompt, - dataset_name=args.dataset_name, - dataset_config_name=args.dataset_config_name, - cache_dir=args.cache_dir, - image_column=args.image_column, train_text_encoder_ti=args.train_text_encoder_ti, - caption_column=args.caption_column, class_data_root=args.class_data_dir if args.with_prior_preservation else None, token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None, class_num=args.num_class_images, @@ -1740,8 +1820,6 @@ def load_model_hook(models, input_dir): def compute_time_ids(crops_coords_top_left, original_size=None): # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids - if original_size is None: - original_size = (args.resolution, args.resolution) target_size = (args.resolution, args.resolution) add_time_ids = list(original_size + crops_coords_top_left + target_size) add_time_ids = torch.tensor([add_time_ids]) @@ -1778,7 +1856,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): if freeze_text_encoder and not train_dataset.custom_instance_prompts: del tokenizers, text_encoders gc.collect() - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't @@ -1946,8 +2025,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): text_encoder_two.train() # set top parameter requires_grad = True for gradient checkpointing works if args.train_text_encoder: - text_encoder_one.text_model.embeddings.requires_grad_(True) - text_encoder_two.text_model.embeddings.requires_grad_(True) + accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) + accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True) for step, batch in enumerate(train_dataloader): if pivoted: @@ -2040,7 +2119,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if freeze_text_encoder: unet_added_conditions = { "time_ids": add_time_ids, - # "time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1), "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1), } prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) @@ -2220,10 +2298,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if accelerator.is_main_process: if args.validation_prompt is not None and epoch % args.validation_epochs == 0: - logger.info( - f"Running validation... \n Generating {args.num_validation_images} images with prompt:" - f" {args.validation_prompt}." - ) # create pipeline if freeze_text_encoder: text_encoder_one = text_encoder_cls_one.from_pretrained( @@ -2250,70 +2324,29 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): variant=args.variant, torch_dtype=weight_dtype, ) - - # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it - scheduler_args = {} - - if not args.do_edm_style_training: - if "variance_type" in pipeline.scheduler.config: - variance_type = pipeline.scheduler.config.variance_type - - if variance_type in ["learned", "learned_range"]: - variance_type = "fixed_small" - - scheduler_args["variance_type"] = variance_type - - pipeline.scheduler = DPMSolverMultistepScheduler.from_config( - pipeline.scheduler.config, **scheduler_args - ) - - pipeline = pipeline.to(accelerator.device) - pipeline.set_progress_bar_config(disable=True) - - # run inference - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None pipeline_args = {"prompt": args.validation_prompt} - if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path: - autocast_ctx = nullcontext() - else: - autocast_ctx = torch.autocast(accelerator.device.type) - - with autocast_ctx: - images = [ - pipeline(**pipeline_args, generator=generator).images[0] - for _ in range(args.num_validation_images) - ] - for tracker in accelerator.trackers: - if tracker.name == "tensorboard": - np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") - if tracker.name == "wandb": - tracker.log( - { - "validation": [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") - for i, image in enumerate(images) - ] - } - ) - - del pipeline - torch.cuda.empty_cache() + images = log_validation( + pipeline, + args, + accelerator, + pipeline_args, + epoch, + ) # Save the lora layers accelerator.wait_for_everyone() if accelerator.is_main_process: - unet = accelerator.unwrap_model(unet) + unet = unwrap_model(unet) unet = unet.to(torch.float32) unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet)) if args.train_text_encoder: - text_encoder_one = accelerator.unwrap_model(text_encoder_one) + text_encoder_one = unwrap_model(text_encoder_one) text_encoder_lora_layers = convert_state_dict_to_diffusers( get_peft_model_state_dict(text_encoder_one.to(torch.float32)) ) - text_encoder_two = accelerator.unwrap_model(text_encoder_two) + text_encoder_two = unwrap_model(text_encoder_two) text_encoder_2_lora_layers = convert_state_dict_to_diffusers( get_peft_model_state_dict(text_encoder_two.to(torch.float32)) ) @@ -2332,85 +2365,39 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): embeddings_path = f"{args.output_dir}/{args.output_dir}_emb.safetensors" embedding_handler.save_embeddings(embeddings_path) + # Final inference + # Load previous pipeline + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipeline = StableDiffusionXLPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + + # load attention processors + pipeline.load_lora_weights(args.output_dir) + + # run inference images = [] if args.validation_prompt and args.num_validation_images > 0: - # Final inference - # Load previous pipeline - vae = AutoencoderKL.from_pretrained( - vae_path, - subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, - revision=args.revision, - variant=args.variant, - torch_dtype=weight_dtype, - ) - pipeline = StableDiffusionXLPipeline.from_pretrained( - args.pretrained_model_name_or_path, - vae=vae, - revision=args.revision, - variant=args.variant, - torch_dtype=weight_dtype, + pipeline_args = {"prompt": args.validation_prompt, "num_inference_steps": 25} + images = log_validation( + pipeline, + args, + accelerator, + pipeline_args, + epoch, + is_final_validation=True, ) - # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it - scheduler_args = {} - - if not args.do_edm_style_training: - if "variance_type" in pipeline.scheduler.config: - variance_type = pipeline.scheduler.config.variance_type - - if variance_type in ["learned", "learned_range"]: - variance_type = "fixed_small" - - scheduler_args["variance_type"] = variance_type - - pipeline.scheduler = DPMSolverMultistepScheduler.from_config( - pipeline.scheduler.config, **scheduler_args - ) - - # load attention processors - pipeline.load_lora_weights(args.output_dir) - - # load new tokens - if args.train_text_encoder_ti: - state_dict = load_file(embeddings_path) - all_new_tokens = [] - for key, value in token_abstraction_dict.items(): - all_new_tokens.extend(value) - pipeline.load_textual_inversion( - state_dict["clip_l"], - token=all_new_tokens, - text_encoder=pipeline.text_encoder, - tokenizer=pipeline.tokenizer, - ) - pipeline.load_textual_inversion( - state_dict["clip_g"], - token=all_new_tokens, - text_encoder=pipeline.text_encoder_2, - tokenizer=pipeline.tokenizer_2, - ) - - # run inference - pipeline = pipeline.to(accelerator.device) - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None - images = [ - pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] - for _ in range(args.num_validation_images) - ] - - for tracker in accelerator.trackers: - if tracker.name == "tensorboard": - np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") - if tracker.name == "wandb": - tracker.log( - { - "test": [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") - for i, image in enumerate(images) - ] - } - ) - # Convert to WebUI format lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors") peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict) @@ -2430,6 +2417,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): repo_folder=args.output_dir, vae_path=args.pretrained_vae_model_name_or_path, ) + if args.push_to_hub: upload_folder( repo_id=repo_id,