From 45c5897fdf4f3eed47700e220d8183fe79e35cc4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 20 Sep 2023 10:02:31 +0100 Subject: [PATCH 01/22] fix: how print training resume logs. --- examples/text_to_image/train_text_to_image.py | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index b00884bfb7ea..8e515afee3c1 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -577,9 +577,10 @@ def deepspeed_zero_init_disabled_context_manager(): args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision ) - # Freeze vae and text_encoder + # Freeze vae and text_encoder and set unet to trainable vae.requires_grad_(False) text_encoder.requires_grad_(False) + unet.train() # Create EMA for the unet. if args.use_ema: @@ -878,29 +879,29 @@ def collate_fn(examples): f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None + initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - resume_global_step = global_step * args.gradient_accumulation_steps + initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) for epoch in range(first_epoch, args.num_train_epochs): - unet.train() train_loss = 0.0 for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - with accelerator.accumulate(unet): # Convert images to latent space latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample() From 84847064242758a361137acdf567e14c0a19eb03 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 20 Sep 2023 22:55:37 +0100 Subject: [PATCH 02/22] propagate changes to text-to-image scripts. --- .../text_to_image/train_text_to_image_lora.py | 22 ++++++++-------- .../text_to_image/train_text_to_image_sdxl.py | 25 ++++++++++--------- 2 files changed, 23 insertions(+), 24 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index b9830a83ae8a..5f69fe69e2af 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -428,7 +428,6 @@ def main(): # freeze parameters of models to save more memory unet.requires_grad_(False) vae.requires_grad_(False) - text_encoder.requires_grad_(False) # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision @@ -718,24 +717,23 @@ def collate_fn(examples): accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - resume_global_step = global_step * args.gradient_accumulation_steps + initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) + else: + initial_global_step = 0 - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) for epoch in range(first_epoch, args.num_train_epochs): unet.train() train_loss = 0.0 for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - with accelerator.accumulate(unet): # Convert images to latent space latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 299d0f0d7523..457a5c98beb0 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -588,6 +588,8 @@ def main(args): vae.requires_grad_(False) text_encoder_one.requires_grad_(False) text_encoder_two.requires_grad_(False) + # Set unet as trainable. + unet.train() # For mixed precision training we cast all non-trainable weigths to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. @@ -927,24 +929,23 @@ def collate_fn(examples): accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - resume_global_step = global_step * args.gradient_accumulation_steps + initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) for epoch in range(first_epoch, args.num_train_epochs): - unet.train() train_loss = 0.0 for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - with accelerator.accumulate(unet): # Sample noise that we'll add to the latents model_input = batch["model_input"].to(accelerator.device) From 1c1f995b019867221f4c30a8a66fc533f6bea3bb Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 20 Sep 2023 23:05:02 +0100 Subject: [PATCH 03/22] propagate changes to instructpix2pix. --- .../train_instruct_pix2pix.py | 27 ++++++++++--------- .../train_instruct_pix2pix_sdxl.py | 14 ++++------ .../text_to_image/train_text_to_image_lora.py | 1 + .../train_text_to_image_lora_sdxl.py | 23 ++++++++-------- .../text_to_image/train_text_to_image_sdxl.py | 1 + 5 files changed, 34 insertions(+), 32 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 5f8a2d9ee150..b59cc3234ed5 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -464,6 +464,9 @@ def main(): vae.requires_grad_(False) text_encoder.requires_grad_(False) + # Set UNet to trainable. + unet.train() + # Create EMA for the unet. if args.use_ema: ema_unet = EMAModel(unet.parameters(), model_cls=UNet2DConditionModel, model_config=unet.config) @@ -756,29 +759,29 @@ def collate_fn(examples): f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None + initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - resume_global_step = global_step * args.gradient_accumulation_steps + initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) for epoch in range(first_epoch, args.num_train_epochs): - unet.train() train_loss = 0.0 for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - with accelerator.accumulate(unet): # We want to learn the denoising process w.r.t the edited images which # are conditioned on the original image (which was edited) and the edit instruction. diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py index 4d0b9bef55f1..1a65b4162fec 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py @@ -726,6 +726,9 @@ def preprocess_images(examples): text_encoder_1.requires_grad_(False) text_encoder_2.requires_grad_(False) + # Set UNet to trainable. + unet.train() + # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt def encode_prompt(text_encoders, tokenizers, prompt): prompt_embeds_list = [] @@ -938,24 +941,17 @@ def collate_fn(examples): accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - resume_global_step = global_step * args.gradient_accumulation_steps first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) + else: + pass # Only show the progress bar once on each machine. progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar.set_description("Steps") for epoch in range(first_epoch, args.num_train_epochs): - unet.train() train_loss = 0.0 for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - with accelerator.accumulate(unet): # We want to learn the denoising process w.r.t the edited images which # are conditioned on the original image (which was edited) and the edit instruction. diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 5f69fe69e2af..981d19831205 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -712,6 +712,7 @@ def collate_fn(examples): f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None + initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 45ae1cc9ef7a..6b2132a8ae0b 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -970,18 +970,25 @@ def collate_fn(examples): f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None + initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - resume_global_step = global_step * args.gradient_accumulation_steps + initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) for epoch in range(first_epoch, args.num_train_epochs): unet.train() @@ -990,12 +997,6 @@ def collate_fn(examples): text_encoder_two.train() train_loss = 0.0 for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - with accelerator.accumulate(unet): # Convert images to latent space if args.pretrained_vae_model_name_or_path is not None: diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 457a5c98beb0..6b4992ee7c93 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -924,6 +924,7 @@ def collate_fn(examples): f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None + initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) From b9af9560b3ef811d27ad78bce915606ac3bfd6a2 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 23 Sep 2023 16:29:05 +0100 Subject: [PATCH 04/22] propagate changes to dreambooth --- examples/dreambooth/train_dreambooth.py | 24 +++++++++---------- examples/dreambooth/train_dreambooth_lora.py | 24 +++++++++---------- .../dreambooth/train_dreambooth_lora_sdxl.py | 23 +++++++++--------- 3 files changed, 36 insertions(+), 35 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 6f815c0f85f4..c8bc365751af 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -1201,30 +1201,30 @@ def compute_text_embeddings(prompt): f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None + initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - resume_global_step = global_step * args.gradient_accumulation_steps + initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) - - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) for epoch in range(first_epoch, args.num_train_epochs): unet.train() if args.train_text_encoder: text_encoder.train() for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - with accelerator.accumulate(unet): pixel_values = batch["pixel_values"].to(dtype=weight_dtype) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index dc90d10f2b26..5bb6c78b7b74 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -1108,30 +1108,30 @@ def compute_text_embeddings(prompt): f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None + initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - resume_global_step = global_step * args.gradient_accumulation_steps + initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) - - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) for epoch in range(first_epoch, args.num_train_epochs): unet.train() if args.train_text_encoder: text_encoder.train() for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - with accelerator.accumulate(unet): pixel_values = batch["pixel_values"].to(dtype=weight_dtype) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 24dbf4313662..ac59bba6c847 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1048,18 +1048,25 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None + initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - resume_global_step = global_step * args.gradient_accumulation_steps + initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) for epoch in range(first_epoch, args.num_train_epochs): unet.train() @@ -1067,12 +1074,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): text_encoder_one.train() text_encoder_two.train() for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - with accelerator.accumulate(unet): pixel_values = batch["pixel_values"].to(dtype=vae.dtype) From 648ba502cbd1a17db3e6e3dd210e90a148d28b8b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 23 Sep 2023 16:35:58 +0100 Subject: [PATCH 05/22] propagate changes to custom diffusion and instructpix2pix --- .../train_custom_diffusion.py | 24 +++++++++---------- .../train_instruct_pix2pix_sdxl.py | 16 +++++++++---- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py index 60d8d6723dcf..3288fe3258ac 100644 --- a/examples/custom_diffusion/train_custom_diffusion.py +++ b/examples/custom_diffusion/train_custom_diffusion.py @@ -1075,30 +1075,30 @@ def main(args): f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None + initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - resume_global_step = global_step * args.gradient_accumulation_steps + initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) - - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) for epoch in range(first_epoch, args.num_train_epochs): unet.train() if args.modifier_token is not None: text_encoder.train() for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - with accelerator.accumulate(unet), accelerator.accumulate(text_encoder): # Convert images to latent space latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py index 1a65b4162fec..e2d9b2105160 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py @@ -936,18 +936,24 @@ def collate_fn(examples): f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None + initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) + initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch else: - pass - - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) for epoch in range(first_epoch, args.num_train_epochs): train_loss = 0.0 From 945177df54f8664f694b8d4544170a69cc4abb1e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 23 Sep 2023 16:49:16 +0100 Subject: [PATCH 06/22] propagate changes to kandinsky --- .../train_text_to_image_decoder.py | 26 +++++++++++-------- .../train_text_to_image_lora_decoder.py | 22 ++++++++-------- .../train_text_to_image_lora_prior.py | 25 ++++++++++-------- .../train_text_to_image_prior.py | 26 ++++++++++--------- 4 files changed, 54 insertions(+), 45 deletions(-) diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py index 364ed7e03189..2387ba4b27cc 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py @@ -512,6 +512,9 @@ def deepspeed_zero_init_disabled_context_manager(): vae.requires_grad_(False) image_encoder.requires_grad_(False) + # Set unet to trainable. + unet.train() + # Create EMA for the unet. if args.use_ema: ema_unet = UNet2DConditionModel.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="unet") @@ -751,27 +754,28 @@ def collate_fn(examples): f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None + initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - resume_global_step = global_step * args.gradient_accumulation_steps + initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) - progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") for epoch in range(first_epoch, args.num_train_epochs): - unet.train() train_loss = 0.0 for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - with accelerator.accumulate(unet): # Convert images to latent space images = batch["pixel_values"].to(weight_dtype) diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py index 9d96a936d0ca..e9694fc74586 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py @@ -602,29 +602,29 @@ def collate_fn(examples): f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None + initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - resume_global_step = global_step * args.gradient_accumulation_steps + initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) + else: + initial_global_step = 0 - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) for epoch in range(first_epoch, args.num_train_epochs): unet.train() train_loss = 0.0 for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - with accelerator.accumulate(unet): # Convert images to latent space images = batch["pixel_values"].to(weight_dtype) diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py index e4aec111b8f7..5e1512a25ac1 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py @@ -619,30 +619,33 @@ def collate_fn(examples): f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None + initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - resume_global_step = global_step * args.gradient_accumulation_steps + initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + clip_mean = clip_mean.to(weight_dtype).to(accelerator.device) clip_std = clip_std.to(weight_dtype).to(accelerator.device) + for epoch in range(first_epoch, args.num_train_epochs): prior.train() train_loss = 0.0 for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - with accelerator.accumulate(prior): # Convert images to latent space text_input_ids, text_mask, clip_images = ( diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py index d451e1bfe40d..4b58400e6365 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py @@ -517,6 +517,9 @@ def deepspeed_zero_init_disabled_context_manager(): text_encoder.requires_grad_(False) image_encoder.requires_grad_(False) + # Set prior to trainable. + prior.train() + # Create EMA for the prior. if args.use_ema: ema_prior = PriorTransformer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior") @@ -765,32 +768,31 @@ def collate_fn(examples): f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None + initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - resume_global_step = global_step * args.gradient_accumulation_steps + initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) + else: + initial_global_step = 0 - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) clip_mean = clip_mean.to(weight_dtype).to(accelerator.device) clip_std = clip_std.to(weight_dtype).to(accelerator.device) for epoch in range(first_epoch, args.num_train_epochs): - prior.train() train_loss = 0.0 for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - with accelerator.accumulate(prior): # Convert images to latent space text_input_ids, text_mask, clip_images = ( From c8f42492461dd4c1bb8dd5b37f1f4d47aec10f33 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 23 Sep 2023 16:55:06 +0100 Subject: [PATCH 07/22] propagate changes to textual inv. --- .../textual_inversion/textual_inversion.py | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 2e6f9a7d9522..01830751ffe2 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -809,18 +809,25 @@ def main(): f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None + initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - resume_global_step = global_step * args.gradient_accumulation_steps + initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) # keep original embeddings as reference orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone() @@ -828,12 +835,6 @@ def main(): for epoch in range(first_epoch, args.num_train_epochs): text_encoder.train() for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - with accelerator.accumulate(text_encoder): # Convert images to latent space latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach() From 601817c6519db83ab9284c961f7d3a295baa76c5 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 27 Sep 2023 08:39:19 +0530 Subject: [PATCH 08/22] debug --- examples/test_examples.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/test_examples.py b/examples/test_examples.py index 89e866231e89..de0b0d4cf4fe 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -1303,6 +1303,7 @@ def test_instruct_pix2pix_checkpointing_checkpoints_total_limit_removes_multiple run_command(self._launch_args + resume_run_args) # check checkpoint directories exist + print({x for x in os.listdir(tmpdir) if "checkpoint" in x}) self.assertEqual( {x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8", "checkpoint-10"}, From bf666fd29caa50f41811a3f97c7fac73eaef788a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 27 Sep 2023 08:41:48 +0530 Subject: [PATCH 09/22] fix: checkpointing. --- examples/test_examples.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/test_examples.py b/examples/test_examples.py index de0b0d4cf4fe..da7c54bf71f6 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -1303,10 +1303,9 @@ def test_instruct_pix2pix_checkpointing_checkpoints_total_limit_removes_multiple run_command(self._launch_args + resume_run_args) # check checkpoint directories exist - print({x for x in os.listdir(tmpdir) if "checkpoint" in x}) self.assertEqual( {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-6", "checkpoint-8", "checkpoint-10"}, + {"checkpoint-12", "checkpoint-8", "checkpoint-10"}, ) def test_dreambooth_checkpointing_checkpoints_total_limit(self): From 5e3e2ec09cbd9f839b76cd588a564f1109e807b2 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 27 Sep 2023 08:46:05 +0530 Subject: [PATCH 10/22] debug --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 4 ++++ examples/test_examples.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index b59cc3234ed5..9d1187f25149 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -753,6 +753,7 @@ def collate_fn(examples): dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) path = dirs[-1] if len(dirs) > 0 else None + print(f"Path found: {path}") if path is None: accelerator.print( @@ -762,8 +763,10 @@ def collate_fn(examples): initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") + print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) + print(f"Global step: {global_step}") initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch @@ -771,6 +774,7 @@ def collate_fn(examples): else: initial_global_step = 0 + print(f"Initial global step: {initial_global_step}") progress_bar = tqdm( range(0, args.max_train_steps), initial=initial_global_step, diff --git a/examples/test_examples.py b/examples/test_examples.py index da7c54bf71f6..89e866231e89 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -1305,7 +1305,7 @@ def test_instruct_pix2pix_checkpointing_checkpoints_total_limit_removes_multiple # check checkpoint directories exist self.assertEqual( {x for x in os.listdir(tmpdir) if "checkpoint" in x}, - {"checkpoint-12", "checkpoint-8", "checkpoint-10"}, + {"checkpoint-6", "checkpoint-8", "checkpoint-10"}, ) def test_dreambooth_checkpointing_checkpoints_total_limit(self): From eb5c2d44828b7184253550a22d41ec750f5d3bd1 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 27 Sep 2023 08:50:59 +0530 Subject: [PATCH 11/22] debug --- examples/test_examples.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/examples/test_examples.py b/examples/test_examples.py index 89e866231e89..75f87961a3fc 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -27,11 +27,14 @@ from accelerate.utils import write_basic_config from diffusers import DiffusionPipeline, UNet2DConditionModel +import torch +import diffusers +log = logging.getLogger("test") +log.setLevel(logging.DEBUG) +logging.basicConfig(level=logging.DEBUG, format='%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s') -logging.basicConfig(level=logging.DEBUG) - -logger = logging.getLogger() +log.info(f'loaded: torch={torch.__version__} diffusers={diffusers.__version__}') # These utils relate to ensuring the right error message is received when running scripts From c750e0ab7ce583cd2407f0fd91bfe9aa82eff90a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 27 Sep 2023 08:52:27 +0530 Subject: [PATCH 12/22] debug --- examples/test_examples.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/test_examples.py b/examples/test_examples.py index 75f87961a3fc..2f00657a4987 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -59,8 +59,8 @@ def run_command(command: List[str], return_stdout=False): ) from e -stream_handler = logging.StreamHandler(sys.stdout) -logger.addHandler(stream_handler) +# stream_handler = logging.StreamHandler(sys.stdout) +# logger.addHandler(stream_handler) class ExamplesTestsAccelerate(unittest.TestCase): From ed386805093a8e6a2d190ebc5b524885cc3606f4 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 27 Sep 2023 08:54:52 +0530 Subject: [PATCH 13/22] back to the square --- examples/test_examples.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/examples/test_examples.py b/examples/test_examples.py index 2f00657a4987..89e866231e89 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -27,14 +27,11 @@ from accelerate.utils import write_basic_config from diffusers import DiffusionPipeline, UNet2DConditionModel -import torch -import diffusers -log = logging.getLogger("test") -log.setLevel(logging.DEBUG) -logging.basicConfig(level=logging.DEBUG, format='%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s') -log.info(f'loaded: torch={torch.__version__} diffusers={diffusers.__version__}') +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() # These utils relate to ensuring the right error message is received when running scripts @@ -59,8 +56,8 @@ def run_command(command: List[str], return_stdout=False): ) from e -# stream_handler = logging.StreamHandler(sys.stdout) -# logger.addHandler(stream_handler) +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) class ExamplesTestsAccelerate(unittest.TestCase): From 653b3c5952e39daa15f20e0565198ccfb08bc737 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 27 Sep 2023 09:08:29 +0530 Subject: [PATCH 14/22] debug --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 9d1187f25149..7c5c1afa31cf 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -747,13 +747,13 @@ def collate_fn(examples): if args.resume_from_checkpoint: if args.resume_from_checkpoint != "latest": path = os.path.basename(args.resume_from_checkpoint) + print(f"Path found: {path}") else: # Get the most recent checkpoint dirs = os.listdir(args.output_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) path = dirs[-1] if len(dirs) > 0 else None - print(f"Path found: {path}") if path is None: accelerator.print( From 03f5b9140438375c1737ea69865a35c386a709bc Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 27 Sep 2023 09:26:13 +0530 Subject: [PATCH 15/22] debug --- examples/dreambooth/train_dreambooth.py | 2 ++ examples/instruct_pix2pix/train_instruct_pix2pix.py | 9 ++++----- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index c8bc365751af..33dc05e07d6a 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -1345,11 +1345,13 @@ def compute_text_embeddings(prompt): checkpoints = os.listdir(args.output_dir) checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + print(f"All checkpoints: {checkpoints}") # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints if len(checkpoints) >= args.checkpoints_total_limit: num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 removing_checkpoints = checkpoints[0:num_to_remove] + print(f"To be removed checkpoints: {removing_checkpoints}") logger.info( f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 7c5c1afa31cf..76a4d716f509 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -747,7 +747,6 @@ def collate_fn(examples): if args.resume_from_checkpoint: if args.resume_from_checkpoint != "latest": path = os.path.basename(args.resume_from_checkpoint) - print(f"Path found: {path}") else: # Get the most recent checkpoint dirs = os.listdir(args.output_dir) @@ -763,10 +762,8 @@ def collate_fn(examples): initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") - print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - print(f"Global step: {global_step}") initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch @@ -868,18 +865,20 @@ def collate_fn(examples): accelerator.log({"train_loss": train_loss}, step=global_step) train_loss = 0.0 - if global_step % args.checkpointing_steps == 0: - if accelerator.is_main_process: + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: checkpoints = os.listdir(args.output_dir) checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + print(f"All checkpoints: {checkpoints}") # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints if len(checkpoints) >= args.checkpoints_total_limit: num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 removing_checkpoints = checkpoints[0:num_to_remove] + print(f"To be removed checkpoints: {removing_checkpoints}") logger.info( f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" From 095c662c5944ac5e9958001c67b6c4caf4246d7f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 27 Sep 2023 09:46:41 +0530 Subject: [PATCH 16/22] change condition order. --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 76a4d716f509..7b91a6c4e98b 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -865,8 +865,8 @@ def collate_fn(examples): accelerator.log({"train_loss": train_loss}, step=global_step) train_loss = 0.0 - if accelerator.is_main_process: - if global_step % args.checkpointing_steps == 0: + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: checkpoints = os.listdir(args.output_dir) From b3f290765d298f1809926f35d3521fcb4977df31 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 27 Sep 2023 09:54:50 +0530 Subject: [PATCH 17/22] debug --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 7b91a6c4e98b..02a0af16b7b1 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -772,6 +772,8 @@ def collate_fn(examples): initial_global_step = 0 print(f"Initial global step: {initial_global_step}") + checkpoints = os.listdir(args.output_dir) + print([d for d in checkpoints if d.startswith("checkpoint")]) progress_bar = tqdm( range(0, args.max_train_steps), initial=initial_global_step, From cfc42e3830d789620ef6621672c2eb9ec9109a33 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 27 Sep 2023 10:00:57 +0530 Subject: [PATCH 18/22] debug --- examples/dreambooth/train_dreambooth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 33dc05e07d6a..960a68505699 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -1219,7 +1219,7 @@ def compute_text_embeddings(prompt): # Only show the progress bar once on each machine. disable=not accelerator.is_local_main_process, ) - + print([d for d in checkpoints if d.startswith("checkpoint")]) for epoch in range(first_epoch, args.num_train_epochs): unet.train() if args.train_text_encoder: From d9da100a5432a20c9994ecbe386edb0967f26d75 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 27 Sep 2023 10:05:58 +0530 Subject: [PATCH 19/22] debug --- examples/dreambooth/train_dreambooth.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 960a68505699..c739800ccfe8 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -1219,6 +1219,7 @@ def compute_text_embeddings(prompt): # Only show the progress bar once on each machine. disable=not accelerator.is_local_main_process, ) + checkpoints = os.listdir(args.output_dir) print([d for d in checkpoints if d.startswith("checkpoint")]) for epoch in range(first_epoch, args.num_train_epochs): unet.train() From 88a055b451e3f0240a947e149c9393f35e2acf18 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 27 Sep 2023 10:14:39 +0530 Subject: [PATCH 20/22] debug --- examples/dreambooth/train_dreambooth.py | 1 + examples/instruct_pix2pix/train_instruct_pix2pix.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index c739800ccfe8..75bba00ada7f 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -1391,6 +1391,7 @@ def compute_text_embeddings(prompt): break # Create the pipeline using using the trained modules and save it. + print(f"Final global step: {global_step}.") accelerator.wait_for_everyone() if accelerator.is_main_process: pipeline_args = {} diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 02a0af16b7b1..c0b65c79d07a 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -900,7 +900,8 @@ def collate_fn(examples): if global_step >= args.max_train_steps: break - + + print(f"Final global step: {global_step}.") if accelerator.is_main_process: if ( (args.val_image_url is not None) From d6bc2bededb2433f2f45b419a4c13e7bb3b7bb3d Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 27 Sep 2023 10:47:00 +0530 Subject: [PATCH 21/22] revert to original --- .../train_instruct_pix2pix.py | 37 +++++++------------ 1 file changed, 14 insertions(+), 23 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index c0b65c79d07a..2f54a8319e28 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -464,9 +464,6 @@ def main(): vae.requires_grad_(False) text_encoder.requires_grad_(False) - # Set UNet to trainable. - unet.train() - # Create EMA for the unet. if args.use_ema: ema_unet = EMAModel(unet.parameters(), model_cls=UNet2DConditionModel, model_config=unet.config) @@ -759,32 +756,29 @@ def collate_fn(examples): f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None - initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - initial_global_step = global_step + resume_global_step = global_step * args.gradient_accumulation_steps first_epoch = global_step // num_update_steps_per_epoch + resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) - else: - initial_global_step = 0 - - print(f"Initial global step: {initial_global_step}") - checkpoints = os.listdir(args.output_dir) - print([d for d in checkpoints if d.startswith("checkpoint")]) - progress_bar = tqdm( - range(0, args.max_train_steps), - initial=initial_global_step, - desc="Steps", - # Only show the progress bar once on each machine. - disable=not accelerator.is_local_main_process, - ) + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") for epoch in range(first_epoch, args.num_train_epochs): + unet.train() train_loss = 0.0 for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + continue + with accelerator.accumulate(unet): # We want to learn the denoising process w.r.t the edited images which # are conditioned on the original image (which was edited) and the edit instruction. @@ -874,13 +868,11 @@ def collate_fn(examples): checkpoints = os.listdir(args.output_dir) checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) - print(f"All checkpoints: {checkpoints}") # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints if len(checkpoints) >= args.checkpoints_total_limit: num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 removing_checkpoints = checkpoints[0:num_to_remove] - print(f"To be removed checkpoints: {removing_checkpoints}") logger.info( f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" @@ -900,8 +892,7 @@ def collate_fn(examples): if global_step >= args.max_train_steps: break - - print(f"Final global step: {global_step}.") + if accelerator.is_main_process: if ( (args.val_image_url is not None) @@ -1015,4 +1006,4 @@ def collate_fn(examples): if __name__ == "__main__": - main() + main() \ No newline at end of file From 2847be07df6d223487996bf013a8f5a909a50ca8 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 27 Sep 2023 10:48:29 +0530 Subject: [PATCH 22/22] clean --- examples/dreambooth/train_dreambooth.py | 6 +----- examples/instruct_pix2pix/train_instruct_pix2pix.py | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 75bba00ada7f..c8bc365751af 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -1219,8 +1219,7 @@ def compute_text_embeddings(prompt): # Only show the progress bar once on each machine. disable=not accelerator.is_local_main_process, ) - checkpoints = os.listdir(args.output_dir) - print([d for d in checkpoints if d.startswith("checkpoint")]) + for epoch in range(first_epoch, args.num_train_epochs): unet.train() if args.train_text_encoder: @@ -1346,13 +1345,11 @@ def compute_text_embeddings(prompt): checkpoints = os.listdir(args.output_dir) checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) - print(f"All checkpoints: {checkpoints}") # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints if len(checkpoints) >= args.checkpoints_total_limit: num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 removing_checkpoints = checkpoints[0:num_to_remove] - print(f"To be removed checkpoints: {removing_checkpoints}") logger.info( f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" @@ -1391,7 +1388,6 @@ def compute_text_embeddings(prompt): break # Create the pipeline using using the trained modules and save it. - print(f"Final global step: {global_step}.") accelerator.wait_for_everyone() if accelerator.is_main_process: pipeline_args = {} diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 2f54a8319e28..5f8a2d9ee150 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -1006,4 +1006,4 @@ def collate_fn(examples): if __name__ == "__main__": - main() \ No newline at end of file + main()