From eb7b9c462d4eb675225ad3ea32c5227dfd8b34bf Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 20 Jun 2024 17:27:03 +0530 Subject: [PATCH 1/5] add a test suite for SD3 DreamBooth --- examples/dreambooth/test_dreambooth_sd3.py | 203 +++++++++++++++++++++ 1 file changed, 203 insertions(+) create mode 100644 examples/dreambooth/test_dreambooth_sd3.py diff --git a/examples/dreambooth/test_dreambooth_sd3.py b/examples/dreambooth/test_dreambooth_sd3.py new file mode 100644 index 000000000000..01ad16bbd6e2 --- /dev/null +++ b/examples/dreambooth/test_dreambooth_sd3.py @@ -0,0 +1,203 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import shutil +import sys +import tempfile + +from diffusers import DiffusionPipeline, SD3Transformer2DModel + + +sys.path.append("..") +from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) + + +class DreamBooth(ExamplesTestsAccelerate): + instance_data_dir = "docs/source/en/imgs" + instance_prompt = "photo" + pretrained_model_name_or_path = "hf-internal-testing/tiny-sd3-pipe" + script_path = "examples/dreambooth/train_dreambooth_sd3.py" + + def test_dreambooth(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "transformer", "diffusion_pytorch_model.safetensors"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) + + def test_dreambooth_checkpointing(self): + with tempfile.TemporaryDirectory() as tmpdir: + # Run training script with checkpointing + # max_train_steps == 4, checkpointing_steps == 2 + # Should create checkpoints at steps 2, 4 + + initial_run_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 4 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=2 + --seed=0 + """.split() + + run_command(self._launch_args + initial_run_args) + + # check can run the original fully trained output pipeline + pipe = DiffusionPipeline.from_pretrained(tmpdir) + pipe(self.instance_prompt, num_inference_steps=1) + + # check checkpoint directories exist + self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2"))) + self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4"))) + + # check can run an intermediate checkpoint + transformer = SD3Transformer2DModel.from_pretrained(tmpdir, subfolder="checkpoint-2/transformer") + pipe = DiffusionPipeline.from_pretrained(self.pretrained_model_name_or_path, transformer=transformer) + pipe(self.instance_prompt, num_inference_steps=1) + + # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming + shutil.rmtree(os.path.join(tmpdir, "checkpoint-2")) + + # Run training script for 7 total steps resuming from checkpoint 4 + + resume_run_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 6 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=2 + --resume_from_checkpoint=checkpoint-4 + --seed=0 + """.split() + + run_command(self._launch_args + resume_run_args) + + # check can run new fully trained pipeline + pipe = DiffusionPipeline.from_pretrained(tmpdir) + pipe(self.instance_prompt, num_inference_steps=1) + + # check old checkpoints do not exist + self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2"))) + + # check new checkpoints exist + self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4"))) + self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6"))) + + def test_dreambooth_checkpointing_checkpoints_total_limit(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --instance_prompt={self.instance_prompt} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=6 + --checkpoints_total_limit=2 + --checkpointing_steps=2 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --instance_prompt={self.instance_prompt} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=4 + --checkpointing_steps=2 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-2", "checkpoint-4"}, + ) + + resume_run_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --instance_prompt={self.instance_prompt} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=8 + --checkpointing_steps=2 + --resume_from_checkpoint=checkpoint-4 + --checkpoints_total_limit=2 + """.split() + + run_command(self._launch_args + resume_run_args) + + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"}) From dc85afdf4dab882e786211f5d03df83584ab80da Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 20 Jun 2024 17:42:46 +0530 Subject: [PATCH 2/5] lora suite --- .../dreambooth/test_dreambooth_lora_sd3.py | 71 +++++++++++++++++++ examples/dreambooth/test_dreambooth_sd3.py | 2 +- 2 files changed, 72 insertions(+), 1 deletion(-) create mode 100644 examples/dreambooth/test_dreambooth_lora_sd3.py diff --git a/examples/dreambooth/test_dreambooth_lora_sd3.py b/examples/dreambooth/test_dreambooth_lora_sd3.py new file mode 100644 index 000000000000..924e5d39861f --- /dev/null +++ b/examples/dreambooth/test_dreambooth_lora_sd3.py @@ -0,0 +1,71 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import sys +import tempfile + +import safetensors + + +sys.path.append("..") +from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) + + +class DreamBoothLoRASD3(ExamplesTestsAccelerate): + instance_data_dir = "docs/source/en/imgs" + instance_prompt = "photo" + pretrained_model_name_or_path = "hf-internal-testing/tiny-sd3-pipe" + script_path = "examples/dreambooth/train_dreambooth_lora_sd3.py" + + def test_dreambooth_lora_sd3(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. + starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer) \ No newline at end of file diff --git a/examples/dreambooth/test_dreambooth_sd3.py b/examples/dreambooth/test_dreambooth_sd3.py index 01ad16bbd6e2..19fb7243cefd 100644 --- a/examples/dreambooth/test_dreambooth_sd3.py +++ b/examples/dreambooth/test_dreambooth_sd3.py @@ -33,7 +33,7 @@ logger.addHandler(stream_handler) -class DreamBooth(ExamplesTestsAccelerate): +class DreamBoothSD3(ExamplesTestsAccelerate): instance_data_dir = "docs/source/en/imgs" instance_prompt = "photo" pretrained_model_name_or_path = "hf-internal-testing/tiny-sd3-pipe" From 1a34e7caec66b6220a3eae0c38be6dd25e2149d1 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 20 Jun 2024 17:44:11 +0530 Subject: [PATCH 3/5] style --- examples/dreambooth/test_dreambooth_lora_sd3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/test_dreambooth_lora_sd3.py b/examples/dreambooth/test_dreambooth_lora_sd3.py index 924e5d39861f..8a12a5eed1de 100644 --- a/examples/dreambooth/test_dreambooth_lora_sd3.py +++ b/examples/dreambooth/test_dreambooth_lora_sd3.py @@ -68,4 +68,4 @@ def test_dreambooth_lora_sd3(self): # when not training the text encoder, all the parameters in the state dict should start # with `"transformer"` in their names. starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) - self.assertTrue(starts_with_transformer) \ No newline at end of file + self.assertTrue(starts_with_transformer) From 24e518062df63d0c1477fd120edcd35aa33fedb9 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 21 Jun 2024 10:48:54 +0530 Subject: [PATCH 4/5] add checkpointing tests for LoRA --- .../dreambooth/test_dreambooth_lora_sd3.py | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/examples/dreambooth/test_dreambooth_lora_sd3.py b/examples/dreambooth/test_dreambooth_lora_sd3.py index 8a12a5eed1de..1dfd45e97464 100644 --- a/examples/dreambooth/test_dreambooth_lora_sd3.py +++ b/examples/dreambooth/test_dreambooth_lora_sd3.py @@ -69,3 +69,64 @@ def test_dreambooth_lora_sd3(self): # with `"transformer"` in their names. starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --instance_prompt={self.instance_prompt} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=6 + --checkpoints_total_limit=2 + --checkpointing_steps=2 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --instance_prompt={self.instance_prompt} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=4 + --checkpointing_steps=2 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"}) + + resume_run_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --instance_prompt={self.instance_prompt} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=8 + --checkpointing_steps=2 + --resume_from_checkpoint=checkpoint-4 + --checkpoints_total_limit=2 + """.split() + + run_command(self._launch_args + resume_run_args) + + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"}) From 9a8870eb227904cc238a5e70a3fc0ef4e1707cd2 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 2 Jul 2024 06:49:15 +0530 Subject: [PATCH 5/5] add test to cover train_text_encoder. --- .../dreambooth/test_dreambooth_lora_sd3.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/examples/dreambooth/test_dreambooth_lora_sd3.py b/examples/dreambooth/test_dreambooth_lora_sd3.py index 1dfd45e97464..518738b78246 100644 --- a/examples/dreambooth/test_dreambooth_lora_sd3.py +++ b/examples/dreambooth/test_dreambooth_lora_sd3.py @@ -70,6 +70,39 @@ def test_dreambooth_lora_sd3(self): starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) self.assertTrue(starts_with_transformer) + def test_dreambooth_lora_text_encoder_sd3(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --train_text_encoder + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + starts_with_expected_prefix = all( + (key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys() + ) + self.assertTrue(starts_with_expected_prefix) + def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f"""