From cb2acab2b1feb0bf92e94dd0a63f1d38dd42337e Mon Sep 17 00:00:00 2001 From: discus0434 Date: Thu, 15 Dec 2022 15:44:53 +0900 Subject: [PATCH 1/4] add func logging interim generation --- examples/dreambooth/train_dreambooth.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index adc1359a5e56..34c7d6026523 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -159,6 +159,14 @@ def parse_args(input_args=None): " training using `--resume_from_checkpoint`." ), ) + parser.add_argument( + "--generating_progress_steps", + type=int, + default=500, + help=( + "Save an image generated from the model every X steps." + ), + ) parser.add_argument( "--resume_from_checkpoint", type=str, @@ -739,6 +747,22 @@ def main(args): accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") + if global_step % args.generating_progress_steps == 0: + if accelerator.is_main_process: + g = torch.Generator(device=accelerator.device.type) + g.manual_seed(42) + + # generate and save the image + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder), + torch_dtype=weight_dtype, + ).to(accelerator.device.type) + + image = pipeline(args.instance_prompt, generator=g).images[0] + image.save(f"{args.output_dir}/logs/image_{global_step}.png") + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) From 84cdc3e3a3a96d51b2835cf34dd965b7adfae9ad Mon Sep 17 00:00:00 2001 From: discus0434 Date: Thu, 15 Dec 2022 16:08:59 +0900 Subject: [PATCH 2/4] formatted with black --- examples/dreambooth/train_dreambooth.py | 38 ++++++------------------- 1 file changed, 8 insertions(+), 30 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 34c7d6026523..a17367cdab2d 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -34,9 +34,7 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): text_encoder_config = PretrainedConfig.from_pretrained( - pretrained_model_name_or_path, - subfolder="text_encoder", - revision=revision, + pretrained_model_name_or_path, subfolder="text_encoder", revision=revision, ) model_class = text_encoder_config.architectures[0] @@ -102,10 +100,7 @@ def parse_args(input_args=None): help="The prompt to specify images in the same class as provided instance images.", ) parser.add_argument( - "--with_prior_preservation", - default=False, - action="store_true", - help="Flag to add prior preservation loss.", + "--with_prior_preservation", default=False, action="store_true", help="Flag to add prior preservation loss.", ) parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") parser.add_argument( @@ -163,9 +158,7 @@ def parse_args(input_args=None): "--generating_progress_steps", type=int, default=500, - help=( - "Save an image generated from the model every X steps." - ), + help=("Save an image generated from the model every X steps."), ) parser.add_argument( "--resume_from_checkpoint", @@ -483,17 +476,10 @@ def main(args): # Load the tokenizer if args.tokenizer_name: - tokenizer = AutoTokenizer.from_pretrained( - args.tokenizer_name, - revision=args.revision, - use_fast=False, - ) + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False,) elif args.pretrained_model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="tokenizer", - revision=args.revision, - use_fast=False, + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False, ) # import correct text encoder class @@ -501,19 +487,11 @@ def main(args): # Load models and create wrapper for stable diffusion text_encoder = text_encoder_cls.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="text_encoder", - revision=args.revision, - ) - vae = AutoencoderKL.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="vae", - revision=args.revision, + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, ) + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision,) unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="unet", - revision=args.revision, + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, ) if is_xformers_available(): From c09698a903a945bd781a77832f936a3e42442314 Mon Sep 17 00:00:00 2001 From: discus0434 Date: Thu, 15 Dec 2022 16:58:23 +0900 Subject: [PATCH 3/4] improve generator seeding --- examples/dreambooth/train_dreambooth.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index a17367cdab2d..847e8d7bb121 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -417,6 +417,11 @@ def main(args): if args.seed is not None: set_seed(args.seed) + g = torch.Generator(device=accelerator.device.type) + g.manual_seed(args.seed) + else: + g = torch.Generator(device=accelerator.device.type) + g.manual_seed(42) if args.with_prior_preservation: class_images_dir = Path(args.class_data_dir) @@ -727,9 +732,6 @@ def main(args): if global_step % args.generating_progress_steps == 0: if accelerator.is_main_process: - g = torch.Generator(device=accelerator.device.type) - g.manual_seed(42) - # generate and save the image pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, From 4258eb82e250620f7a6fcfe4a111aa2fa0e865b4 Mon Sep 17 00:00:00 2001 From: discus0434 Date: Thu, 15 Dec 2022 20:15:38 +0900 Subject: [PATCH 4/4] align code style (maybe) --- examples/dreambooth/train_dreambooth.py | 34 ++++++++++++++++++++----- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 847e8d7bb121..447adeb9beee 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -34,7 +34,9 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): text_encoder_config = PretrainedConfig.from_pretrained( - pretrained_model_name_or_path, subfolder="text_encoder", revision=revision, + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=revision, ) model_class = text_encoder_config.architectures[0] @@ -100,7 +102,10 @@ def parse_args(input_args=None): help="The prompt to specify images in the same class as provided instance images.", ) parser.add_argument( - "--with_prior_preservation", default=False, action="store_true", help="Flag to add prior preservation loss.", + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", ) parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") parser.add_argument( @@ -481,10 +486,17 @@ def main(args): # Load the tokenizer if args.tokenizer_name: - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False,) + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer_name, + revision=args.revision, + use_fast=False, + ) elif args.pretrained_model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False, + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, ) # import correct text encoder class @@ -492,11 +504,19 @@ def main(args): # Load models and create wrapper for stable diffusion text_encoder = text_encoder_cls.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, ) - vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision,) unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, + args.pretrained_model_name_or_path, + subfolder="unet", + revision=args.revision, ) if is_xformers_available():