From 92814a33dffbb7c268da3528a9c0a0f1b4e4631e Mon Sep 17 00:00:00 2001 From: mrepetto Date: Tue, 13 Jun 2023 21:03:27 -0300 Subject: [PATCH 01/28] - Added validation parameters - Changed some parameter descriptions to better explain their use. - Fixed a few typos. - Added concept_list parameter for better management of multiple subjects - changed logic for image validation --- examples/dreambooth/train_dreambooth.py | 12 +- .../train_multi_subject_dreambooth.py | 340 +++++++++++++++--- setup.py | 3 +- 3 files changed, 293 insertions(+), 62 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 37b06acb6977..6bfc84ba7d9a 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -175,7 +175,8 @@ def log_validation( ) del pipeline - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() return images @@ -757,7 +758,7 @@ def main(args): gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=args.report_to, - logging_dir=logging_dir, + project_dir=logging_dir, project_config=accelerator_project_config, ) @@ -834,7 +835,10 @@ def main(args): image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" image.save(image_filename) + # Clean up the memory deleting one-time-use variables. del pipeline + del sample_dataloader + del sample_dataset if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -1094,7 +1098,7 @@ def compute_text_embeddings(prompt): args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) # We need to initialize the trackers we use, and also store our configuration. - # The trackers initializes automatically on the main process. + # The trackers initialize automatically on the main process. if accelerator.is_main_process: accelerator.init_trackers("dreambooth", config=vars(args)) @@ -1266,7 +1270,7 @@ def compute_text_embeddings(prompt): if global_step >= args.max_train_steps: break - # Create the pipeline using using the trained modules and save it. + # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: pipeline_args = {} diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py index a1016b50e7b2..67729806e2d5 100644 --- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py +++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py @@ -1,13 +1,16 @@ import argparse import hashlib import itertools +import json import logging import math -import os import warnings +from os import environ, listdir, makedirs +from os.path import basename, join, normpath from pathlib import Path import datasets +import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint @@ -23,11 +26,14 @@ from transformers import AutoTokenizer, PretrainedConfig import diffusers -from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel +from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, DPMSolverMultistepScheduler, UNet2DConditionModel from diffusers.optimization import get_scheduler -from diffusers.utils import check_min_version +from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available +if is_wandb_available(): + import wandb + # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.13.0.dev0") @@ -35,6 +41,100 @@ logger = get_logger(__name__) +def log_validation_images_to_tracker(images, label, validation_prompt, accelerator, epoch): + logger.info( + f"Logging images to tracker." + ) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{label}_{i}: {validation_prompt}") for i, image in enumerate(images) + ] + } + ) + + +# TODO: Add `prompt_embeds` and `negative_prompt_embeds` parameters to the function when `pre_compute_text_embeddings` +# argument is implemented. +def generate_validation_images(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype): + logger.info(f"Running validation images.") + + pipeline_args = {} + + if text_encoder is not None: + pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder) + + if vae is not None: + pipeline_args["vae"] = vae + + # create pipeline (note: unet and vae are loaded again in float32) + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + tokenizer=tokenizer, + unet=accelerator.unwrap_model(unet), + revision=args.revision, + torch_dtype=weight_dtype, + **pipeline_args, + ) + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the + # scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) + + images_sets = [] + for vp, nvi, vnp, vis, vgs in zip(args.validation_prompt, args.validation_number_images, + args.validation_negative_prompt, args.validation_inference_steps, + args.validation_guidance_scale): + images = [] + if vp is not None: + logger.info( + f"Generating {nvi} images with prompt: '{vp}', negative prompt: '{vnp}', inference steps: {vis}, " + f"guidance scale: {vgs}." + ) + + pipeline_args = {"prompt": vp, + "negative_prompt": vnp, + "num_inference_steps": vis, + "guidance_scale": vgs + } + + # run inference + # TODO: it would be good to measure whether it's faster to run inference on all images at once, one at a + # time or in small batches + for _ in range(nvi): + with torch.autocast("cuda"): + image = pipeline(**pipeline_args, num_images_per_prompt=1, generator=generator).images[0] + images.append(image) + + images_sets.append(images) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return images_sets + + def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, @@ -81,7 +181,7 @@ def parse_args(input_args=None): "--instance_data_dir", type=str, default=None, - required=True, + required=False, help="A folder containing the training data of instance images.", ) parser.add_argument( @@ -95,7 +195,7 @@ def parse_args(input_args=None): "--instance_prompt", type=str, default=None, - required=True, + required=False, help="The prompt with identifier specifying the instance", ) parser.add_argument( @@ -272,6 +372,46 @@ def parse_args(input_args=None): ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' ), ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run validation every X steps. Validation consists of running the prompt" + " `validation_prompt` multiple times: `validation_number_images`" + " and logging the images." + ), + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--validation_number_images", + type=int, + default=4, + help="Number of images that should be generated during validation with the validation parameters given.", + ) + parser.add_argument( + "--validation_negative_prompt", + type=str, + default=None, + help="A negative prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--validation_inference_steps", + type=int, + default=25, + help="Number of inference steps (denoising steps) to run during validation.", + ) + parser.add_argument( + "--validation_guidance_scale", + type=float, + default=7.5, + help="To control how much the image generation process follows the text prompt", + ) parser.add_argument( "--mixed_precision", type=str, @@ -297,27 +437,56 @@ def parse_args(input_args=None): parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) + parser.add_argument( + "--set_grads_to_none", + action="store_true", + help=( + "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" + " behaviors, so disable this argument if it causes any problems. More info:" + " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" + ), + ) + parser.add_argument( + "--concepts_list", + type=str, + default=None, + help="Path to json file containing a list of multiple concepts, will overwrite parameters like instance_prompt," + " class_prompt, etc.", + ) if input_args is not None: args = parser.parse_args(input_args) else: args = parser.parse_args() - env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if args.concepts_list is None and (args.instance_data_dir is None or args.instance_prompt is None): + raise ValueError("You must specify either instance parameters (data directory, prompt, etc.) or use " + "the `concept_list` parameter and specify them within the file.") + env_local_rank = int(environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank if args.with_prior_preservation: - if args.class_data_dir is None: - raise ValueError("You must specify a data directory for class images.") - if args.class_prompt is None: - raise ValueError("You must specify prompt for class images.") + if args.concepts_list is None: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + if args.class_data_dir is not None: + raise ValueError("If you are using `concepts_list` parameter define the class data directory within " + "the file.") + if args.class_prompt is not None: + raise ValueError("If you are using `concepts_list` parameter define the class prompt within " + "the file.") else: # logger is not available yet if args.class_data_dir is not None: - warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + warnings.warn( + "Ignoring `class_data_dir` parameter, you need to use it together with `with_prior_preservation`.") if args.class_prompt is not None: - warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + warnings.warn( + "Ignoring `class_prompt` parameter, you need to use it together with `with_prior_preservation`.") return args @@ -329,14 +498,14 @@ class DreamBoothDataset(Dataset): """ def __init__( - self, - instance_data_root, - instance_prompt, - tokenizer, - class_data_root=None, - class_prompt=None, - size=512, - center_crop=False, + self, + instance_data_root, + instance_prompt, + tokenizer, + class_data_root=None, + class_prompt=None, + size=512, + center_crop=False, ): self.size = size self.center_crop = center_crop @@ -362,6 +531,7 @@ def __init__( self.instance_prompt.append(instance_prompt[i]) self._length += self.num_instance_images[i] + self.class_data_root = None if class_data_root is not None: self.class_data_root.append(Path(class_data_root[i])) self.class_data_root[i].mkdir(parents=True, exist_ok=True) @@ -371,8 +541,6 @@ def __init__( self._length -= self.num_instance_images[i] self._length += self.num_class_images[i] self.class_prompt.append(class_prompt[i]) - else: - self.class_data_root = None self.image_transforms = transforms.Compose( [ @@ -446,7 +614,7 @@ def collate_fn(num_instances, examples, with_prior_preservation=False): class PromptDataset(Dataset): - "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + """A simple dataset to prepare the prompts to generate class images on multiple GPUs.""" def __init__(self, prompt, num_samples): self.prompt = prompt @@ -471,10 +639,14 @@ def main(args): gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=args.report_to, - logging_dir=logging_dir, + project_dir=logging_dir, project_config=accelerator_project_config, ) + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. @@ -484,23 +656,58 @@ def main(args): "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." ) - # Parse instance and class inputs, and double check that lengths match - instance_data_dir = args.instance_data_dir.split(",") - instance_prompt = args.instance_prompt.split(",") - assert all( - x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)] - ), "Instance data dir and prompt inputs are not of the same length." - - if args.with_prior_preservation: - class_data_dir = args.class_data_dir.split(",") - class_prompt = args.class_prompt.split(",") - assert all( - x == len(instance_data_dir) - for x in [len(instance_data_dir), len(instance_prompt), len(class_data_dir), len(class_prompt)] - ), "Instance & class data dir or prompt inputs are not of the same length." + instance_data_dir = [] + instance_prompt = [] + class_data_dir = [] if args.with_prior_preservation else None + class_prompt = [] if args.with_prior_preservation else None + if args.concepts_list is not None: + with open(args.concepts_list, "r") as f: + concepts_list = json.load(f) + + if args.validation_steps is not None: + args.validation_prompt = [] + args.validation_number_images = [] + args.validation_negative_prompt = [] + args.validation_inference_steps = [] + args.validation_guidance_scale = [] + for concept in concepts_list: + instance_data_dir.append(concept['instance_data_dir']) + instance_prompt.append(concept['instance_prompt']) + if args.with_prior_preservation: + try: + class_data_dir.append(concept['class_data_dir']) + class_prompt.append(concept['class_prompt']) + except KeyError: + raise KeyError("`class_data_dir` or `class_prompt` not found in concepts_list while using " + "`with_prior_preservation`.") + if args.validation_steps is not None: + args.validation_prompt.append(concept.get('validation_prompt', None)) + args.validation_number_images.append(concept.get('validation_number_images', 4)) + args.validation_negative_prompt.append(concept.get('validation_negative_prompt', None)) + args.validation_inference_steps.append(concept.get('validation_inference_steps', 25)) + args.validation_guidance_scale.append(concept.get('validation_guidance_scale', 7.5)) else: - class_data_dir = args.class_data_dir - class_prompt = args.class_prompt + # Parse instance and class inputs, and double check that lengths match + instance_data_dir = args.instance_data_dir.split(",") + instance_prompt = args.instance_prompt.split(",") + assert all( + x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)] + ), "Instance data dir and prompt inputs are not of the same length." + + if args.with_prior_preservation: + class_data_dir = args.class_data_dir.split(",") + class_prompt = args.class_prompt.split(",") + assert all( + x == len(instance_data_dir) + for x in [len(instance_data_dir), len(instance_prompt), len(class_data_dir), len(class_prompt)] + ), "Instance & class data dir or prompt inputs are not of the same length." + + if args.validation_steps is not None: + args.validation_prompt = [args.validation_prompt] + args.validation_number_images = [args.validation_number_images] + args.validation_negative_prompt = [args.validation_negative_prompt] + args.validation_inference_steps = [args.validation_inference_steps] + args.validation_guidance_scale = [args.validation_guidance_scale] # Make one log on every process with the configuration for debugging. logging.basicConfig( @@ -556,25 +763,28 @@ def main(args): pipeline.to(accelerator.device) for example in tqdm( - sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): images = pipeline(example["prompt"]).images - for i, image in enumerate(images): + for ii, image in enumerate(images): hash_image = hashlib.sha1(image.tobytes()).hexdigest() image_filename = ( - class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + class_images_dir / f"{example['index'][ii] + cur_class_images}-{hash_image}.jpg" ) image.save(image_filename) + # Clean up the memory deleting one-time-use variables. del pipeline + del sample_dataloader + del sample_dataset if torch.cuda.is_available(): torch.cuda.empty_cache() # Handle the repository creation if accelerator.is_main_process: if args.output_dir is not None: - os.makedirs(args.output_dir, exist_ok=True) + makedirs(args.output_dir, exist_ok=True) if args.push_to_hub: repo_id = create_repo( @@ -627,7 +837,7 @@ def main(args): if args.scale_lr: args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs @@ -659,7 +869,7 @@ def main(args): train_dataset = DreamBoothDataset( instance_data_root=instance_data_dir, instance_prompt=instance_prompt, - class_data_root=class_data_dir if args.with_prior_preservation else None, + class_data_root=class_data_dir, class_prompt=class_prompt, tokenizer=tokenizer, size=args.resolution, @@ -721,7 +931,7 @@ def main(args): args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) # We need to initialize the trackers we use, and also store our configuration. - # The trackers initializes automatically on the main process. + # The trackers initialize automatically on the main process. if accelerator.is_main_process: accelerator.init_trackers("dreambooth", config=vars(args)) @@ -742,10 +952,10 @@ def main(args): # Potentially load in the weights and states from a previous save if args.resume_from_checkpoint: if args.resume_from_checkpoint != "latest": - path = os.path.basename(args.resume_from_checkpoint) + path = basename(args.resume_from_checkpoint) else: # Get the mos recent checkpoint - dirs = os.listdir(args.output_dir) + dirs = listdir(args.output_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) path = dirs[-1] if len(dirs) > 0 else None @@ -757,7 +967,7 @@ def main(args): args.resume_from_checkpoint = None else: accelerator.print(f"Resuming from checkpoint {path}") - accelerator.load_state(os.path.join(args.output_dir, path)) + accelerator.load_state(join(args.output_dir, path)) global_step = int(path.split("-")[1]) resume_global_step = global_step * args.gradient_accumulation_steps @@ -835,18 +1045,34 @@ def main(args): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() - optimizer.zero_grad() + optimizer.zero_grad(set_to_none=args.set_grads_to_none) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 - if global_step % args.checkpointing_steps == 0: - if accelerator.is_main_process: - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") - accelerator.save_state(save_path) - logger.info(f"Saved state to {save_path}") + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + save_path = join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + if any(args.validation_prompt) and global_step % args.validation_steps == 0: + images_set = generate_validation_images( + text_encoder, + tokenizer, + unet, + vae, + args, + accelerator, + weight_dtype + ) + for ix, (images, instance_data_dir, validation_prompt) in enumerate(zip(images_set, args.instance_data_dir, args.validation_prompt)): + if len(images) > 0: + # Get the label from the instance data directory + label = basename(normpath(instance_data_dir)) if validation_prompt is None else f"image{ix}" + log_validation_images_to_tracker(images, label, validation_prompt, accelerator, epoch) logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) @@ -855,7 +1081,7 @@ def main(args): if global_step >= args.max_train_steps: break - # Create the pipeline using using the trained modules and save it. + # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: pipeline = DiffusionPipeline.from_pretrained( diff --git a/setup.py b/setup.py index a972df80b509..cc3e67fad5ce 100644 --- a/setup.py +++ b/setup.py @@ -69,10 +69,11 @@ import os import re -from distutils.core import Command from setuptools import find_packages, setup +from distutils.core import Command + # IMPORTANT: # 1. all dependencies should be listed here with their version requirements if any From 728f28b96556431da3af772c4481ee8653ff8cfd Mon Sep 17 00:00:00 2001 From: mrepetto Date: Tue, 13 Jun 2023 22:25:59 -0300 Subject: [PATCH 02/28] - Fixed bad logic for class data root directories --- .../multi_subject_dreambooth/train_multi_subject_dreambooth.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py index 67729806e2d5..2406c87e7121 100644 --- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py +++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py @@ -515,7 +515,7 @@ def __init__( self.instance_images_path = [] self.num_instance_images = [] self.instance_prompt = [] - self.class_data_root = [] + self.class_data_root = [] if class_data_root is not None else None self.class_images_path = [] self.num_class_images = [] self.class_prompt = [] @@ -531,7 +531,6 @@ def __init__( self.instance_prompt.append(instance_prompt[i]) self._length += self.num_instance_images[i] - self.class_data_root = None if class_data_root is not None: self.class_data_root.append(Path(class_data_root[i])) self.class_data_root[i].mkdir(parents=True, exist_ok=True) From 2f34eb6f8245f7d36e47a160dfbb67e44e661b8a Mon Sep 17 00:00:00 2001 From: mrepetto Date: Thu, 15 Jun 2023 11:53:19 -0300 Subject: [PATCH 03/28] Defaulting validation_steps to None for an easier logic --- .../train_multi_subject_dreambooth.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py index 2406c87e7121..3a71d99a1a60 100644 --- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py +++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py @@ -375,7 +375,7 @@ def parse_args(input_args=None): parser.add_argument( "--validation_steps", type=int, - default=100, + default=None, help=( "Run validation every X steps. Validation consists of running the prompt" " `validation_prompt` multiple times: `validation_number_images`" @@ -663,7 +663,7 @@ def main(args): with open(args.concepts_list, "r") as f: concepts_list = json.load(f) - if args.validation_steps is not None: + if args.validation_steps: args.validation_prompt = [] args.validation_number_images = [] args.validation_negative_prompt = [] @@ -679,7 +679,7 @@ def main(args): except KeyError: raise KeyError("`class_data_dir` or `class_prompt` not found in concepts_list while using " "`with_prior_preservation`.") - if args.validation_steps is not None: + if args.validation_steps: args.validation_prompt.append(concept.get('validation_prompt', None)) args.validation_number_images.append(concept.get('validation_number_images', 4)) args.validation_negative_prompt.append(concept.get('validation_negative_prompt', None)) @@ -701,7 +701,7 @@ def main(args): for x in [len(instance_data_dir), len(instance_prompt), len(class_data_dir), len(class_prompt)] ), "Instance & class data dir or prompt inputs are not of the same length." - if args.validation_steps is not None: + if args.validation_steps: args.validation_prompt = [args.validation_prompt] args.validation_number_images = [args.validation_number_images] args.validation_negative_prompt = [args.validation_negative_prompt] @@ -1057,7 +1057,7 @@ def main(args): accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") - if any(args.validation_prompt) and global_step % args.validation_steps == 0: + if args.validation_steps and any(args.validation_prompt) and global_step % args.validation_steps == 0: images_set = generate_validation_images( text_encoder, tokenizer, @@ -1067,7 +1067,8 @@ def main(args): accelerator, weight_dtype ) - for ix, (images, instance_data_dir, validation_prompt) in enumerate(zip(images_set, args.instance_data_dir, args.validation_prompt)): + for ix, (images, instance_data_dir, validation_prompt) in enumerate( + zip(images_set, args.instance_data_dir, args.validation_prompt)): if len(images) > 0: # Get the label from the instance data directory label = basename(normpath(instance_data_dir)) if validation_prompt is None else f"image{ix}" From 1f007a787147906370933e9fc661fdebcbc4aff0 Mon Sep 17 00:00:00 2001 From: mrepetto Date: Thu, 15 Jun 2023 14:15:17 -0300 Subject: [PATCH 04/28] Fixed multiple validation prompts --- .../train_multi_subject_dreambooth.py | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py index 3a71d99a1a60..c8e3d7fdbfa6 100644 --- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py +++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py @@ -702,11 +702,22 @@ def main(args): ), "Instance & class data dir or prompt inputs are not of the same length." if args.validation_steps: - args.validation_prompt = [args.validation_prompt] - args.validation_number_images = [args.validation_number_images] - args.validation_negative_prompt = [args.validation_negative_prompt] - args.validation_inference_steps = [args.validation_inference_steps] - args.validation_guidance_scale = [args.validation_guidance_scale] + validation_prompts = args.validation_prompt.split(",") + num_of_validation_prompts = len(validation_prompts) + args.validation_prompt = validation_prompts + args.validation_number_images = [args.validation_number_images] * num_of_validation_prompts + if args.validation_negative_prompt: + negative_validation_prompts = args.validation_negative_prompt.split(",") + while len(negative_validation_prompts) < validation_prompts: + negative_validation_prompts.append(None) + args.validation_negative_prompt = negative_validation_prompts + else: + args.validation_negative_prompt = [None] * num_of_validation_prompts + + assert num_of_validation_prompts == len(negative_validation_prompts), \ + "The length of negative prompts for validation is greater than the prompt." + args.validation_inference_steps = [args.validation_inference_steps] * num_of_validation_prompts + args.validation_guidance_scale = [args.validation_guidance_scale] * num_of_validation_prompts # Make one log on every process with the configuration for debugging. logging.basicConfig( From c5756c4d31dfc69642c7180925bc639cd60995d4 Mon Sep 17 00:00:00 2001 From: mrepetto Date: Thu, 15 Jun 2023 14:23:30 -0300 Subject: [PATCH 05/28] Fixed bug on validation negative prompt --- .../train_multi_subject_dreambooth.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py index c8e3d7fdbfa6..1196ff77e719 100644 --- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py +++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py @@ -706,13 +706,13 @@ def main(args): num_of_validation_prompts = len(validation_prompts) args.validation_prompt = validation_prompts args.validation_number_images = [args.validation_number_images] * num_of_validation_prompts + + negative_validation_prompts = [None] * num_of_validation_prompts if args.validation_negative_prompt: negative_validation_prompts = args.validation_negative_prompt.split(",") while len(negative_validation_prompts) < validation_prompts: negative_validation_prompts.append(None) - args.validation_negative_prompt = negative_validation_prompts - else: - args.validation_negative_prompt = [None] * num_of_validation_prompts + args.validation_negative_prompt = negative_validation_prompts assert num_of_validation_prompts == len(negative_validation_prompts), \ "The length of negative prompts for validation is greater than the prompt." From f171686bf5195831dcf8033455c431dc72c73f00 Mon Sep 17 00:00:00 2001 From: mrepetto Date: Thu, 15 Jun 2023 16:16:57 -0300 Subject: [PATCH 06/28] Changed validation logic for tracker. --- .../train_multi_subject_dreambooth.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py index 1196ff77e719..dcdd61a0bd44 100644 --- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py +++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py @@ -43,7 +43,7 @@ def log_validation_images_to_tracker(images, label, validation_prompt, accelerator, epoch): logger.info( - f"Logging images to tracker." + f"Logging images to tracker for validation prompt: {validation_prompt}." ) for tracker in accelerator.trackers: @@ -1078,11 +1078,11 @@ def main(args): accelerator, weight_dtype ) - for ix, (images, instance_data_dir, validation_prompt) in enumerate( - zip(images_set, args.instance_data_dir, args.validation_prompt)): + for images, instance_data_dir, validation_prompt in zip(images_set, args.instance_data_dir, + args.validation_prompt): if len(images) > 0: # Get the label from the instance data directory - label = basename(normpath(instance_data_dir)) if validation_prompt is None else f"image{ix}" + label = basename(normpath(instance_data_dir)) log_validation_images_to_tracker(images, label, validation_prompt, accelerator, epoch) logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} From c23faef7d27e6114e97ec4bf69a340c4870191fe Mon Sep 17 00:00:00 2001 From: mrepetto Date: Thu, 15 Jun 2023 17:11:55 -0300 Subject: [PATCH 07/28] Added uuid for validation image labeling --- .../train_multi_subject_dreambooth.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py index dcdd61a0bd44..af06402a2402 100644 --- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py +++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py @@ -4,9 +4,10 @@ import json import logging import math +import uuid import warnings from os import environ, listdir, makedirs -from os.path import basename, join, normpath +from os.path import basename, join from pathlib import Path import datasets @@ -1078,11 +1079,9 @@ def main(args): accelerator, weight_dtype ) - for images, instance_data_dir, validation_prompt in zip(images_set, args.instance_data_dir, - args.validation_prompt): + for images, validation_prompt in zip(images_set, args.validation_prompt): if len(images) > 0: - # Get the label from the instance data directory - label = basename(normpath(instance_data_dir)) + label = str(uuid.uuid1())[:8] # generate an id for different set of images log_validation_images_to_tracker(images, label, validation_prompt, accelerator, epoch) logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} From d1885dc4bcf50aedabb146dd38e4df21ffb465ca Mon Sep 17 00:00:00 2001 From: mrepetto Date: Thu, 15 Jun 2023 17:20:43 -0300 Subject: [PATCH 08/28] Fix error when comparing validation prompts and validation negative prompts --- .../multi_subject_dreambooth/train_multi_subject_dreambooth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py index af06402a2402..e47c0da0139f 100644 --- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py +++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py @@ -711,7 +711,7 @@ def main(args): negative_validation_prompts = [None] * num_of_validation_prompts if args.validation_negative_prompt: negative_validation_prompts = args.validation_negative_prompt.split(",") - while len(negative_validation_prompts) < validation_prompts: + while len(negative_validation_prompts) < num_of_validation_prompts: negative_validation_prompts.append(None) args.validation_negative_prompt = negative_validation_prompts From f1a5f8f72f09e6296c9d5929e8a92e24cbd5b81b Mon Sep 17 00:00:00 2001 From: mrepetto Date: Thu, 15 Jun 2023 17:34:37 -0300 Subject: [PATCH 09/28] Improved error message when negative prompts for validation are more than the number of prompts --- .../multi_subject_dreambooth/train_multi_subject_dreambooth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py index e47c0da0139f..ece318f015d6 100644 --- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py +++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py @@ -716,7 +716,7 @@ def main(args): args.validation_negative_prompt = negative_validation_prompts assert num_of_validation_prompts == len(negative_validation_prompts), \ - "The length of negative prompts for validation is greater than the prompt." + "The length of negative prompts for validation is greater than the number of validation prompts." args.validation_inference_steps = [args.validation_inference_steps] * num_of_validation_prompts args.validation_guidance_scale = [args.validation_guidance_scale] * num_of_validation_prompts From 6eb258784b5b5d7104d49d4a218c52342326ea36 Mon Sep 17 00:00:00 2001 From: mrepetto Date: Thu, 15 Jun 2023 18:02:45 -0300 Subject: [PATCH 10/28] - Changed image tracking number from epoch to global_step - Added Typing for functions --- .../train_multi_subject_dreambooth.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py index ece318f015d6..6ccee2498972 100644 --- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py +++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py @@ -9,6 +9,7 @@ from os import environ, listdir, makedirs from os.path import basename, join from pathlib import Path +from typing import List import datasets import numpy as np @@ -42,7 +43,8 @@ logger = get_logger(__name__) -def log_validation_images_to_tracker(images, label, validation_prompt, accelerator, epoch): +def log_validation_images_to_tracker(images: List[np.array], label: str, validation_prompt: str, + accelerator: Accelerator, epoch: int): logger.info( f"Logging images to tracker for validation prompt: {validation_prompt}." ) @@ -63,7 +65,8 @@ def log_validation_images_to_tracker(images, label, validation_prompt, accelerat # TODO: Add `prompt_embeds` and `negative_prompt_embeds` parameters to the function when `pre_compute_text_embeddings` # argument is implemented. -def generate_validation_images(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype): +def generate_validation_images(text_encoder: object, tokenizer: object, unet: object, vae: object, + arguments: argparse.Namespace, accelerator: Accelerator, weight_dtype: str): logger.info(f"Running validation images.") pipeline_args = {} @@ -76,10 +79,10 @@ def generate_validation_images(text_encoder, tokenizer, unet, vae, args, acceler # create pipeline (note: unet and vae are loaded again in float32) pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, + arguments.pretrained_model_name_or_path, tokenizer=tokenizer, unet=accelerator.unwrap_model(unet), - revision=args.revision, + revision=arguments.revision, torch_dtype=weight_dtype, **pipeline_args, ) @@ -100,12 +103,12 @@ def generate_validation_images(text_encoder, tokenizer, unet, vae, args, acceler pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) - generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) + generator = None if arguments.seed is None else torch.Generator(device=accelerator.device).manual_seed(arguments.seed) images_sets = [] - for vp, nvi, vnp, vis, vgs in zip(args.validation_prompt, args.validation_number_images, - args.validation_negative_prompt, args.validation_inference_steps, - args.validation_guidance_scale): + for vp, nvi, vnp, vis, vgs in zip(arguments.validation_prompt, arguments.validation_number_images, + arguments.validation_negative_prompt, arguments.validation_inference_steps, + arguments.validation_guidance_scale): images = [] if vp is not None: logger.info( @@ -1082,7 +1085,7 @@ def main(args): for images, validation_prompt in zip(images_set, args.validation_prompt): if len(images) > 0: label = str(uuid.uuid1())[:8] # generate an id for different set of images - log_validation_images_to_tracker(images, label, validation_prompt, accelerator, epoch) + log_validation_images_to_tracker(images, label, validation_prompt, accelerator, global_step) logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) From e2e7f2593decb9ba9c5e3482af2d1592cdd187d7 Mon Sep 17 00:00:00 2001 From: mrepetto Date: Thu, 15 Jun 2023 18:57:48 -0300 Subject: [PATCH 11/28] Added some validations more when using concept_list parameter and the regular ones. --- .../train_multi_subject_dreambooth.py | 38 ++++++++++++------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py index 6ccee2498972..ce142eef38c9 100644 --- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py +++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py @@ -458,37 +458,46 @@ def parse_args(input_args=None): " class_prompt, etc.", ) - if input_args is not None: + if input_args: args = parser.parse_args(input_args) else: args = parser.parse_args() - if args.concepts_list is None and (args.instance_data_dir is None or args.instance_prompt is None): + if not args.concepts_list and (not args.instance_data_dir or not args.instance_prompt): raise ValueError("You must specify either instance parameters (data directory, prompt, etc.) or use " "the `concept_list` parameter and specify them within the file.") + + if args.concepts_list: + if args.instance_prompt: + raise ValueError("If you are using `concepts_list` parameter, define the instance data directory within " + "the file.") + if args.instance_data_dir: + raise ValueError("If you are using `concepts_list` parameter, define the instance within " + "the file.") + env_local_rank = int(environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank if args.with_prior_preservation: - if args.concepts_list is None: - if args.class_data_dir is None: + if not args.concepts_list: + if not args.class_data_dir: raise ValueError("You must specify a data directory for class images.") - if args.class_prompt is None: + if not args.class_prompt: raise ValueError("You must specify prompt for class images.") else: - if args.class_data_dir is not None: - raise ValueError("If you are using `concepts_list` parameter define the class data directory within " - "the file.") - if args.class_prompt is not None: - raise ValueError("If you are using `concepts_list` parameter define the class prompt within " - "the file.") + if args.class_data_dir: + raise ValueError(f"If you are using `concepts_list` parameter, define the class data directory within " + f"the file.") + if args.class_prompt: + raise ValueError(f"If you are using `concepts_list` parameter, define the class prompt within " + f"the file.") else: # logger is not available yet - if args.class_data_dir is not None: + if not args.class_data_dir: warnings.warn( "Ignoring `class_data_dir` parameter, you need to use it together with `with_prior_preservation`.") - if args.class_prompt is not None: + if not args.class_prompt: warnings.warn( "Ignoring `class_prompt` parameter, you need to use it together with `with_prior_preservation`.") @@ -663,7 +672,7 @@ def main(args): instance_prompt = [] class_data_dir = [] if args.with_prior_preservation else None class_prompt = [] if args.with_prior_preservation else None - if args.concepts_list is not None: + if args.concepts_list: with open(args.concepts_list, "r") as f: concepts_list = json.load(f) @@ -673,6 +682,7 @@ def main(args): args.validation_negative_prompt = [] args.validation_inference_steps = [] args.validation_guidance_scale = [] + for concept in concepts_list: instance_data_dir.append(concept['instance_data_dir']) instance_prompt.append(concept['instance_prompt']) From 514494441237fe97ee5523685b53fdbc1083365d Mon Sep 17 00:00:00 2001 From: mrepetto Date: Thu, 15 Jun 2023 19:01:19 -0300 Subject: [PATCH 12/28] Fixed error message --- .../train_multi_subject_dreambooth.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py index ce142eef38c9..a8d2cf865ae6 100644 --- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py +++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py @@ -469,10 +469,9 @@ def parse_args(input_args=None): if args.concepts_list: if args.instance_prompt: - raise ValueError("If you are using `concepts_list` parameter, define the instance data directory within " - "the file.") + raise ValueError("If you are using `concepts_list` parameter, define the instance prompt within the file.") if args.instance_data_dir: - raise ValueError("If you are using `concepts_list` parameter, define the instance within " + raise ValueError("If you are using `concepts_list` parameter, define the instance data directory within " "the file.") env_local_rank = int(environ.get("LOCAL_RANK", -1)) From 04f45183898d9d0ac090d2e529e66f1ac424b7f7 Mon Sep 17 00:00:00 2001 From: mrepetto Date: Thu, 15 Jun 2023 19:13:09 -0300 Subject: [PATCH 13/28] Added more validations for validation parameters --- .../train_multi_subject_dreambooth.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py index a8d2cf865ae6..0d695d0e8463 100644 --- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py +++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py @@ -474,6 +474,12 @@ def parse_args(input_args=None): raise ValueError("If you are using `concepts_list` parameter, define the instance data directory within " "the file.") + if args.validation_steps: + if args.validation_prompt or args.validation_negative_prompt or args.validation_guidance_scale \ + or args.validation_number_images or args.validation_inference_steps: + raise ValueError("If you are using `concepts_list` parameter, define all validation parameters for " + "each subject within the file.") + env_local_rank = int(environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank From aa5bb74e0dc4a26e8e9c327e893a7ed62a996c73 Mon Sep 17 00:00:00 2001 From: mrepetto Date: Thu, 15 Jun 2023 19:36:58 -0300 Subject: [PATCH 14/28] Improved messaging for errors --- .../train_multi_subject_dreambooth.py | 30 ++++++++++++------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py index 0d695d0e8463..682f112dd454 100644 --- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py +++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py @@ -381,40 +381,46 @@ def parse_args(input_args=None): type=int, default=None, help=( - "Run validation every X steps. Validation consists of running the prompt" - " `validation_prompt` multiple times: `validation_number_images`" - " and logging the images." + "Run validation every X steps. Validation consists of running the prompt(s) `validation_prompt` " + "multiple times (`validation_number_images`) and logging the images." ), ) parser.add_argument( "--validation_prompt", type=str, default=None, - help="A prompt that is used during validation to verify that the model is learning.", + help="A prompt that is used during validation to verify that the model is learning. You can use commas to " + "define multiple negative prompts. This parameter can be defined also within the file given by " + "`concepts_list` parameter in the respective subject.", ) parser.add_argument( "--validation_number_images", type=int, default=4, - help="Number of images that should be generated during validation with the validation parameters given.", + help="Number of images that should be generated during validation with the validation parameters given. This " + "can be defined within the file given by `concepts_list` parameter in the respective subject.", ) parser.add_argument( "--validation_negative_prompt", type=str, default=None, - help="A negative prompt that is used during validation to verify that the model is learning.", + help="A negative prompt that is used during validation to verify that the model is learning. You can use commas" + " to define multiple negative prompts, each one corresponding to a validation prompt. This parameter can " + "be defined also within the file given by `concepts_list` parameter in the respective subject.", ) parser.add_argument( "--validation_inference_steps", type=int, default=25, - help="Number of inference steps (denoising steps) to run during validation.", + help="Number of inference steps (denoising steps) to run during validation. This can be defined within the " + "file given by `concepts_list` parameter in the respective subject.", ) parser.add_argument( "--validation_guidance_scale", type=float, default=7.5, - help="To control how much the image generation process follows the text prompt", + help="To control how much the image generation process follows the text prompt. This can be defined within the " + "file given by `concepts_list` parameter in the respective subject.", ) parser.add_argument( "--mixed_precision", @@ -477,8 +483,12 @@ def parse_args(input_args=None): if args.validation_steps: if args.validation_prompt or args.validation_negative_prompt or args.validation_guidance_scale \ or args.validation_number_images or args.validation_inference_steps: - raise ValueError("If you are using `concepts_list` parameter, define all validation parameters for " - "each subject within the file.") + raise ValueError("If you are using `concepts_list` parameter, define validation parameters for " + "each subject within the file:\n - `validation_prompt`." + "\n - `validation_negative_prompt`.\n - `validation_guidance_scale`." + "\n - `validation_number_images`.\n - `validation_prompt`." + "\n - `validation_inference_steps`.\nThe `validation_steps` parameter is the only one " + "that needs to be defined outside the file.") env_local_rank = int(environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: From d01502b831540f75e10f749b5549a8d534e94027 Mon Sep 17 00:00:00 2001 From: mrepetto Date: Thu, 15 Jun 2023 20:00:13 -0300 Subject: [PATCH 15/28] Fixed validation error for parameters with default values --- .../train_multi_subject_dreambooth.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py index 682f112dd454..3f763f20b7f3 100644 --- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py +++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py @@ -479,16 +479,13 @@ def parse_args(input_args=None): if args.instance_data_dir: raise ValueError("If you are using `concepts_list` parameter, define the instance data directory within " "the file.") - - if args.validation_steps: - if args.validation_prompt or args.validation_negative_prompt or args.validation_guidance_scale \ - or args.validation_number_images or args.validation_inference_steps: - raise ValueError("If you are using `concepts_list` parameter, define validation parameters for " - "each subject within the file:\n - `validation_prompt`." - "\n - `validation_negative_prompt`.\n - `validation_guidance_scale`." - "\n - `validation_number_images`.\n - `validation_prompt`." - "\n - `validation_inference_steps`.\nThe `validation_steps` parameter is the only one " - "that needs to be defined outside the file.") + if args.validation_steps and (args.validation_prompt or args.validation_negative_prompt): + raise ValueError("If you are using `concepts_list` parameter, define validation parameters for " + "each subject within the file:\n - `validation_prompt`." + "\n - `validation_negative_prompt`.\n - `validation_guidance_scale`." + "\n - `validation_number_images`.\n - `validation_prompt`." + "\n - `validation_inference_steps`.\nThe `validation_steps` parameter is the only one " + "that needs to be defined outside the file.") env_local_rank = int(environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: From 29be7860d2925f8cf36c287b3975b4f47a1c9fee Mon Sep 17 00:00:00 2001 From: mrepetto Date: Thu, 15 Jun 2023 20:16:44 -0300 Subject: [PATCH 16/28] - Added train step to image name for validation - reformatted code --- .../train_multi_subject_dreambooth.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py index 3f763f20b7f3..77a858a57cb6 100644 --- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py +++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py @@ -36,7 +36,6 @@ if is_wandb_available(): import wandb - # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.13.0.dev0") @@ -57,7 +56,8 @@ def log_validation_images_to_tracker(images: List[np.array], label: str, validat tracker.log( { "validation": [ - wandb.Image(image, caption=f"{label}_{i}: {validation_prompt}") for i, image in enumerate(images) + wandb.Image(image, caption=f"{label}_{epoch}_{i}: {validation_prompt}") for i, image in + enumerate(images) ] } ) @@ -103,7 +103,8 @@ def generate_validation_images(text_encoder: object, tokenizer: object, unet: ob pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) - generator = None if arguments.seed is None else torch.Generator(device=accelerator.device).manual_seed(arguments.seed) + generator = None if arguments.seed is None else torch.Generator(device=accelerator.device).manual_seed( + arguments.seed) images_sets = [] for vp, nvi, vnp, vis, vgs in zip(arguments.validation_prompt, arguments.validation_number_images, @@ -1090,11 +1091,12 @@ def main(args): if accelerator.is_main_process: if global_step % args.checkpointing_steps == 0: - save_path = join(args.output_dir, f"checkpoint-{global_step}") - accelerator.save_state(save_path) - logger.info(f"Saved state to {save_path}") + save_path = join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") - if args.validation_steps and any(args.validation_prompt) and global_step % args.validation_steps == 0: + if args.validation_steps and any(args.validation_prompt) and \ + global_step % args.validation_steps == 0: images_set = generate_validation_images( text_encoder, tokenizer, @@ -1107,7 +1109,8 @@ def main(args): for images, validation_prompt in zip(images_set, args.validation_prompt): if len(images) > 0: label = str(uuid.uuid1())[:8] # generate an id for different set of images - log_validation_images_to_tracker(images, label, validation_prompt, accelerator, global_step) + log_validation_images_to_tracker(images, label, validation_prompt, accelerator, + global_step) logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) From e32515001f036adc60c2b3deddd2846cac8f4891 Mon Sep 17 00:00:00 2001 From: mrepetto Date: Thu, 15 Jun 2023 20:38:34 -0300 Subject: [PATCH 17/28] - Added train step to image's name for validation - reformatted code --- .../train_multi_subject_dreambooth.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py index 77a858a57cb6..c173a370d1e4 100644 --- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py +++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py @@ -22,6 +22,8 @@ from accelerate.utils import ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder from PIL import Image +from torch import dtype +from torch.nn import Module from torch.utils.data import Dataset from torchvision import transforms from tqdm.auto import tqdm @@ -65,8 +67,8 @@ def log_validation_images_to_tracker(images: List[np.array], label: str, validat # TODO: Add `prompt_embeds` and `negative_prompt_embeds` parameters to the function when `pre_compute_text_embeddings` # argument is implemented. -def generate_validation_images(text_encoder: object, tokenizer: object, unet: object, vae: object, - arguments: argparse.Namespace, accelerator: Accelerator, weight_dtype: str): +def generate_validation_images(text_encoder: Module, tokenizer: Module, unet: Module, vae: Module, + arguments: argparse.Namespace, accelerator: Accelerator, weight_dtype: dtype): logger.info(f"Running validation images.") pipeline_args = {} @@ -799,9 +801,9 @@ def main(args): sample_dataloader = accelerator.prepare(sample_dataloader) pipeline.to(accelerator.device) - for example in tqdm( - sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process - ): + for example in tqdm(sample_dataloader, + desc="Generating class images", + disable=not accelerator.is_local_main_process): images = pipeline(example["prompt"]).images for ii, image in enumerate(images): @@ -829,6 +831,7 @@ def main(args): ).repo_id # Load the tokenizer + tokenizer = None if args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) elif args.pretrained_model_name_or_path: @@ -1035,24 +1038,24 @@ def main(args): noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) - timesteps = timesteps.long() + time_steps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + time_steps = time_steps.long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_latents = noise_scheduler.add_noise(latents, noise, time_steps) # Get the text embedding for conditioning encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Predict the noise residual - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + model_pred = unet(noisy_latents, time_steps, encoder_hidden_states).sample # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(latents, noise, timesteps) + target = noise_scheduler.get_velocity(latents, noise, time_steps) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") From 1338dcd23ec896ea2ac8351cd614a3cc0190ca28 Mon Sep 17 00:00:00 2001 From: mrepetto Date: Thu, 15 Jun 2023 21:18:26 -0300 Subject: [PATCH 18/28] Updated README.md file. --- .../multi_subject_dreambooth/README.md | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/examples/research_projects/multi_subject_dreambooth/README.md b/examples/research_projects/multi_subject_dreambooth/README.md index cf7dd31d0797..d1a7705cfebb 100644 --- a/examples/research_projects/multi_subject_dreambooth/README.md +++ b/examples/research_projects/multi_subject_dreambooth/README.md @@ -86,6 +86,53 @@ This example shows training for 2 subjects, but please note that the model can b Note also that in this script, `sks` and `t@y` were used as tokens to learn the new subjects ([this thread](https://github.com/XavierXiao/Dreambooth-Stable-Diffusion/issues/71) inspired the use of `t@y` as our second identifier). However, there may be better rare tokens to experiment with, and results also seemed to be good when more intuitive words are used. +**Important**: New parameters are added to the script, making possible to validate the progress of the training by +generating images at specified steps. Taking also into account that a comma separated list in a text field for a prompt +it's never a good idea (simply because it is very common in prompts to have them as part of a regular text) we +introduce the `concept_list` parameter: allowing to specify a json-like file where you can define the different +configuration for each subject that you want to train. + +An example of how to generate the file: +```python +import json + +# here we are using parameters for prior-preservation and validation as well. +concepts_list = [ + { + "instance_prompt": "drawing of a t@y meme", + "class_prompt": "drawing of a meme", + "instance_data_dir": "/some_folder/meme_toy", + "class_data_dir": "/data/meme", + "validation_prompt": "drawing of a t@y meme about football in Uruguay", + "validation_negative_prompt": "black and white" + }, + { + "instance_prompt": "drawing of a sks sir", + "class_prompt": "drawing of a sir", + "instance_data_dir": "/some_other_folder/sir_sks", + "class_data_dir": "/data/sir", + "validation_prompt": "drawing of a sks sir with the Uruguayan sun in his chest", + "validation_negative_prompt": "an old man", + "validation_guidance_scale": 20, + "validation_number_images": 3, + "validation_inference_steps": 10 + } +] + +with open("concepts_list.json", "w") as f: + json.dump(concepts_list, f, indent=4) +``` +And then just point to the file when executing the script: + +```bash +# exports... +accelerate launch train_multi_subject_dreambooth.py \ +# more parameters... +--concepts_list="concepts_list.json" +``` + +You can use the helper from the script to get a better sense of each parameter. + ### Inference Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt. From e0dfaa44645394b07300af16efcc6b558c065b5c Mon Sep 17 00:00:00 2001 From: mrepetto Date: Thu, 15 Jun 2023 21:20:28 -0300 Subject: [PATCH 19/28] reverted back original script of train_dreambooth.py --- examples/dreambooth/train_dreambooth.py | 93 +++++++++++++++++++------ 1 file changed, 70 insertions(+), 23 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 6bfc84ba7d9a..85b2534c6e0d 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -20,6 +20,7 @@ import logging import math import os +import shutil import warnings from pathlib import Path @@ -58,7 +59,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.17.0.dev0") +check_min_version("0.18.0.dev0") logger = get_logger(__name__) @@ -114,16 +115,17 @@ def log_validation( pipeline_args = {} - if text_encoder is not None: - pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder) - if vae is not None: pipeline_args["vae"] = vae + if text_encoder is not None: + text_encoder = accelerator.unwrap_model(text_encoder) + # create pipeline (note: unet and vae are loaded again in float32) pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, tokenizer=tokenizer, + text_encoder=text_encoder, unet=accelerator.unwrap_model(unet), revision=args.revision, torch_dtype=weight_dtype, @@ -156,10 +158,16 @@ def log_validation( # run inference generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) images = [] - for _ in range(args.num_validation_images): - with torch.autocast("cuda"): - image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0] - images.append(image) + if args.validation_images is None: + for _ in range(args.num_validation_images): + with torch.autocast("cuda"): + image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0] + images.append(image) + else: + for image in args.validation_images: + image = Image.open(image) + image = pipeline(**pipeline_args, image=image, generator=generator).images[0] + images.append(image) for tracker in accelerator.trackers: if tracker.name == "tensorboard": @@ -175,8 +183,7 @@ def log_validation( ) del pipeline - if torch.cuda.is_available(): - torch.cuda.empty_cache() + torch.cuda.empty_cache() return images @@ -526,6 +533,19 @@ def parse_args(input_args=None): parser.add_argument( "--skip_save_text_encoder", action="store_true", required=False, help="Set to not save text encoder" ) + parser.add_argument( + "--validation_images", + required=False, + default=None, + nargs="+", + help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.", + ) + parser.add_argument( + "--class_labels_conditioning", + required=False, + default=None, + help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.", + ) if input_args is not None: args = parser.parse_args(input_args) @@ -752,13 +772,12 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte def main(args): logging_dir = Path(args.output_dir, args.logging_dir) - accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=args.report_to, - project_dir=logging_dir, project_config=accelerator_project_config, ) @@ -835,10 +854,7 @@ def main(args): image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" image.save(image_filename) - # Clean up the memory deleting one-time-use variables. del pipeline - del sample_dataloader - del sample_dataset if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -1075,8 +1091,8 @@ def compute_text_embeddings(prompt): unet, optimizer, train_dataloader, lr_scheduler ) - # For mixed precision training we cast the text_encoder and vae weights to half-precision - # as these models are only used for inference, keeping weights in full precision is not required. + # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 @@ -1098,7 +1114,7 @@ def compute_text_embeddings(prompt): args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) # We need to initialize the trackers we use, and also store our configuration. - # The trackers initialize automatically on the main process. + # The trackers initializes automatically on the main process. if accelerator.is_main_process: accelerator.init_trackers("dreambooth", config=vars(args)) @@ -1173,7 +1189,7 @@ def compute_text_embeddings(prompt): ) else: noise = torch.randn_like(model_input) - bsz = model_input.shape[0] + bsz, channels, height, width = model_input.shape # Sample a random timestep for each image timesteps = torch.randint( 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device @@ -1195,8 +1211,18 @@ def compute_text_embeddings(prompt): text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, ) + if accelerator.unwrap_model(unet).config.in_channels == channels * 2: + noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) + + if args.class_labels_conditioning == "timesteps": + class_labels = timesteps + else: + class_labels = None + # Predict the noise residual - model_pred = unet(noisy_model_input, timesteps, encoder_hidden_states).sample + model_pred = unet( + noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels + ).sample if model_pred.shape[1] == 6: model_pred, _ = torch.chunk(model_pred, 2, dim=1) @@ -1243,12 +1269,33 @@ def compute_text_embeddings(prompt): global_step += 1 if accelerator.is_main_process: - images = [] if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") + images = [] + if args.validation_prompt is not None and global_step % args.validation_steps == 0: images = log_validation( text_encoder, @@ -1270,7 +1317,7 @@ def compute_text_embeddings(prompt): if global_step >= args.max_train_steps: break - # Create the pipeline using the trained modules and save it. + # Create the pipeline using using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: pipeline_args = {} @@ -1325,4 +1372,4 @@ def compute_text_embeddings(prompt): if __name__ == "__main__": args = parse_args() - main(args) + main(args) \ No newline at end of file From 22a202d10f6347b5bb1d4f0a6913c29b289f89c9 Mon Sep 17 00:00:00 2001 From: mrepetto Date: Thu, 15 Jun 2023 21:49:46 -0300 Subject: [PATCH 20/28] reverted back original script of train_dreambooth.py --- examples/dreambooth/train_dreambooth.py | 81 +++++-------------------- 1 file changed, 15 insertions(+), 66 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 85b2534c6e0d..0e79c1348760 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -20,7 +20,6 @@ import logging import math import os -import shutil import warnings from pathlib import Path @@ -59,7 +58,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.18.0.dev0") +check_min_version("0.17.0.dev0") logger = get_logger(__name__) @@ -115,17 +114,16 @@ def log_validation( pipeline_args = {} + if text_encoder is not None: + pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder) + if vae is not None: pipeline_args["vae"] = vae - if text_encoder is not None: - text_encoder = accelerator.unwrap_model(text_encoder) - # create pipeline (note: unet and vae are loaded again in float32) pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, tokenizer=tokenizer, - text_encoder=text_encoder, unet=accelerator.unwrap_model(unet), revision=args.revision, torch_dtype=weight_dtype, @@ -158,16 +156,10 @@ def log_validation( # run inference generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) images = [] - if args.validation_images is None: - for _ in range(args.num_validation_images): - with torch.autocast("cuda"): - image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0] - images.append(image) - else: - for image in args.validation_images: - image = Image.open(image) - image = pipeline(**pipeline_args, image=image, generator=generator).images[0] - images.append(image) + for _ in range(args.num_validation_images): + with torch.autocast("cuda"): + image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0] + images.append(image) for tracker in accelerator.trackers: if tracker.name == "tensorboard": @@ -533,19 +525,6 @@ def parse_args(input_args=None): parser.add_argument( "--skip_save_text_encoder", action="store_true", required=False, help="Set to not save text encoder" ) - parser.add_argument( - "--validation_images", - required=False, - default=None, - nargs="+", - help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.", - ) - parser.add_argument( - "--class_labels_conditioning", - required=False, - default=None, - help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.", - ) if input_args is not None: args = parser.parse_args(input_args) @@ -772,12 +751,13 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte def main(args): logging_dir = Path(args.output_dir, args.logging_dir) - accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=args.report_to, + logging_dir=logging_dir, project_config=accelerator_project_config, ) @@ -1091,8 +1071,8 @@ def compute_text_embeddings(prompt): unet, optimizer, train_dataloader, lr_scheduler ) - # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision - # as these weights are only used for inference, keeping weights in full precision is not required. + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 @@ -1189,7 +1169,7 @@ def compute_text_embeddings(prompt): ) else: noise = torch.randn_like(model_input) - bsz, channels, height, width = model_input.shape + bsz = model_input.shape[0] # Sample a random timestep for each image timesteps = torch.randint( 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device @@ -1211,18 +1191,8 @@ def compute_text_embeddings(prompt): text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, ) - if accelerator.unwrap_model(unet).config.in_channels == channels * 2: - noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) - - if args.class_labels_conditioning == "timesteps": - class_labels = timesteps - else: - class_labels = None - # Predict the noise residual - model_pred = unet( - noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels - ).sample + model_pred = unet(noisy_model_input, timesteps, encoder_hidden_states).sample if model_pred.shape[1] == 6: model_pred, _ = torch.chunk(model_pred, 2, dim=1) @@ -1269,33 +1239,12 @@ def compute_text_embeddings(prompt): global_step += 1 if accelerator.is_main_process: + images = [] if global_step % args.checkpointing_steps == 0: - # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` - if args.checkpoints_total_limit is not None: - checkpoints = os.listdir(args.output_dir) - checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] - checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) - - # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints - if len(checkpoints) >= args.checkpoints_total_limit: - num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 - removing_checkpoints = checkpoints[0:num_to_remove] - - logger.info( - f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" - ) - logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") - - for removing_checkpoint in removing_checkpoints: - removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) - shutil.rmtree(removing_checkpoint) - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") - images = [] - if args.validation_prompt is not None and global_step % args.validation_steps == 0: images = log_validation( text_encoder, From 3c75d2212ff48721557e819c145bb2d15af8fd3d Mon Sep 17 00:00:00 2001 From: mrepetto Date: Thu, 15 Jun 2023 21:55:19 -0300 Subject: [PATCH 21/28] left one blank line at the eof --- examples/dreambooth/train_dreambooth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 0e79c1348760..37b06acb6977 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -1321,4 +1321,4 @@ def compute_text_embeddings(prompt): if __name__ == "__main__": args = parse_args() - main(args) \ No newline at end of file + main(args) From f7aa125f6716a8f01adb3e6c4c190a65faa01924 Mon Sep 17 00:00:00 2001 From: mrepetto Date: Thu, 15 Jun 2023 21:56:51 -0300 Subject: [PATCH 22/28] reverted back setup.py --- setup.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/setup.py b/setup.py index cc3e67fad5ce..684487a06554 100644 --- a/setup.py +++ b/setup.py @@ -69,10 +69,8 @@ import os import re - -from setuptools import find_packages, setup - from distutils.core import Command +from setuptools import find_packages, setup # IMPORTANT: From 3767a628092e9caef600cafebd9a093df636bdd8 Mon Sep 17 00:00:00 2001 From: mrepetto Date: Thu, 15 Jun 2023 21:57:49 -0300 Subject: [PATCH 23/28] reverted back setup.py --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 684487a06554..a972df80b509 100644 --- a/setup.py +++ b/setup.py @@ -70,6 +70,7 @@ import os import re from distutils.core import Command + from setuptools import find_packages, setup From 6f0d794064875a2eacb878cc01a0d18f64e9abf9 Mon Sep 17 00:00:00 2001 From: mrepetto Date: Fri, 16 Jun 2023 12:30:22 -0300 Subject: [PATCH 24/28] added same logic for when parameters for prior preservation are used without enabling the flag while using concept_list parameter. --- .../train_multi_subject_dreambooth.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py index c173a370d1e4..41d195f39446 100644 --- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py +++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py @@ -701,6 +701,7 @@ def main(args): for concept in concepts_list: instance_data_dir.append(concept['instance_data_dir']) instance_prompt.append(concept['instance_prompt']) + if args.with_prior_preservation: try: class_data_dir.append(concept['class_data_dir']) @@ -708,6 +709,14 @@ def main(args): except KeyError: raise KeyError("`class_data_dir` or `class_prompt` not found in concepts_list while using " "`with_prior_preservation`.") + else: + if 'class_data_dir' in concept: + warnings.warn("Ignoring `class_data_dir` key, to use it you need to enable " + "`with_prior_preservation`.") + if 'class_prompt' in concept: + warnings.warn("Ignoring `class_prompt` key, to use it you need to enable " + "`with_prior_preservation`.") + if args.validation_steps: args.validation_prompt.append(concept.get('validation_prompt', None)) args.validation_number_images.append(concept.get('validation_number_images', 4)) From ef99e82ad0b21127cc02d2153be043db7a4d9a87 Mon Sep 17 00:00:00 2001 From: mrepetto Date: Wed, 28 Jun 2023 12:22:05 -0300 Subject: [PATCH 25/28] Ran black formatter. --- .../train_multi_subject_dreambooth.py | 205 ++++++++++-------- 1 file changed, 116 insertions(+), 89 deletions(-) diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py index 41d195f39446..cdc45249c8b4 100644 --- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py +++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py @@ -17,11 +17,11 @@ import torch.nn.functional as F import torch.utils.checkpoint import transformers +from PIL import Image from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder -from PIL import Image from torch import dtype from torch.nn import Module from torch.utils.data import Dataset @@ -30,7 +30,13 @@ from transformers import AutoTokenizer, PretrainedConfig import diffusers -from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, DPMSolverMultistepScheduler, UNet2DConditionModel +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DiffusionPipeline, + DPMSolverMultistepScheduler, + UNet2DConditionModel, +) from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -44,11 +50,10 @@ logger = get_logger(__name__) -def log_validation_images_to_tracker(images: List[np.array], label: str, validation_prompt: str, - accelerator: Accelerator, epoch: int): - logger.info( - f"Logging images to tracker for validation prompt: {validation_prompt}." - ) +def log_validation_images_to_tracker( + images: List[np.array], label: str, validation_prompt: str, accelerator: Accelerator, epoch: int +): + logger.info(f"Logging images to tracker for validation prompt: {validation_prompt}.") for tracker in accelerator.trackers: if tracker.name == "tensorboard": @@ -58,8 +63,8 @@ def log_validation_images_to_tracker(images: List[np.array], label: str, validat tracker.log( { "validation": [ - wandb.Image(image, caption=f"{label}_{epoch}_{i}: {validation_prompt}") for i, image in - enumerate(images) + wandb.Image(image, caption=f"{label}_{epoch}_{i}: {validation_prompt}") + for i, image in enumerate(images) ] } ) @@ -67,8 +72,15 @@ def log_validation_images_to_tracker(images: List[np.array], label: str, validat # TODO: Add `prompt_embeds` and `negative_prompt_embeds` parameters to the function when `pre_compute_text_embeddings` # argument is implemented. -def generate_validation_images(text_encoder: Module, tokenizer: Module, unet: Module, vae: Module, - arguments: argparse.Namespace, accelerator: Accelerator, weight_dtype: dtype): +def generate_validation_images( + text_encoder: Module, + tokenizer: Module, + unet: Module, + vae: Module, + arguments: argparse.Namespace, + accelerator: Accelerator, + weight_dtype: dtype, +): logger.info(f"Running validation images.") pipeline_args = {} @@ -105,13 +117,18 @@ def generate_validation_images(text_encoder: Module, tokenizer: Module, unet: Mo pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) - generator = None if arguments.seed is None else torch.Generator(device=accelerator.device).manual_seed( - arguments.seed) + generator = ( + None if arguments.seed is None else torch.Generator(device=accelerator.device).manual_seed(arguments.seed) + ) images_sets = [] - for vp, nvi, vnp, vis, vgs in zip(arguments.validation_prompt, arguments.validation_number_images, - arguments.validation_negative_prompt, arguments.validation_inference_steps, - arguments.validation_guidance_scale): + for vp, nvi, vnp, vis, vgs in zip( + arguments.validation_prompt, + arguments.validation_number_images, + arguments.validation_negative_prompt, + arguments.validation_inference_steps, + arguments.validation_guidance_scale, + ): images = [] if vp is not None: logger.info( @@ -119,11 +136,7 @@ def generate_validation_images(text_encoder: Module, tokenizer: Module, unet: Mo f"guidance scale: {vgs}." ) - pipeline_args = {"prompt": vp, - "negative_prompt": vnp, - "num_inference_steps": vis, - "guidance_scale": vgs - } + pipeline_args = {"prompt": vp, "negative_prompt": vnp, "num_inference_steps": vis, "guidance_scale": vgs} # run inference # TODO: it would be good to measure whether it's faster to run inference on all images at once, one at a @@ -393,37 +406,37 @@ def parse_args(input_args=None): type=str, default=None, help="A prompt that is used during validation to verify that the model is learning. You can use commas to " - "define multiple negative prompts. This parameter can be defined also within the file given by " - "`concepts_list` parameter in the respective subject.", + "define multiple negative prompts. This parameter can be defined also within the file given by " + "`concepts_list` parameter in the respective subject.", ) parser.add_argument( "--validation_number_images", type=int, default=4, help="Number of images that should be generated during validation with the validation parameters given. This " - "can be defined within the file given by `concepts_list` parameter in the respective subject.", + "can be defined within the file given by `concepts_list` parameter in the respective subject.", ) parser.add_argument( "--validation_negative_prompt", type=str, default=None, help="A negative prompt that is used during validation to verify that the model is learning. You can use commas" - " to define multiple negative prompts, each one corresponding to a validation prompt. This parameter can " - "be defined also within the file given by `concepts_list` parameter in the respective subject.", + " to define multiple negative prompts, each one corresponding to a validation prompt. This parameter can " + "be defined also within the file given by `concepts_list` parameter in the respective subject.", ) parser.add_argument( "--validation_inference_steps", type=int, default=25, help="Number of inference steps (denoising steps) to run during validation. This can be defined within the " - "file given by `concepts_list` parameter in the respective subject.", + "file given by `concepts_list` parameter in the respective subject.", ) parser.add_argument( "--validation_guidance_scale", type=float, default=7.5, help="To control how much the image generation process follows the text prompt. This can be defined within the " - "file given by `concepts_list` parameter in the respective subject.", + "file given by `concepts_list` parameter in the respective subject.", ) parser.add_argument( "--mixed_precision", @@ -464,7 +477,7 @@ def parse_args(input_args=None): type=str, default=None, help="Path to json file containing a list of multiple concepts, will overwrite parameters like instance_prompt," - " class_prompt, etc.", + " class_prompt, etc.", ) if input_args: @@ -473,22 +486,27 @@ def parse_args(input_args=None): args = parser.parse_args() if not args.concepts_list and (not args.instance_data_dir or not args.instance_prompt): - raise ValueError("You must specify either instance parameters (data directory, prompt, etc.) or use " - "the `concept_list` parameter and specify them within the file.") + raise ValueError( + "You must specify either instance parameters (data directory, prompt, etc.) or use " + "the `concept_list` parameter and specify them within the file." + ) if args.concepts_list: if args.instance_prompt: raise ValueError("If you are using `concepts_list` parameter, define the instance prompt within the file.") if args.instance_data_dir: - raise ValueError("If you are using `concepts_list` parameter, define the instance data directory within " - "the file.") + raise ValueError( + "If you are using `concepts_list` parameter, define the instance data directory within " "the file." + ) if args.validation_steps and (args.validation_prompt or args.validation_negative_prompt): - raise ValueError("If you are using `concepts_list` parameter, define validation parameters for " - "each subject within the file:\n - `validation_prompt`." - "\n - `validation_negative_prompt`.\n - `validation_guidance_scale`." - "\n - `validation_number_images`.\n - `validation_prompt`." - "\n - `validation_inference_steps`.\nThe `validation_steps` parameter is the only one " - "that needs to be defined outside the file.") + raise ValueError( + "If you are using `concepts_list` parameter, define validation parameters for " + "each subject within the file:\n - `validation_prompt`." + "\n - `validation_negative_prompt`.\n - `validation_guidance_scale`." + "\n - `validation_number_images`.\n - `validation_prompt`." + "\n - `validation_inference_steps`.\nThe `validation_steps` parameter is the only one " + "that needs to be defined outside the file." + ) env_local_rank = int(environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: @@ -502,19 +520,23 @@ def parse_args(input_args=None): raise ValueError("You must specify prompt for class images.") else: if args.class_data_dir: - raise ValueError(f"If you are using `concepts_list` parameter, define the class data directory within " - f"the file.") + raise ValueError( + f"If you are using `concepts_list` parameter, define the class data directory within " f"the file." + ) if args.class_prompt: - raise ValueError(f"If you are using `concepts_list` parameter, define the class prompt within " - f"the file.") + raise ValueError( + f"If you are using `concepts_list` parameter, define the class prompt within " f"the file." + ) else: # logger is not available yet if not args.class_data_dir: warnings.warn( - "Ignoring `class_data_dir` parameter, you need to use it together with `with_prior_preservation`.") + "Ignoring `class_data_dir` parameter, you need to use it together with `with_prior_preservation`." + ) if not args.class_prompt: warnings.warn( - "Ignoring `class_prompt` parameter, you need to use it together with `with_prior_preservation`.") + "Ignoring `class_prompt` parameter, you need to use it together with `with_prior_preservation`." + ) return args @@ -526,14 +548,14 @@ class DreamBoothDataset(Dataset): """ def __init__( - self, - instance_data_root, - instance_prompt, - tokenizer, - class_data_root=None, - class_prompt=None, - size=512, - center_crop=False, + self, + instance_data_root, + instance_prompt, + tokenizer, + class_data_root=None, + class_prompt=None, + size=512, + center_crop=False, ): self.size = size self.center_crop = center_crop @@ -699,30 +721,34 @@ def main(args): args.validation_guidance_scale = [] for concept in concepts_list: - instance_data_dir.append(concept['instance_data_dir']) - instance_prompt.append(concept['instance_prompt']) + instance_data_dir.append(concept["instance_data_dir"]) + instance_prompt.append(concept["instance_prompt"]) if args.with_prior_preservation: try: - class_data_dir.append(concept['class_data_dir']) - class_prompt.append(concept['class_prompt']) + class_data_dir.append(concept["class_data_dir"]) + class_prompt.append(concept["class_prompt"]) except KeyError: - raise KeyError("`class_data_dir` or `class_prompt` not found in concepts_list while using " - "`with_prior_preservation`.") + raise KeyError( + "`class_data_dir` or `class_prompt` not found in concepts_list while using " + "`with_prior_preservation`." + ) else: - if 'class_data_dir' in concept: - warnings.warn("Ignoring `class_data_dir` key, to use it you need to enable " - "`with_prior_preservation`.") - if 'class_prompt' in concept: - warnings.warn("Ignoring `class_prompt` key, to use it you need to enable " - "`with_prior_preservation`.") + if "class_data_dir" in concept: + warnings.warn( + "Ignoring `class_data_dir` key, to use it you need to enable " "`with_prior_preservation`." + ) + if "class_prompt" in concept: + warnings.warn( + "Ignoring `class_prompt` key, to use it you need to enable " "`with_prior_preservation`." + ) if args.validation_steps: - args.validation_prompt.append(concept.get('validation_prompt', None)) - args.validation_number_images.append(concept.get('validation_number_images', 4)) - args.validation_negative_prompt.append(concept.get('validation_negative_prompt', None)) - args.validation_inference_steps.append(concept.get('validation_inference_steps', 25)) - args.validation_guidance_scale.append(concept.get('validation_guidance_scale', 7.5)) + args.validation_prompt.append(concept.get("validation_prompt", None)) + args.validation_number_images.append(concept.get("validation_number_images", 4)) + args.validation_negative_prompt.append(concept.get("validation_negative_prompt", None)) + args.validation_inference_steps.append(concept.get("validation_inference_steps", 25)) + args.validation_guidance_scale.append(concept.get("validation_guidance_scale", 7.5)) else: # Parse instance and class inputs, and double check that lengths match instance_data_dir = args.instance_data_dir.split(",") @@ -752,8 +778,9 @@ def main(args): negative_validation_prompts.append(None) args.validation_negative_prompt = negative_validation_prompts - assert num_of_validation_prompts == len(negative_validation_prompts), \ - "The length of negative prompts for validation is greater than the number of validation prompts." + assert num_of_validation_prompts == len( + negative_validation_prompts + ), "The length of negative prompts for validation is greater than the number of validation prompts." args.validation_inference_steps = [args.validation_inference_steps] * num_of_validation_prompts args.validation_guidance_scale = [args.validation_guidance_scale] * num_of_validation_prompts @@ -810,15 +837,15 @@ def main(args): sample_dataloader = accelerator.prepare(sample_dataloader) pipeline.to(accelerator.device) - for example in tqdm(sample_dataloader, - desc="Generating class images", - disable=not accelerator.is_local_main_process): + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): images = pipeline(example["prompt"]).images for ii, image in enumerate(images): hash_image = hashlib.sha1(image.tobytes()).hexdigest() image_filename = ( - class_images_dir / f"{example['index'][ii] + cur_class_images}-{hash_image}.jpg" + class_images_dir / f"{example['index'][ii] + cur_class_images}-{hash_image}.jpg" ) image.save(image_filename) @@ -886,7 +913,7 @@ def main(args): if args.scale_lr: args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs @@ -1047,7 +1074,9 @@ def main(args): noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image - time_steps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + time_steps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device + ) time_steps = time_steps.long() # Add noise to the latents according to the noise magnitude at each timestep @@ -1107,22 +1136,20 @@ def main(args): accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") - if args.validation_steps and any(args.validation_prompt) and \ - global_step % args.validation_steps == 0: + if ( + args.validation_steps + and any(args.validation_prompt) + and global_step % args.validation_steps == 0 + ): images_set = generate_validation_images( - text_encoder, - tokenizer, - unet, - vae, - args, - accelerator, - weight_dtype + text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype ) for images, validation_prompt in zip(images_set, args.validation_prompt): if len(images) > 0: label = str(uuid.uuid1())[:8] # generate an id for different set of images - log_validation_images_to_tracker(images, label, validation_prompt, accelerator, - global_step) + log_validation_images_to_tracker( + images, label, validation_prompt, accelerator, global_step + ) logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) From 42e3d11e1760d496dd2e97e10d938468ceedc8c5 Mon Sep 17 00:00:00 2001 From: mrepetto Date: Wed, 28 Jun 2023 14:23:43 -0300 Subject: [PATCH 26/28] fixed a few strings --- .../train_multi_subject_dreambooth.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py index e67069d039dd..1170197b04ac 100644 --- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py +++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py @@ -496,7 +496,7 @@ def parse_args(input_args=None): raise ValueError("If you are using `concepts_list` parameter, define the instance prompt within the file.") if args.instance_data_dir: raise ValueError( - "If you are using `concepts_list` parameter, define the instance data directory within " "the file." + "If you are using `concepts_list` parameter, define the instance data directory within the file." ) if args.validation_steps and (args.validation_prompt or args.validation_negative_prompt): raise ValueError( @@ -521,11 +521,11 @@ def parse_args(input_args=None): else: if args.class_data_dir: raise ValueError( - f"If you are using `concepts_list` parameter, define the class data directory within " f"the file." + "If you are using `concepts_list` parameter, define the class data directory within the file." ) if args.class_prompt: raise ValueError( - f"If you are using `concepts_list` parameter, define the class prompt within " f"the file." + "If you are using `concepts_list` parameter, define the class prompt within the file." ) else: # logger is not available yet @@ -544,7 +544,7 @@ def parse_args(input_args=None): class DreamBoothDataset(Dataset): """ A dataset to prepare the instance and class images with the prompts for fine-tuning the model. - It pre-processes the images and the tokenizes prompts. + It pre-processes the images and then tokenizes prompts. """ def __init__( @@ -735,11 +735,11 @@ def main(args): else: if "class_data_dir" in concept: warnings.warn( - "Ignoring `class_data_dir` key, to use it you need to enable " "`with_prior_preservation`." + "Ignoring `class_data_dir` key, to use it you need to enable `with_prior_preservation`." ) if "class_prompt" in concept: warnings.warn( - "Ignoring `class_prompt` key, to use it you need to enable " "`with_prior_preservation`." + "Ignoring `class_prompt` key, to use it you need to enable `with_prior_preservation`." ) if args.validation_steps: From 5c775cab1091ac598eaf9e0004136a9b0e2978cc Mon Sep 17 00:00:00 2001 From: mrepetto Date: Wed, 28 Jun 2023 14:37:25 -0300 Subject: [PATCH 27/28] fixed import sort with isort and removed fstrings without placeholder --- .../train_multi_subject_dreambooth.py | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py index 1170197b04ac..e6903b257e7f 100644 --- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py +++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py @@ -11,25 +11,20 @@ from pathlib import Path from typing import List -import datasets import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint -import transformers -from PIL import Image -from accelerate import Accelerator -from accelerate.logging import get_logger -from accelerate.utils import ProjectConfiguration, set_seed -from huggingface_hub import create_repo, upload_folder from torch import dtype from torch.nn import Module from torch.utils.data import Dataset -from torchvision import transforms -from tqdm.auto import tqdm -from transformers import AutoTokenizer, PretrainedConfig +import datasets import diffusers +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed from diffusers import ( AutoencoderKL, DDPMScheduler, @@ -40,6 +35,12 @@ from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available +from huggingface_hub import create_repo, upload_folder +from PIL import Image +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + if is_wandb_available(): import wandb @@ -81,7 +82,7 @@ def generate_validation_images( accelerator: Accelerator, weight_dtype: dtype, ): - logger.info(f"Running validation images.") + logger.info("Running validation images.") pipeline_args = {} From 2753220ebaa296d15791ef66184e2bf67e5c4c47 Mon Sep 17 00:00:00 2001 From: mrepetto Date: Wed, 28 Jun 2023 14:51:07 -0300 Subject: [PATCH 28/28] fixed import order with ruff (since with isort wasn't ok) --- .../train_multi_subject_dreambooth.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py index e6903b257e7f..c75a0a9acc64 100644 --- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py +++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py @@ -11,20 +11,25 @@ from pathlib import Path from typing import List +import datasets import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder +from PIL import Image from torch import dtype from torch.nn import Module from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig -import datasets import diffusers -import transformers -from accelerate import Accelerator -from accelerate.logging import get_logger -from accelerate.utils import ProjectConfiguration, set_seed from diffusers import ( AutoencoderKL, DDPMScheduler, @@ -35,11 +40,6 @@ from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available -from huggingface_hub import create_repo, upload_folder -from PIL import Image -from torchvision import transforms -from tqdm.auto import tqdm -from transformers import AutoTokenizer, PretrainedConfig if is_wandb_available():