diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py index fc358783b5f9..21ab38a3d7a9 100644 --- a/examples/controlnet/train_controlnet.py +++ b/examples/controlnet/train_controlnet.py @@ -18,6 +18,7 @@ import math import os import random +import shutil from pathlib import Path import accelerate @@ -307,11 +308,7 @@ def parse_args(input_args=None): "--checkpoints_total_limit", type=int, default=None, - help=( - "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." - " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" - " for more details" - ), + help=("Max number of checkpoints to store."), ) parser.add_argument( "--resume_from_checkpoint", @@ -716,9 +713,7 @@ def collate_fn(examples): def main(args): logging_dir = Path(args.output_dir, args.logging_dir) - accelerator_project_config = ProjectConfiguration( - total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir - ) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -1060,6 +1055,26 @@ def load_model_hook(models, input_dir): if accelerator.is_main_process: if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py index 421532602137..e0ec56eca1f3 100644 --- a/examples/custom_diffusion/train_custom_diffusion.py +++ b/examples/custom_diffusion/train_custom_diffusion.py @@ -21,6 +21,7 @@ import math import os import random +import shutil import warnings from pathlib import Path @@ -446,11 +447,7 @@ def parse_args(input_args=None): "--checkpoints_total_limit", type=int, default=None, - help=( - "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." - " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" - " for more docs" - ), + help=("Max number of checkpoints to store."), ) parser.add_argument( "--resume_from_checkpoint", @@ -637,9 +634,7 @@ def parse_args(input_args=None): def main(args): logging_dir = Path(args.output_dir, args.logging_dir) - accelerator_project_config = ProjectConfiguration( - total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir - ) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -1171,6 +1166,26 @@ def main(args): if global_step % args.checkpointing_steps == 0: if accelerator.is_main_process: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 695b0a0423a6..797cfbd0e5d7 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -20,6 +20,7 @@ import logging import math import os +import shutil import warnings from pathlib import Path @@ -771,9 +772,7 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte def main(args): logging_dir = Path(args.output_dir, args.logging_dir) - accelerator_project_config = ProjectConfiguration( - total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir - ) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -1270,12 +1269,33 @@ def compute_text_embeddings(prompt): global_step += 1 if accelerator.is_main_process: - images = [] if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") + images = [] + if args.validation_prompt is not None and global_step % args.validation_steps == 0: images = log_validation( text_encoder, diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index b4f099fc2f58..72fcfa648b48 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -20,6 +20,7 @@ import logging import math import os +import shutil import warnings from pathlib import Path @@ -276,11 +277,7 @@ def parse_args(input_args=None): "--checkpoints_total_limit", type=int, default=None, - help=( - "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." - " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" - " for more docs" - ), + help=("Max number of checkpoints to store."), ) parser.add_argument( "--resume_from_checkpoint", @@ -653,9 +650,7 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte def main(args): logging_dir = Path(args.output_dir, args.logging_dir) - accelerator_project_config = ProjectConfiguration( - total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir - ) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -1221,6 +1216,26 @@ def compute_text_embeddings(prompt): if accelerator.is_main_process: if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 08dd5cd42701..e84698a8f215 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -20,6 +20,7 @@ import logging import math import os +import shutil from pathlib import Path import accelerate @@ -327,11 +328,7 @@ def parse_args(): "--checkpoints_total_limit", type=int, default=None, - help=( - "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." - " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" - " for more docs" - ), + help=("Max number of checkpoints to store."), ) parser.add_argument( "--resume_from_checkpoint", @@ -387,9 +384,7 @@ def main(): ), ) logging_dir = os.path.join(args.output_dir, args.logging_dir) - accelerator_project_config = ProjectConfiguration( - total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir - ) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, @@ -867,6 +862,26 @@ def collate_fn(examples): if global_step % args.checkpointing_steps == 0: if accelerator.is_main_process: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") diff --git a/examples/test_examples.py b/examples/test_examples.py index 59c96f44fe93..d11841350064 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -435,8 +435,10 @@ def test_text_to_image_checkpointing(self): pipe(prompt, num_inference_steps=2) # check checkpoint directories exist - self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2"))) - self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4"))) + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-2", "checkpoint-4"}, + ) # check can run an intermediate checkpoint unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet") @@ -474,12 +476,15 @@ def test_text_to_image_checkpointing(self): pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None) pipe(prompt, num_inference_steps=2) - # check old checkpoints do not exist - self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2"))) - - # check new checkpoints exist - self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4"))) - self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6"))) + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + { + # no checkpoint-2 -> check old checkpoints do not exist + # check new checkpoints exist + "checkpoint-4", + "checkpoint-6", + }, + ) def test_text_to_image_checkpointing_use_ema(self): pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" @@ -516,8 +521,10 @@ def test_text_to_image_checkpointing_use_ema(self): pipe(prompt, num_inference_steps=2) # check checkpoint directories exist - self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2"))) - self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4"))) + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-2", "checkpoint-4"}, + ) # check can run an intermediate checkpoint unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet") @@ -556,9 +563,773 @@ def test_text_to_image_checkpointing_use_ema(self): pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None) pipe(prompt, num_inference_steps=2) - # check old checkpoints do not exist - self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2"))) + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + { + # no checkpoint-2 -> check old checkpoints do not exist + # check new checkpoints exist + "checkpoint-4", + "checkpoint-6", + }, + ) + + def test_text_to_image_checkpointing_checkpoints_total_limit(self): + pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" + prompt = "a prompt" - # check new checkpoints exist - self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4"))) - self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6"))) + with tempfile.TemporaryDirectory() as tmpdir: + # Run training script with checkpointing + # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2 + # Should create checkpoints at steps 2, 4, 6 + # with checkpoint at step 2 deleted + + initial_run_args = f""" + examples/text_to_image/train_text_to_image.py + --pretrained_model_name_or_path {pretrained_model_name_or_path} + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --center_crop + --random_flip + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 7 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=2 + --checkpoints_total_limit=2 + --seed=0 + """.split() + + run_command(self._launch_args + initial_run_args) + + pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None) + pipe(prompt, num_inference_steps=2) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + # checkpoint-2 should have been deleted + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_text_to_image_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): + pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" + prompt = "a prompt" + + with tempfile.TemporaryDirectory() as tmpdir: + # Run training script with checkpointing + # max_train_steps == 9, checkpointing_steps == 2 + # Should create checkpoints at steps 2, 4, 6, 8 + + initial_run_args = f""" + examples/text_to_image/train_text_to_image.py + --pretrained_model_name_or_path {pretrained_model_name_or_path} + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --center_crop + --random_flip + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 9 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=2 + --seed=0 + """.split() + + run_command(self._launch_args + initial_run_args) + + pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None) + pipe(prompt, num_inference_steps=2) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"}, + ) + + # resume and we should try to checkpoint at 10, where we'll have to remove + # checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint + + resume_run_args = f""" + examples/text_to_image/train_text_to_image.py + --pretrained_model_name_or_path {pretrained_model_name_or_path} + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --center_crop + --random_flip + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 11 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=2 + --resume_from_checkpoint=checkpoint-8 + --checkpoints_total_limit=3 + --seed=0 + """.split() + + run_command(self._launch_args + resume_run_args) + + pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None) + pipe(prompt, num_inference_steps=2) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-6", "checkpoint-8", "checkpoint-10"}, + ) + + def test_text_to_image_lora_checkpointing_checkpoints_total_limit(self): + pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" + prompt = "a prompt" + + with tempfile.TemporaryDirectory() as tmpdir: + # Run training script with checkpointing + # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2 + # Should create checkpoints at steps 2, 4, 6 + # with checkpoint at step 2 deleted + + initial_run_args = f""" + examples/text_to_image/train_text_to_image_lora.py + --pretrained_model_name_or_path {pretrained_model_name_or_path} + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --center_crop + --random_flip + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 7 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=2 + --checkpoints_total_limit=2 + --seed=0 + --num_validation_images=0 + """.split() + + run_command(self._launch_args + initial_run_args) + + pipe = DiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None + ) + pipe.load_lora_weights(tmpdir) + pipe(prompt, num_inference_steps=2) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + # checkpoint-2 should have been deleted + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): + pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" + prompt = "a prompt" + + with tempfile.TemporaryDirectory() as tmpdir: + # Run training script with checkpointing + # max_train_steps == 9, checkpointing_steps == 2 + # Should create checkpoints at steps 2, 4, 6, 8 + + initial_run_args = f""" + examples/text_to_image/train_text_to_image_lora.py + --pretrained_model_name_or_path {pretrained_model_name_or_path} + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --center_crop + --random_flip + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 9 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=2 + --seed=0 + --num_validation_images=0 + """.split() + + run_command(self._launch_args + initial_run_args) + + pipe = DiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None + ) + pipe.load_lora_weights(tmpdir) + pipe(prompt, num_inference_steps=2) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"}, + ) + + # resume and we should try to checkpoint at 10, where we'll have to remove + # checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint + + resume_run_args = f""" + examples/text_to_image/train_text_to_image_lora.py + --pretrained_model_name_or_path {pretrained_model_name_or_path} + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --center_crop + --random_flip + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 11 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=2 + --resume_from_checkpoint=checkpoint-8 + --checkpoints_total_limit=3 + --seed=0 + --num_validation_images=0 + """.split() + + run_command(self._launch_args + resume_run_args) + + pipe = DiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None + ) + pipe.load_lora_weights(tmpdir) + pipe(prompt, num_inference_steps=2) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-6", "checkpoint-8", "checkpoint-10"}, + ) + + def test_unconditional_checkpointing_checkpoints_total_limit(self): + with tempfile.TemporaryDirectory() as tmpdir: + initial_run_args = f""" + examples/unconditional_image_generation/train_unconditional.py + --dataset_name hf-internal-testing/dummy_image_class_data + --model_config_name_or_path diffusers/ddpm_dummy + --resolution 64 + --output_dir {tmpdir} + --train_batch_size 1 + --num_epochs 1 + --gradient_accumulation_steps 1 + --ddpm_num_inference_steps 2 + --learning_rate 1e-3 + --lr_warmup_steps 5 + --checkpointing_steps=2 + --checkpoints_total_limit=2 + """.split() + + run_command(self._launch_args + initial_run_args) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + # checkpoint-2 should have been deleted + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_unconditional_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): + with tempfile.TemporaryDirectory() as tmpdir: + initial_run_args = f""" + examples/unconditional_image_generation/train_unconditional.py + --dataset_name hf-internal-testing/dummy_image_class_data + --model_config_name_or_path diffusers/ddpm_dummy + --resolution 64 + --output_dir {tmpdir} + --train_batch_size 1 + --num_epochs 1 + --gradient_accumulation_steps 1 + --ddpm_num_inference_steps 2 + --learning_rate 1e-3 + --lr_warmup_steps 5 + --checkpointing_steps=1 + """.split() + + run_command(self._launch_args + initial_run_args) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-1", "checkpoint-2", "checkpoint-3", "checkpoint-4", "checkpoint-5", "checkpoint-6"}, + ) + + resume_run_args = f""" + examples/unconditional_image_generation/train_unconditional.py + --dataset_name hf-internal-testing/dummy_image_class_data + --model_config_name_or_path diffusers/ddpm_dummy + --resolution 64 + --output_dir {tmpdir} + --train_batch_size 1 + --num_epochs 2 + --gradient_accumulation_steps 1 + --ddpm_num_inference_steps 2 + --learning_rate 1e-3 + --lr_warmup_steps 5 + --resume_from_checkpoint=checkpoint-6 + --checkpointing_steps=2 + --checkpoints_total_limit=3 + """.split() + + run_command(self._launch_args + resume_run_args) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-8", "checkpoint-10", "checkpoint-12"}, + ) + + def test_textual_inversion_checkpointing(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/textual_inversion/textual_inversion.py + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe + --train_data_dir docs/source/en/imgs + --learnable_property object + --placeholder_token + --initializer_token a + --validation_prompt + --validation_steps 1 + --save_steps 1 + --num_vectors 2 + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 3 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=1 + --checkpoints_total_limit=2 + """.split() + + run_command(self._launch_args + test_args) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-2", "checkpoint-3"}, + ) + + def test_textual_inversion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/textual_inversion/textual_inversion.py + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe + --train_data_dir docs/source/en/imgs + --learnable_property object + --placeholder_token + --initializer_token a + --validation_prompt + --validation_steps 1 + --save_steps 1 + --num_vectors 2 + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 3 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=1 + """.split() + + run_command(self._launch_args + test_args) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-1", "checkpoint-2", "checkpoint-3"}, + ) + + resume_run_args = f""" + examples/textual_inversion/textual_inversion.py + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe + --train_data_dir docs/source/en/imgs + --learnable_property object + --placeholder_token + --initializer_token a + --validation_prompt + --validation_steps 1 + --save_steps 1 + --num_vectors 2 + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 4 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=1 + --resume_from_checkpoint=checkpoint-3 + --checkpoints_total_limit=2 + """.split() + + run_command(self._launch_args + resume_run_args) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-3", "checkpoint-4"}, + ) + + def test_instruct_pix2pix_checkpointing_checkpoints_total_limit(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/instruct_pix2pix/train_instruct_pix2pix.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe + --dataset_name=hf-internal-testing/instructpix2pix-10-samples + --resolution=64 + --random_flip + --train_batch_size=1 + --max_train_steps=7 + --checkpointing_steps=2 + --checkpoints_total_limit=2 + --output_dir {tmpdir} + --seed=0 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_instruct_pix2pix_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/instruct_pix2pix/train_instruct_pix2pix.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe + --dataset_name=hf-internal-testing/instructpix2pix-10-samples + --resolution=64 + --random_flip + --train_batch_size=1 + --max_train_steps=9 + --checkpointing_steps=2 + --output_dir {tmpdir} + --seed=0 + """.split() + + run_command(self._launch_args + test_args) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"}, + ) + + resume_run_args = f""" + examples/instruct_pix2pix/train_instruct_pix2pix.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe + --dataset_name=hf-internal-testing/instructpix2pix-10-samples + --resolution=64 + --random_flip + --train_batch_size=1 + --max_train_steps=11 + --checkpointing_steps=2 + --output_dir {tmpdir} + --seed=0 + --resume_from_checkpoint=checkpoint-8 + --checkpoints_total_limit=3 + """.split() + + run_command(self._launch_args + resume_run_args) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-6", "checkpoint-8", "checkpoint-10"}, + ) + + def test_dreambooth_checkpointing_checkpoints_total_limit(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe + --instance_data_dir=docs/source/en/imgs + --output_dir={tmpdir} + --instance_prompt=prompt + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=6 + --checkpoints_total_limit=2 + --checkpointing_steps=2 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe + --instance_data_dir=docs/source/en/imgs + --output_dir={tmpdir} + --instance_prompt=prompt + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=9 + --checkpointing_steps=2 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"}, + ) + + resume_run_args = f""" + examples/dreambooth/train_dreambooth.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe + --instance_data_dir=docs/source/en/imgs + --output_dir={tmpdir} + --instance_prompt=prompt + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=11 + --checkpointing_steps=2 + --resume_from_checkpoint=checkpoint-8 + --checkpoints_total_limit=3 + """.split() + + run_command(self._launch_args + resume_run_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-6", "checkpoint-8", "checkpoint-10"}, + ) + + def test_dreambooth_lora_checkpointing_checkpoints_total_limit(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth_lora.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe + --instance_data_dir=docs/source/en/imgs + --output_dir={tmpdir} + --instance_prompt=prompt + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=6 + --checkpoints_total_limit=2 + --checkpointing_steps=2 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_dreambooth_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth_lora.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe + --instance_data_dir=docs/source/en/imgs + --output_dir={tmpdir} + --instance_prompt=prompt + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=9 + --checkpointing_steps=2 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"}, + ) + + resume_run_args = f""" + examples/dreambooth/train_dreambooth_lora.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe + --instance_data_dir=docs/source/en/imgs + --output_dir={tmpdir} + --instance_prompt=prompt + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=11 + --checkpointing_steps=2 + --resume_from_checkpoint=checkpoint-8 + --checkpoints_total_limit=3 + """.split() + + run_command(self._launch_args + resume_run_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-6", "checkpoint-8", "checkpoint-10"}, + ) + + def test_controlnet_checkpointing_checkpoints_total_limit(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/controlnet/train_controlnet.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe + --dataset_name=hf-internal-testing/fill10 + --output_dir={tmpdir} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=6 + --checkpoints_total_limit=2 + --checkpointing_steps=2 + --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_controlnet_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/controlnet/train_controlnet.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe + --dataset_name=hf-internal-testing/fill10 + --output_dir={tmpdir} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet + --max_train_steps=9 + --checkpointing_steps=2 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"}, + ) + + resume_run_args = f""" + examples/controlnet/train_controlnet.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe + --dataset_name=hf-internal-testing/fill10 + --output_dir={tmpdir} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet + --max_train_steps=11 + --checkpointing_steps=2 + --resume_from_checkpoint=checkpoint-8 + --checkpoints_total_limit=3 + """.split() + + run_command(self._launch_args + resume_run_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-8", "checkpoint-10", "checkpoint-12"}, + ) + + def test_custom_diffusion_checkpointing_checkpoints_total_limit(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/custom_diffusion/train_custom_diffusion.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe + --instance_data_dir=docs/source/en/imgs + --output_dir={tmpdir} + --instance_prompt= + --resolution=64 + --train_batch_size=1 + --modifier_token= + --dataloader_num_workers=0 + --max_train_steps=6 + --checkpoints_total_limit=2 + --checkpointing_steps=2 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/custom_diffusion/train_custom_diffusion.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe + --instance_data_dir=docs/source/en/imgs + --output_dir={tmpdir} + --instance_prompt= + --resolution=64 + --train_batch_size=1 + --modifier_token= + --dataloader_num_workers=0 + --max_train_steps=9 + --checkpointing_steps=2 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"}, + ) + + resume_run_args = f""" + examples/custom_diffusion/train_custom_diffusion.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe + --instance_data_dir=docs/source/en/imgs + --output_dir={tmpdir} + --instance_prompt= + --resolution=64 + --train_batch_size=1 + --modifier_token= + --dataloader_num_workers=0 + --max_train_steps=11 + --checkpointing_steps=2 + --resume_from_checkpoint=checkpoint-8 + --checkpoints_total_limit=3 + """.split() + + run_command(self._launch_args + resume_run_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-6", "checkpoint-8", "checkpoint-10"}, + ) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 3fe72b90b24a..2ec2702e439a 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -18,6 +18,7 @@ import math import os import random +import shutil from pathlib import Path import accelerate @@ -362,11 +363,7 @@ def parse_args(): "--checkpoints_total_limit", type=int, default=None, - help=( - "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." - " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" - " for more docs" - ), + help=("Max number of checkpoints to store."), ) parser.add_argument( "--resume_from_checkpoint", @@ -427,9 +424,7 @@ def main(): ) logging_dir = os.path.join(args.output_dir, args.logging_dir) - accelerator_project_config = ProjectConfiguration( - total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir - ) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -909,6 +904,26 @@ def collate_fn(examples): if global_step % args.checkpointing_steps == 0: if accelerator.is_main_process: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 4a39f37a2896..7c2601d8e9b5 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -19,6 +19,7 @@ import math import os import random +import shutil from pathlib import Path import datasets @@ -327,11 +328,7 @@ def parse_args(): "--checkpoints_total_limit", type=int, default=None, - help=( - "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." - " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" - " for more docs" - ), + help=("Max number of checkpoints to store."), ) parser.add_argument( "--resume_from_checkpoint", @@ -368,9 +365,7 @@ def main(): args = parse_args() logging_dir = Path(args.output_dir, args.logging_dir) - accelerator_project_config = ProjectConfiguration( - total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir - ) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -809,6 +804,26 @@ def collate_fn(examples): if global_step % args.checkpointing_steps == 0: if accelerator.is_main_process: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") @@ -903,18 +918,19 @@ def collate_fn(examples): if accelerator.is_main_process: for tracker in accelerator.trackers: - if tracker.name == "tensorboard": - np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") - if tracker.name == "wandb": - tracker.log( - { - "test": [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") - for i, image in enumerate(images) - ] - } - ) + if len(images) != 0: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "test": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) accelerator.end_training() diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 8c44247a75b5..14b0997862d2 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -18,6 +18,7 @@ import math import os import random +import shutil import warnings from pathlib import Path @@ -394,11 +395,7 @@ def parse_args(): "--checkpoints_total_limit", type=int, default=None, - help=( - "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." - " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" - " for more docs" - ), + help=("Max number of checkpoints to store."), ) parser.add_argument( "--resume_from_checkpoint", @@ -566,9 +563,7 @@ def __getitem__(self, i): def main(): args = parse_args() logging_dir = os.path.join(args.output_dir, args.logging_dir) - accelerator_project_config = ProjectConfiguration( - total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir - ) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, @@ -887,6 +882,26 @@ def main(): if accelerator.is_main_process: if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index b07143f8b267..d6e4b17ba889 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -3,6 +3,7 @@ import logging import math import os +import shutil from pathlib import Path from typing import Optional @@ -245,11 +246,7 @@ def parse_args(): "--checkpoints_total_limit", type=int, default=None, - help=( - "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." - " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" - " for more docs" - ), + help=("Max number of checkpoints to store."), ) parser.add_argument( "--resume_from_checkpoint", @@ -287,9 +284,7 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: def main(args): logging_dir = os.path.join(args.output_dir, args.logging_dir) - accelerator_project_config = ProjectConfiguration( - total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir - ) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -607,6 +602,26 @@ def transform_images(examples): global_step += 1 if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + if accelerator.is_main_process: save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path)