From 6dc9f544a3b68d21af7e91b6e3158ee310b493cf Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 5 Mar 2023 19:56:40 +0000 Subject: [PATCH 1/4] add intermediate logging --- examples/dreambooth/train_dreambooth.py | 79 ++++++++++++++++++++++++- 1 file changed, 77 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index a69ab2d4c045..7255a4eba090 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -40,18 +40,65 @@ from transformers import AutoTokenizer, PretrainedConfig import diffusers -from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel +from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, DiffusionPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler -from diffusers.utils import check_min_version +from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available +if is_wandb_available(): + import wandb + # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.15.0.dev0") logger = get_logger(__name__) +def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # create pipeline (note: unet and vae are loaded again in float32) + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + text_encoder=accelerator.unwrap_model(text_encoder), + tokenizer=tokenizer, + unet=accelerator.unwrap_model(unet), + vae=vae, + revision=args.revision, + torch_dtype=weight_dtype, + ) + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) + images = [] + for _ in range(args.num_validation_images): + with torch.autocast("cuda"): + image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] + images.append(image) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + ] + } + ) + + del pipeline + torch.cuda.empty_cache() + + def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, @@ -306,6 +353,28 @@ def parse_args(input_args=None): ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' ), ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run validation every X steps. Validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`" + " and logging the images." + ), + ) parser.add_argument( "--mixed_precision", type=str, @@ -508,6 +577,10 @@ def main(args): project_config=accelerator_project_config, ) + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. @@ -920,6 +993,8 @@ def load_model_hook(models, input_dir): save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") + if args.validation_prompt is not None and global_step % args.validation_steps == 0: + log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch) logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) From 7ae4077b7af3659a221a249f551b47f69df19f64 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 5 Mar 2023 20:03:19 +0000 Subject: [PATCH 2/4] make style --- examples/dreambooth/train_dreambooth.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 7255a4eba090..1fd5db392398 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -40,7 +40,13 @@ from transformers import AutoTokenizer, PretrainedConfig import diffusers -from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, DiffusionPipeline, UNet2DConditionModel +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DiffusionPipeline, + DPMSolverMultistepScheduler, + UNet2DConditionModel, +) from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -374,7 +380,7 @@ def parse_args(input_args=None): " `args.validation_prompt` multiple times: `args.num_validation_images`" " and logging the images." ), - ) + ) parser.add_argument( "--mixed_precision", type=str, From 49bcaf1ae034918277a34bb667ac0807099c77ff Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 5 Mar 2023 20:08:35 +0000 Subject: [PATCH 3/4] fix --- examples/dreambooth/train_dreambooth.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 1fd5db392398..1e845b495720 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -23,6 +23,7 @@ from pathlib import Path from typing import Optional +import numpy as np import accelerate import torch import torch.nn.functional as F From e0c2b58e1b083436896459669c2e4df8f52d865a Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 5 Mar 2023 20:09:41 +0000 Subject: [PATCH 4/4] fix --- examples/dreambooth/train_dreambooth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 1e845b495720..4d921900c059 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -23,8 +23,8 @@ from pathlib import Path from typing import Optional -import numpy as np import accelerate +import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint