From 90686c2e772ee285aeea737bea416a30a8d84adc Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 12 Aug 2024 17:30:28 +0300 Subject: [PATCH 01/20] add ostris trainer to README & add cache latents of vae --- examples/dreambooth/README_flux.md | 2 +- .../dreambooth/train_dreambooth_lora_flux.py | 26 ++++++++++++++++++- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/README_flux.md b/examples/dreambooth/README_flux.md index fab382c0894c..fd3961b1fd79 100644 --- a/examples/dreambooth/README_flux.md +++ b/examples/dreambooth/README_flux.md @@ -8,7 +8,7 @@ The `train_dreambooth_flux.py` script shows how to implement the training proced > > Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements - > a LoRA with a rank of 16 (w/ all components trained) can exceed 40GB of VRAM for training. -> For more tips & guidance on training on a resource-constrained device please visit [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX.md) +> For more tips & guidance on training on a resource-constrained device please check out these great guides and trainers for FLUX: [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX.md) & [`ostris` guide](https://github.com/ostris/ai-toolkit?tab=readme-ov-file#flux1-training) > [!NOTE] diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 3629fcca4dd0..3122ad2e13c7 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -599,6 +599,12 @@ def parse_args(input_args=None): " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" ), ) + parser.add_argument( + "--cache_latents", + action="store_true", + default=False, + help="Cache the VAE latents", + ) parser.add_argument( "--report_to", type=str, @@ -1456,6 +1462,21 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) + if args.cache_latents: + latents_cache = [] + for batch in tqdm(train_dataloader, desc="Caching latents"): + with torch.no_grad(): + batch["pixel_values"] = batch["pixel_values"].to( + accelerator.device, non_blocking=True, dtype=torch.float32 + ) + latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) + + if args.validation_prompt is None: + del vae + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -1610,7 +1631,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) # Convert images to latent space - model_input = vae.encode(pixel_values).latent_dist.sample() + if args.cache_latents: + model_input = latents_cache[step].sample() + else: + model_input = vae.encode(pixel_values).latent_dist.sample() model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor model_input = model_input.to(dtype=weight_dtype) From 7b12ed2c82e34722dada6375ea4d39fd54deed9f Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 12 Aug 2024 17:54:16 +0300 Subject: [PATCH 02/20] add ostris trainer to README & add cache latents of vae --- examples/dreambooth/train_dreambooth_lora_flux.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 3122ad2e13c7..7acceed0bc66 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1462,12 +1462,15 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) + vae_config_shift_factor = vae.config.shift_factor + vae_config_scaling_factor = vae.config.scaling_factor + vae_config_block_out_channels = vae.config.block_out_channels if args.cache_latents: latents_cache = [] for batch in tqdm(train_dataloader, desc="Caching latents"): with torch.no_grad(): batch["pixel_values"] = batch["pixel_values"].to( - accelerator.device, non_blocking=True, dtype=torch.float32 + accelerator.device, non_blocking=True, dtype=weight_dtype ) latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) @@ -1475,6 +1478,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): del vae if torch.cuda.is_available(): torch.cuda.empty_cache() + gc.collect() # Scheduler and math around the number of training steps. @@ -1599,7 +1603,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.train_text_encoder: models_to_accumulate.extend([text_encoder_one]) with accelerator.accumulate(models_to_accumulate): - pixel_values = batch["pixel_values"].to(dtype=vae.dtype) prompts = batch["prompts"] # encode batch prompts when custom prompts are provided for each image - @@ -1634,11 +1637,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.cache_latents: model_input = latents_cache[step].sample() else: + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) model_input = vae.encode(pixel_values).latent_dist.sample() - model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor + model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor model_input = model_input.to(dtype=weight_dtype) - vae_scale_factor = 2 ** (len(vae.config.block_out_channels)) + vae_scale_factor = 2 ** (len(vae_config_block_out_channels)) latent_image_ids = FluxPipeline._prepare_latent_image_ids( model_input.shape[0], From 17dca18cf1c99359d6e6c075b5ab739dac02fe19 Mon Sep 17 00:00:00 2001 From: Linoy Date: Mon, 12 Aug 2024 15:01:55 +0000 Subject: [PATCH 03/20] style --- examples/dreambooth/train_dreambooth_lora_flux.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 7acceed0bc66..ebf2d3dfcdff 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1480,7 +1480,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): torch.cuda.empty_cache() gc.collect() - # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) From 8b314e9ad672ff823fd47ee0ed7f6bb2022e434a Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 13 Aug 2024 17:10:13 +0300 Subject: [PATCH 04/20] readme --- examples/dreambooth/README_flux.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/README_flux.md b/examples/dreambooth/README_flux.md index fd3961b1fd79..e668e156a0ef 100644 --- a/examples/dreambooth/README_flux.md +++ b/examples/dreambooth/README_flux.md @@ -8,7 +8,9 @@ The `train_dreambooth_flux.py` script shows how to implement the training proced > > Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements - > a LoRA with a rank of 16 (w/ all components trained) can exceed 40GB of VRAM for training. -> For more tips & guidance on training on a resource-constrained device please check out these great guides and trainers for FLUX: [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX.md) & [`ostris` guide](https://github.com/ostris/ai-toolkit?tab=readme-ov-file#flux1-training) +> For more tips & guidance on training on a resource-constrained device please check out these great guides and trainers for FLUX: +> 1) [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX.md) +> 2) [`ostris`'s guide](https://github.com/ostris/ai-toolkit?tab=readme-ov-file#flux1-training) > [!NOTE] From df54cd8dc0606ce2013cce8ebdcd9f9180c89459 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 14 Aug 2024 10:35:38 +0300 Subject: [PATCH 05/20] add test for latent caching --- examples/dreambooth/test_dreambooth_flux.py | 24 +++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/examples/dreambooth/test_dreambooth_flux.py b/examples/dreambooth/test_dreambooth_flux.py index 2d5703d2a24a..c7a83157b091 100644 --- a/examples/dreambooth/test_dreambooth_flux.py +++ b/examples/dreambooth/test_dreambooth_flux.py @@ -62,6 +62,30 @@ def test_dreambooth(self): self.assertTrue(os.path.isfile(os.path.join(tmpdir, "transformer", "diffusion_pytorch_model.safetensors"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) + def test_dreambooth_latent_caching(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --cache_latents + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "transformer", "diffusion_pytorch_model.safetensors"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) + def test_dreambooth_checkpointing(self): with tempfile.TemporaryDirectory() as tmpdir: # Run training script with checkpointing From e0e0319960d0608d220c76c8b4f8b1d59cc84b86 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 14 Aug 2024 10:57:20 +0300 Subject: [PATCH 06/20] add ostris noise scheduler https://github.com/ostris/ai-toolkit/blob/9ee1ef2a0a2a9a02b92d114a95f21312e5906e54/toolkit/samplers/custom_flowmatch_sampler.py#L95 --- .../dreambooth/train_dreambooth_lora_flux.py | 102 +++++++++++++++++- 1 file changed, 101 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 33f5a89a0741..317a230458ef 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1002,6 +1002,106 @@ def encode_prompt( return prompt_embeds, pooled_prompt_embeds, text_ids +# CustomFlowMatchEulerDiscreteScheduler was taken from ostris ai-toolkit trainer: +# https://github.com/ostris/ai-toolkit/blob/9ee1ef2a0a2a9a02b92d114a95f21312e5906e54/toolkit/samplers/custom_flowmatch_sampler.py#L95 +class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + with torch.no_grad(): + # create weights for timesteps + num_timesteps = 1000 + + # generate the multiplier based on cosmap loss weighing + # this is only used on linear timesteps for now + + # cosine map weighing is higher in the middle and lower at the ends + # bot = 1 - 2 * self.sigmas + 2 * self.sigmas ** 2 + # cosmap_weighing = 2 / (math.pi * bot) + + # sigma sqrt weighing is significantly higher at the end and lower at the beginning + sigma_sqrt_weighing = (self.sigmas ** -2.0).float() + # clip at 1e4 (1e6 is too high) + sigma_sqrt_weighing = torch.clamp(sigma_sqrt_weighing, max=1e4) + # bring to a mean of 1 + sigma_sqrt_weighing = sigma_sqrt_weighing / sigma_sqrt_weighing.mean() + + # Create linear timesteps from 1000 to 0 + timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu') + + self.linear_timesteps = timesteps + # self.linear_timesteps_weights = cosmap_weighing + self.linear_timesteps_weights = sigma_sqrt_weighing + + # self.sigmas = self.get_sigmas(timesteps, n_dim=1, dtype=torch.float32, device='cpu') + pass + + def get_weights_for_timesteps(self, timesteps: torch.Tensor) -> torch.Tensor: + # Get the indices of the timesteps + step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps] + + # Get the weights for the timesteps + weights = self.linear_timesteps_weights[step_indices].flatten() + + return weights + + def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor: + sigmas = self.sigmas.to(device=device, dtype=dtype) + schedule_timesteps = self.timesteps.to(device) + timesteps = timesteps.to(device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + + return sigma + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, + ) -> torch.Tensor: + ## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578 + ## Add noise according to flow matching. + ## zt = (1 - texp) * x + texp * z1 + + # sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) + # noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + + # timestep needs to be in [0, 1], we store them in [0, 1000] + # noisy_sample = (1 - timestep) * latent + timestep * noise + t_01 = (timesteps / 1000).to(original_samples.device) + noisy_model_input = (1 - t_01) * original_samples + t_01 * noise + + # n_dim = original_samples.ndim + # sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device) + # noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise + return noisy_model_input + + def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + return sample + + def set_train_timesteps(self, num_timesteps, device, linear=False): + if linear: + timesteps = torch.linspace(1000, 0, num_timesteps, device=device) + self.timesteps = timesteps + return timesteps + else: + # distribute them closer to center. Inference distributes them as a bias toward first + # Generate values from 0 to 1 + t = torch.sigmoid(torch.randn((num_timesteps,), device=device)) + + # Scale and reverse the values to go from 1000 to 0 + timesteps = ((1 - t) * 1000) + + # Sort the timesteps in descending order + timesteps, _ = torch.sort(timesteps, descending=True) + + self.timesteps = timesteps.to(device=device) + + return timesteps def main(args): if args.report_to == "wandb" and args.hub_token is not None: @@ -1133,7 +1233,7 @@ def main(args): ) # Load scheduler and models - noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + noise_scheduler = CustomFlowMatchEulerDiscreteScheduler.from_pretrained( args.pretrained_model_name_or_path, subfolder="scheduler" ) noise_scheduler_copy = copy.deepcopy(noise_scheduler) From 18aa3697c8e0bc477696d4d22123a8f44d0e43ee Mon Sep 17 00:00:00 2001 From: Linoy Date: Wed, 14 Aug 2024 08:00:15 +0000 Subject: [PATCH 07/20] style --- .../dreambooth/train_dreambooth_lora_flux.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 317a230458ef..6e46948d7ec4 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1002,6 +1002,7 @@ def encode_prompt( return prompt_embeds, pooled_prompt_embeds, text_ids + # CustomFlowMatchEulerDiscreteScheduler was taken from ostris ai-toolkit trainer: # https://github.com/ostris/ai-toolkit/blob/9ee1ef2a0a2a9a02b92d114a95f21312e5906e54/toolkit/samplers/custom_flowmatch_sampler.py#L95 class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): @@ -1020,14 +1021,14 @@ def __init__(self, *args, **kwargs): # cosmap_weighing = 2 / (math.pi * bot) # sigma sqrt weighing is significantly higher at the end and lower at the beginning - sigma_sqrt_weighing = (self.sigmas ** -2.0).float() + sigma_sqrt_weighing = (self.sigmas**-2.0).float() # clip at 1e4 (1e6 is too high) sigma_sqrt_weighing = torch.clamp(sigma_sqrt_weighing, max=1e4) # bring to a mean of 1 sigma_sqrt_weighing = sigma_sqrt_weighing / sigma_sqrt_weighing.mean() # Create linear timesteps from 1000 to 0 - timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu') + timesteps = torch.linspace(1000, 0, num_timesteps, device="cpu") self.linear_timesteps = timesteps # self.linear_timesteps_weights = cosmap_weighing @@ -1058,10 +1059,10 @@ def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Ten return sigma def add_noise( - self, - original_samples: torch.Tensor, - noise: torch.Tensor, - timesteps: torch.Tensor, + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, ) -> torch.Tensor: ## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578 ## Add noise according to flow matching. @@ -1094,7 +1095,7 @@ def set_train_timesteps(self, num_timesteps, device, linear=False): t = torch.sigmoid(torch.randn((num_timesteps,), device=device)) # Scale and reverse the values to go from 1000 to 0 - timesteps = ((1 - t) * 1000) + timesteps = (1 - t) * 1000 # Sort the timesteps in descending order timesteps, _ = torch.sort(timesteps, descending=True) @@ -1103,6 +1104,7 @@ def set_train_timesteps(self, num_timesteps, device, linear=False): return timesteps + def main(args): if args.report_to == "wandb" and args.hub_token is not None: raise ValueError( From f97d53d009089c661a2176a18d08fa8bbf008b52 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 14 Aug 2024 11:02:08 +0300 Subject: [PATCH 08/20] fix import --- examples/dreambooth/train_dreambooth_lora_flux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 6e46948d7ec4..6b315cc3dd25 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -44,7 +44,7 @@ from torchvision.transforms.functional import crop from tqdm.auto import tqdm from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast - +from typing import Union import diffusers from diffusers import ( AutoencoderKL, From 0156becdf99ad64bc718b30048115de0a6c75dcf Mon Sep 17 00:00:00 2001 From: Linoy Date: Wed, 14 Aug 2024 08:04:12 +0000 Subject: [PATCH 09/20] style --- examples/dreambooth/train_dreambooth_lora_flux.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 6b315cc3dd25..d3675fd56fa2 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -25,6 +25,7 @@ import warnings from contextlib import nullcontext from pathlib import Path +from typing import Union import numpy as np import torch @@ -44,7 +45,7 @@ from torchvision.transforms.functional import crop from tqdm.auto import tqdm from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast -from typing import Union + import diffusers from diffusers import ( AutoencoderKL, From c4c2c48d248447cb31ad9e0ab0a783ad4f7e777e Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 14 Aug 2024 11:59:36 +0300 Subject: [PATCH 10/20] fix tests --- examples/dreambooth/test_dreambooth_flux.py | 24 -------------- .../dreambooth/test_dreambooth_lora_flux.py | 31 +++++++++++++++++++ 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/examples/dreambooth/test_dreambooth_flux.py b/examples/dreambooth/test_dreambooth_flux.py index c7a83157b091..2d5703d2a24a 100644 --- a/examples/dreambooth/test_dreambooth_flux.py +++ b/examples/dreambooth/test_dreambooth_flux.py @@ -62,30 +62,6 @@ def test_dreambooth(self): self.assertTrue(os.path.isfile(os.path.join(tmpdir, "transformer", "diffusion_pytorch_model.safetensors"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) - def test_dreambooth_latent_caching(self): - with tempfile.TemporaryDirectory() as tmpdir: - test_args = f""" - {self.script_path} - --pretrained_model_name_or_path {self.pretrained_model_name_or_path} - --instance_data_dir {self.instance_data_dir} - --instance_prompt {self.instance_prompt} - --resolution 64 - --train_batch_size 1 - --gradient_accumulation_steps 1 - --cache_latents - --max_train_steps 2 - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --output_dir {tmpdir} - """.split() - - run_command(self._launch_args + test_args) - # save_pretrained smoke test - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "transformer", "diffusion_pytorch_model.safetensors"))) - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) - def test_dreambooth_checkpointing(self): with tempfile.TemporaryDirectory() as tmpdir: # Run training script with checkpointing diff --git a/examples/dreambooth/test_dreambooth_lora_flux.py b/examples/dreambooth/test_dreambooth_lora_flux.py index b77f84447aaa..1dd9d73a2ca6 100644 --- a/examples/dreambooth/test_dreambooth_lora_flux.py +++ b/examples/dreambooth/test_dreambooth_lora_flux.py @@ -102,7 +102,38 @@ def test_dreambooth_lora_text_encoder_flux(self): (key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys() ) self.assertTrue(starts_with_expected_prefix) + def test_dreambooth_lora_latent_caching(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --cache_latents + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. + starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer) def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f""" From d514c7bac0e7d23630b15626ebafb46854066ac3 Mon Sep 17 00:00:00 2001 From: Linoy Date: Wed, 14 Aug 2024 12:35:29 +0000 Subject: [PATCH 11/20] style --- examples/dreambooth/test_dreambooth_lora_flux.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/dreambooth/test_dreambooth_lora_flux.py b/examples/dreambooth/test_dreambooth_lora_flux.py index 1dd9d73a2ca6..d197c8187b87 100644 --- a/examples/dreambooth/test_dreambooth_lora_flux.py +++ b/examples/dreambooth/test_dreambooth_lora_flux.py @@ -102,6 +102,7 @@ def test_dreambooth_lora_text_encoder_flux(self): (key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys() ) self.assertTrue(starts_with_expected_prefix) + def test_dreambooth_lora_latent_caching(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f""" @@ -134,6 +135,7 @@ def test_dreambooth_lora_latent_caching(self): # with `"transformer"` in their names. starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) self.assertTrue(starts_with_transformer) + def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f""" From d5c2a36af11d8c55b26d0357d81213d17639bb2f Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 16 Aug 2024 17:57:59 +0300 Subject: [PATCH 12/20] --change upcasting of transformer? --- examples/dreambooth/train_dreambooth_lora_flux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index d3675fd56fa2..d0238c403725 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1923,7 +1923,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator.wait_for_everyone() if accelerator.is_main_process: transformer = unwrap_model(transformer) - transformer = transformer.to(torch.float32) + transformer = transformer.to(weight_dtype) transformer_lora_layers = get_peft_model_state_dict(transformer) if args.train_text_encoder: From fbacbb564cf49ce43453cecbf23fadf37f5bf332 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 11 Sep 2024 12:00:34 +0300 Subject: [PATCH 13/20] update readme according to main --- examples/dreambooth/README_flux.md | 50 ++++++++++++++++++++++++------ 1 file changed, 40 insertions(+), 10 deletions(-) diff --git a/examples/dreambooth/README_flux.md b/examples/dreambooth/README_flux.md index b816ae6153e0..eaa0ebd80666 100644 --- a/examples/dreambooth/README_flux.md +++ b/examples/dreambooth/README_flux.md @@ -9,10 +9,9 @@ The `train_dreambooth_flux.py` script shows how to implement the training proced > Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements - > a LoRA with a rank of 16 (w/ all components trained) can exceed 40GB of VRAM for training. -> For more tips & guidance on training on a resource-constrained device please check out these great guides and trainers for FLUX: +> For more tips & guidance on training on a resource-constrained device and general good practices please check out these great guides and trainers for FLUX: > 1) [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX.md) -> 2) [`ostris`'s guide](https://github.com/ostris/ai-toolkit?tab=readme-ov-file#flux1-training) - +> 2) [`ostris`'s guide](https://github.com/ostris/ai-toolkit?tab=readme-ov-file#flux1-training) > [!NOTE] > **Gated model** @@ -103,8 +102,10 @@ accelerate launch train_dreambooth_flux.py \ --instance_prompt="a photo of sks dog" \ --resolution=1024 \ --train_batch_size=1 \ + --guidance_scale=1 \ --gradient_accumulation_steps=4 \ - --learning_rate=1e-4 \ + --optimizer="prodigy" \ + --learning_rate=1. \ --report_to="wandb" \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ @@ -123,15 +124,23 @@ To better track our training experiments, we're using the following flags in the > [!NOTE] > If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases. -> [!TIP] -> You can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so. - ## LoRA + DreamBooth [LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters. Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment. +### Prodigy Optimizer +Prodigy is an adaptive optimizer that dynamically adjusts the learning rate learned parameters based on past gradients, allowing for more efficient convergence. +By using prodigy we can "eliminate" the need for manual learning rate tuning. read more [here](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers). + +to use prodigy, specify +```bash +--optimizer="prodigy" +``` +> [!TIP] +> When using prodigy it's generally good practice to set- `--learning_rate=1.0` + To perform DreamBooth with LoRA, run: ```bash @@ -147,8 +156,10 @@ accelerate launch train_dreambooth_lora_flux.py \ --instance_prompt="a photo of sks dog" \ --resolution=512 \ --train_batch_size=1 \ + --guidance_scale=1 \ --gradient_accumulation_steps=4 \ - --learning_rate=1e-5 \ + --optimizer="prodigy" \ + --learning_rate=1. \ --report_to="wandb" \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ @@ -165,6 +176,7 @@ Alongside the transformer, fine-tuning of the CLIP text encoder is also supporte To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind: > [!NOTE] +> This is still an experimental feature. > FLUX.1 has 2 text encoders (CLIP L/14 and T5-v1.1-XXL). By enabling `--train_text_encoder`, fine-tuning of the **CLIP encoder** is performed. > At the moment, T5 fine-tuning is not supported and weights remain frozen when text encoder training is enabled. @@ -183,8 +195,10 @@ accelerate launch train_dreambooth_lora_flux.py \ --instance_prompt="a photo of sks dog" \ --resolution=512 \ --train_batch_size=1 \ + --guidance_scale=1 \ --gradient_accumulation_steps=4 \ - --learning_rate=1e-5 \ + --optimizer="prodigy" \ + --learning_rate=1. \ --report_to="wandb" \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ @@ -194,5 +208,21 @@ accelerate launch train_dreambooth_lora_flux.py \ --push_to_hub ``` +## Memory Optimizations +As mentioned, Flux Dreambooth LoRA training is very memory intensive Here are some options (some still experimental) for a more memory efficient training. +### Image Resolution +An easy way to mitigate some of the memory requirements is through `--resolution`. `--resolution` refers to the resolution for input images, all the images in the train/validation dataset are resized to this. +Note that by default, images are resized to resolution of 512, but it's good to keep in mind in case you're accustomed to training on higher resolutions. +### Gradient Checkpointing and Accumulation +* `--gradient accumulation` refers to the number of updates steps to accumulate before performing a backward/update pass. +by passing a value > 1 you can reduce the amount of backward/update passes and hence also memory reqs. +* with `--gradient checkpointing` we can save memory by not storing all intermediate activations during the forward pass. +Instead, only a subset of these activations (the checkpoints) are stored and the rest is recomputed as needed during the backward pass. Note that this comes at the expanse of a slower backward pass. +### 8-bit-Adam Optimizer +When training with `AdamW`(doesn't apply to `prodigy`) You can pass `--use_8bit_adam` to reduce the memory requirements of training. +Make sure to install `bitsandbytes` if you want to do so. +### latent caching +When training w/o validation runs, we can pre-encode the training images with the vae, and then delete it to free up some memory. +to enable `latent_caching`, first, use the version in [this PR](https://github.com/huggingface/diffusers/blob/1b195933d04e4c8281a2634128c0d2d380893f73/examples/dreambooth/train_dreambooth_lora_flux.py), and then pass `--cache_latents` ## Other notes -Thanks to `bghira` for their help with reviewing & insight sharing ♥️ \ No newline at end of file +Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️ \ No newline at end of file From feae3dc88849e63291f21280782c03964e5bec8d Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 13 Sep 2024 18:19:16 +0300 Subject: [PATCH 14/20] keep only latent caching --- .../dreambooth/train_dreambooth_lora_flux.py | 104 +----------------- 1 file changed, 1 insertion(+), 103 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index b87bf6620aed..b86d6f150db2 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1004,108 +1004,6 @@ def encode_prompt( return prompt_embeds, pooled_prompt_embeds, text_ids -# CustomFlowMatchEulerDiscreteScheduler was taken from ostris ai-toolkit trainer: -# https://github.com/ostris/ai-toolkit/blob/9ee1ef2a0a2a9a02b92d114a95f21312e5906e54/toolkit/samplers/custom_flowmatch_sampler.py#L95 -class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - with torch.no_grad(): - # create weights for timesteps - num_timesteps = 1000 - - # generate the multiplier based on cosmap loss weighing - # this is only used on linear timesteps for now - - # cosine map weighing is higher in the middle and lower at the ends - # bot = 1 - 2 * self.sigmas + 2 * self.sigmas ** 2 - # cosmap_weighing = 2 / (math.pi * bot) - - # sigma sqrt weighing is significantly higher at the end and lower at the beginning - sigma_sqrt_weighing = (self.sigmas**-2.0).float() - # clip at 1e4 (1e6 is too high) - sigma_sqrt_weighing = torch.clamp(sigma_sqrt_weighing, max=1e4) - # bring to a mean of 1 - sigma_sqrt_weighing = sigma_sqrt_weighing / sigma_sqrt_weighing.mean() - - # Create linear timesteps from 1000 to 0 - timesteps = torch.linspace(1000, 0, num_timesteps, device="cpu") - - self.linear_timesteps = timesteps - # self.linear_timesteps_weights = cosmap_weighing - self.linear_timesteps_weights = sigma_sqrt_weighing - - # self.sigmas = self.get_sigmas(timesteps, n_dim=1, dtype=torch.float32, device='cpu') - pass - - def get_weights_for_timesteps(self, timesteps: torch.Tensor) -> torch.Tensor: - # Get the indices of the timesteps - step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps] - - # Get the weights for the timesteps - weights = self.linear_timesteps_weights[step_indices].flatten() - - return weights - - def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor: - sigmas = self.sigmas.to(device=device, dtype=dtype) - schedule_timesteps = self.timesteps.to(device) - timesteps = timesteps.to(device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < n_dim: - sigma = sigma.unsqueeze(-1) - - return sigma - - def add_noise( - self, - original_samples: torch.Tensor, - noise: torch.Tensor, - timesteps: torch.Tensor, - ) -> torch.Tensor: - ## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578 - ## Add noise according to flow matching. - ## zt = (1 - texp) * x + texp * z1 - - # sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) - # noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise - - # timestep needs to be in [0, 1], we store them in [0, 1000] - # noisy_sample = (1 - timestep) * latent + timestep * noise - t_01 = (timesteps / 1000).to(original_samples.device) - noisy_model_input = (1 - t_01) * original_samples + t_01 * noise - - # n_dim = original_samples.ndim - # sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device) - # noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise - return noisy_model_input - - def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: - return sample - - def set_train_timesteps(self, num_timesteps, device, linear=False): - if linear: - timesteps = torch.linspace(1000, 0, num_timesteps, device=device) - self.timesteps = timesteps - return timesteps - else: - # distribute them closer to center. Inference distributes them as a bias toward first - # Generate values from 0 to 1 - t = torch.sigmoid(torch.randn((num_timesteps,), device=device)) - - # Scale and reverse the values to go from 1000 to 0 - timesteps = (1 - t) * 1000 - - # Sort the timesteps in descending order - timesteps, _ = torch.sort(timesteps, descending=True) - - self.timesteps = timesteps.to(device=device) - - return timesteps - - def main(args): if args.report_to == "wandb" and args.hub_token is not None: raise ValueError( @@ -1236,7 +1134,7 @@ def main(args): ) # Load scheduler and models - noise_scheduler = CustomFlowMatchEulerDiscreteScheduler.from_pretrained( + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( args.pretrained_model_name_or_path, subfolder="scheduler" ) noise_scheduler_copy = copy.deepcopy(noise_scheduler) From b53ae0be72ecccaef1da4b6457692553fdb6b764 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 13 Sep 2024 18:42:11 +0300 Subject: [PATCH 15/20] add configurable param for final saving of trained layers- --upcast_before_saving --- examples/dreambooth/README_flux.md | 9 +++++++-- examples/dreambooth/train_dreambooth_lora_flux.py | 14 +++++++++++++- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/README_flux.md b/examples/dreambooth/README_flux.md index eaa0ebd80666..abb53cc6ae46 100644 --- a/examples/dreambooth/README_flux.md +++ b/examples/dreambooth/README_flux.md @@ -221,8 +221,13 @@ Instead, only a subset of these activations (the checkpoints) are stored and the ### 8-bit-Adam Optimizer When training with `AdamW`(doesn't apply to `prodigy`) You can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so. -### latent caching +### Latent caching When training w/o validation runs, we can pre-encode the training images with the vae, and then delete it to free up some memory. -to enable `latent_caching`, first, use the version in [this PR](https://github.com/huggingface/diffusers/blob/1b195933d04e4c8281a2634128c0d2d380893f73/examples/dreambooth/train_dreambooth_lora_flux.py), and then pass `--cache_latents` +to enable `latent_caching` simply pass `--cache_latents`. +### Precision of saved LoRA layers +By default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision="bf16"`, final finetuned layers will be saved in `bf16` as well. +This reduces memory requirements significantly w/o a significant quality loss. Note that if you do wish to save the final layers in float32 at the expanse of more memory usage, +you can do so by passing `--upcast_before_saving` + ## Other notes Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️ \ No newline at end of file diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index b86d6f150db2..1451d46eda58 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -626,6 +626,15 @@ def parse_args(input_args=None): " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) + parser.add_argument( + "--upcast_before_saving", + action="store_true", + default=False, + help=( + "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). " + "Defaults to precision dtype used for training to save memory" + ), + ) parser.add_argument( "--prior_generation_precision", type=str, @@ -1823,7 +1832,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator.wait_for_everyone() if accelerator.is_main_process: transformer = unwrap_model(transformer) - transformer = transformer.to(weight_dtype) + if args.upcast_before_saving: + transformer.to(torch.float32) + else: + transformer = transformer.to(weight_dtype) transformer_lora_layers = get_peft_model_state_dict(transformer) if args.train_text_encoder: From 5cdb4f5b86fc922704137736aab9fad80dcdbd6a Mon Sep 17 00:00:00 2001 From: Linoy Date: Fri, 13 Sep 2024 16:46:35 +0000 Subject: [PATCH 16/20] style --- examples/dreambooth/train_dreambooth_lora_flux.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 1451d46eda58..6ac1a351fa22 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -25,7 +25,6 @@ import warnings from contextlib import nullcontext from pathlib import Path -from typing import Union import numpy as np import torch From e047ae2fdd791aae597fc48a42f5a6596b493074 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Sat, 14 Sep 2024 20:57:50 +0300 Subject: [PATCH 17/20] Update examples/dreambooth/README_flux.md Co-authored-by: Sayak Paul --- examples/dreambooth/README_flux.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/README_flux.md b/examples/dreambooth/README_flux.md index abb53cc6ae46..04fae3838f72 100644 --- a/examples/dreambooth/README_flux.md +++ b/examples/dreambooth/README_flux.md @@ -225,7 +225,7 @@ Make sure to install `bitsandbytes` if you want to do so. When training w/o validation runs, we can pre-encode the training images with the vae, and then delete it to free up some memory. to enable `latent_caching` simply pass `--cache_latents`. ### Precision of saved LoRA layers -By default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision="bf16"`, final finetuned layers will be saved in `bf16` as well. +By default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision="bf16"`, final finetuned layers will be saved in `torch.bfloat16` as well. This reduces memory requirements significantly w/o a significant quality loss. Note that if you do wish to save the final layers in float32 at the expanse of more memory usage, you can do so by passing `--upcast_before_saving` From a882c418d9f34108f6c10dbe9635f42fffbc64d9 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Sat, 14 Sep 2024 20:58:28 +0300 Subject: [PATCH 18/20] Update examples/dreambooth/README_flux.md Co-authored-by: Sayak Paul --- examples/dreambooth/README_flux.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/dreambooth/README_flux.md b/examples/dreambooth/README_flux.md index 04fae3838f72..69dfd241395b 100644 --- a/examples/dreambooth/README_flux.md +++ b/examples/dreambooth/README_flux.md @@ -226,8 +226,7 @@ When training w/o validation runs, we can pre-encode the training images with th to enable `latent_caching` simply pass `--cache_latents`. ### Precision of saved LoRA layers By default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision="bf16"`, final finetuned layers will be saved in `torch.bfloat16` as well. -This reduces memory requirements significantly w/o a significant quality loss. Note that if you do wish to save the final layers in float32 at the expanse of more memory usage, -you can do so by passing `--upcast_before_saving` +This reduces memory requirements significantly w/o a significant quality loss. Note that if you do wish to save the final layers in float32 at the expanse of more memory usage, you can do so by passing `--upcast_before_saving`. ## Other notes Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️ \ No newline at end of file From 75058d7f96c7607cefe8a2a261d50764ded48749 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Sat, 14 Sep 2024 21:03:52 +0300 Subject: [PATCH 19/20] use clear_objs_and_retain_memory from utilities --- .../dreambooth/train_dreambooth_lora_flux.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 6ac1a351fa22..a98617419947 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -58,6 +58,7 @@ cast_training_params, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, + clear_objs_and_retain_memory ) from diffusers.utils import ( check_min_version, @@ -1436,12 +1437,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Clear the memory here if not args.train_text_encoder and not train_dataset.custom_instance_prompts: - del tokenizers, text_encoders - # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection - del text_encoder_one, text_encoder_two - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clear_objs_and_retain_memory([tokenizers, text_encoders, text_encoder_one, text_encoder_two]) # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't @@ -1484,10 +1480,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) if args.validation_prompt is None: - del vae - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clear_objs_and_retain_memory([vae]) # Scheduler and math around the number of training steps. overrode_max_train_steps = False @@ -1823,9 +1816,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): epoch=epoch, ) if not args.train_text_encoder: - del text_encoder_one, text_encoder_two - torch.cuda.empty_cache() - gc.collect() + clear_objs_and_retain_memory([text_encoder_one, text_encoder_two]) # Save the lora layers accelerator.wait_for_everyone() From 88c02752b727917685701d623139aa71a8f2443e Mon Sep 17 00:00:00 2001 From: Linoy Date: Sun, 15 Sep 2024 11:57:02 +0000 Subject: [PATCH 20/20] style --- examples/dreambooth/train_dreambooth_lora_flux.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index c94f73b18583..6091622719ee 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -15,7 +15,6 @@ import argparse import copy -import gc import itertools import logging import math @@ -56,9 +55,9 @@ from diffusers.training_utils import ( _set_state_dict_into_text_encoder, cast_training_params, + clear_objs_and_retain_memory, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, - clear_objs_and_retain_memory ) from diffusers.utils import ( check_min_version,