diff --git a/examples/dreambooth/README_flux.md b/examples/dreambooth/README_flux.md index eaa0ebd80666..69dfd241395b 100644 --- a/examples/dreambooth/README_flux.md +++ b/examples/dreambooth/README_flux.md @@ -221,8 +221,12 @@ Instead, only a subset of these activations (the checkpoints) are stored and the ### 8-bit-Adam Optimizer When training with `AdamW`(doesn't apply to `prodigy`) You can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so. -### latent caching +### Latent caching When training w/o validation runs, we can pre-encode the training images with the vae, and then delete it to free up some memory. -to enable `latent_caching`, first, use the version in [this PR](https://github.com/huggingface/diffusers/blob/1b195933d04e4c8281a2634128c0d2d380893f73/examples/dreambooth/train_dreambooth_lora_flux.py), and then pass `--cache_latents` +to enable `latent_caching` simply pass `--cache_latents`. +### Precision of saved LoRA layers +By default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision="bf16"`, final finetuned layers will be saved in `torch.bfloat16` as well. +This reduces memory requirements significantly w/o a significant quality loss. Note that if you do wish to save the final layers in float32 at the expanse of more memory usage, you can do so by passing `--upcast_before_saving`. + ## Other notes Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️ \ No newline at end of file diff --git a/examples/dreambooth/test_dreambooth_lora_flux.py b/examples/dreambooth/test_dreambooth_lora_flux.py index b77f84447aaa..d197c8187b87 100644 --- a/examples/dreambooth/test_dreambooth_lora_flux.py +++ b/examples/dreambooth/test_dreambooth_lora_flux.py @@ -103,6 +103,39 @@ def test_dreambooth_lora_text_encoder_flux(self): ) self.assertTrue(starts_with_expected_prefix) + def test_dreambooth_lora_latent_caching(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --cache_latents + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. + starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer) + def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f""" diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index bd5b46cc9fa9..6091622719ee 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -15,7 +15,6 @@ import argparse import copy -import gc import itertools import logging import math @@ -56,6 +55,7 @@ from diffusers.training_utils import ( _set_state_dict_into_text_encoder, cast_training_params, + clear_objs_and_retain_memory, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, ) @@ -600,6 +600,12 @@ def parse_args(input_args=None): " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" ), ) + parser.add_argument( + "--cache_latents", + action="store_true", + default=False, + help="Cache the VAE latents", + ) parser.add_argument( "--report_to", type=str, @@ -620,6 +626,15 @@ def parse_args(input_args=None): " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) + parser.add_argument( + "--upcast_before_saving", + action="store_true", + default=False, + help=( + "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). " + "Defaults to precision dtype used for training to save memory" + ), + ) parser.add_argument( "--prior_generation_precision", type=str, @@ -1422,12 +1437,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Clear the memory here if not args.train_text_encoder and not train_dataset.custom_instance_prompts: - del tokenizers, text_encoders - # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection - del text_encoder_one, text_encoder_two - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clear_objs_and_retain_memory([tokenizers, text_encoders, text_encoder_one, text_encoder_two]) # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't @@ -1457,6 +1467,21 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) + vae_config_shift_factor = vae.config.shift_factor + vae_config_scaling_factor = vae.config.scaling_factor + vae_config_block_out_channels = vae.config.block_out_channels + if args.cache_latents: + latents_cache = [] + for batch in tqdm(train_dataloader, desc="Caching latents"): + with torch.no_grad(): + batch["pixel_values"] = batch["pixel_values"].to( + accelerator.device, non_blocking=True, dtype=weight_dtype + ) + latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) + + if args.validation_prompt is None: + clear_objs_and_retain_memory([vae]) + # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -1579,7 +1604,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.train_text_encoder: models_to_accumulate.extend([text_encoder_one]) with accelerator.accumulate(models_to_accumulate): - pixel_values = batch["pixel_values"].to(dtype=vae.dtype) prompts = batch["prompts"] # encode batch prompts when custom prompts are provided for each image - @@ -1613,11 +1637,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) # Convert images to latent space - model_input = vae.encode(pixel_values).latent_dist.sample() - model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor + if args.cache_latents: + model_input = latents_cache[step].sample() + else: + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + model_input = vae.encode(pixel_values).latent_dist.sample() + model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor model_input = model_input.to(dtype=weight_dtype) - vae_scale_factor = 2 ** (len(vae.config.block_out_channels)) + vae_scale_factor = 2 ** (len(vae_config_block_out_channels)) latent_image_ids = FluxPipeline._prepare_latent_image_ids( model_input.shape[0], @@ -1789,15 +1817,16 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): torch_dtype=weight_dtype, ) if not args.train_text_encoder: - del text_encoder_one, text_encoder_two - torch.cuda.empty_cache() - gc.collect() + clear_objs_and_retain_memory([text_encoder_one, text_encoder_two]) # Save the lora layers accelerator.wait_for_everyone() if accelerator.is_main_process: transformer = unwrap_model(transformer) - transformer = transformer.to(torch.float32) + if args.upcast_before_saving: + transformer.to(torch.float32) + else: + transformer = transformer.to(weight_dtype) transformer_lora_layers = get_peft_model_state_dict(transformer) if args.train_text_encoder: