From 90686c2e772ee285aeea737bea416a30a8d84adc Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 12 Aug 2024 17:30:28 +0300 Subject: [PATCH 01/82] add ostris trainer to README & add cache latents of vae --- examples/dreambooth/README_flux.md | 2 +- .../dreambooth/train_dreambooth_lora_flux.py | 26 ++++++++++++++++++- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/README_flux.md b/examples/dreambooth/README_flux.md index fab382c0894c..fd3961b1fd79 100644 --- a/examples/dreambooth/README_flux.md +++ b/examples/dreambooth/README_flux.md @@ -8,7 +8,7 @@ The `train_dreambooth_flux.py` script shows how to implement the training proced > > Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements - > a LoRA with a rank of 16 (w/ all components trained) can exceed 40GB of VRAM for training. -> For more tips & guidance on training on a resource-constrained device please visit [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX.md) +> For more tips & guidance on training on a resource-constrained device please check out these great guides and trainers for FLUX: [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX.md) & [`ostris` guide](https://github.com/ostris/ai-toolkit?tab=readme-ov-file#flux1-training) > [!NOTE] diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 3629fcca4dd0..3122ad2e13c7 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -599,6 +599,12 @@ def parse_args(input_args=None): " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" ), ) + parser.add_argument( + "--cache_latents", + action="store_true", + default=False, + help="Cache the VAE latents", + ) parser.add_argument( "--report_to", type=str, @@ -1456,6 +1462,21 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) + if args.cache_latents: + latents_cache = [] + for batch in tqdm(train_dataloader, desc="Caching latents"): + with torch.no_grad(): + batch["pixel_values"] = batch["pixel_values"].to( + accelerator.device, non_blocking=True, dtype=torch.float32 + ) + latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) + + if args.validation_prompt is None: + del vae + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -1610,7 +1631,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) # Convert images to latent space - model_input = vae.encode(pixel_values).latent_dist.sample() + if args.cache_latents: + model_input = latents_cache[step].sample() + else: + model_input = vae.encode(pixel_values).latent_dist.sample() model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor model_input = model_input.to(dtype=weight_dtype) From 7b12ed2c82e34722dada6375ea4d39fd54deed9f Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 12 Aug 2024 17:54:16 +0300 Subject: [PATCH 02/82] add ostris trainer to README & add cache latents of vae --- examples/dreambooth/train_dreambooth_lora_flux.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 3122ad2e13c7..7acceed0bc66 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1462,12 +1462,15 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) + vae_config_shift_factor = vae.config.shift_factor + vae_config_scaling_factor = vae.config.scaling_factor + vae_config_block_out_channels = vae.config.block_out_channels if args.cache_latents: latents_cache = [] for batch in tqdm(train_dataloader, desc="Caching latents"): with torch.no_grad(): batch["pixel_values"] = batch["pixel_values"].to( - accelerator.device, non_blocking=True, dtype=torch.float32 + accelerator.device, non_blocking=True, dtype=weight_dtype ) latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) @@ -1475,6 +1478,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): del vae if torch.cuda.is_available(): torch.cuda.empty_cache() + gc.collect() # Scheduler and math around the number of training steps. @@ -1599,7 +1603,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.train_text_encoder: models_to_accumulate.extend([text_encoder_one]) with accelerator.accumulate(models_to_accumulate): - pixel_values = batch["pixel_values"].to(dtype=vae.dtype) prompts = batch["prompts"] # encode batch prompts when custom prompts are provided for each image - @@ -1634,11 +1637,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.cache_latents: model_input = latents_cache[step].sample() else: + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) model_input = vae.encode(pixel_values).latent_dist.sample() - model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor + model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor model_input = model_input.to(dtype=weight_dtype) - vae_scale_factor = 2 ** (len(vae.config.block_out_channels)) + vae_scale_factor = 2 ** (len(vae_config_block_out_channels)) latent_image_ids = FluxPipeline._prepare_latent_image_ids( model_input.shape[0], From 17dca18cf1c99359d6e6c075b5ab739dac02fe19 Mon Sep 17 00:00:00 2001 From: Linoy Date: Mon, 12 Aug 2024 15:01:55 +0000 Subject: [PATCH 03/82] style --- examples/dreambooth/train_dreambooth_lora_flux.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 7acceed0bc66..ebf2d3dfcdff 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1480,7 +1480,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): torch.cuda.empty_cache() gc.collect() - # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) From 8b314e9ad672ff823fd47ee0ed7f6bb2022e434a Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 13 Aug 2024 17:10:13 +0300 Subject: [PATCH 04/82] readme --- examples/dreambooth/README_flux.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/README_flux.md b/examples/dreambooth/README_flux.md index fd3961b1fd79..e668e156a0ef 100644 --- a/examples/dreambooth/README_flux.md +++ b/examples/dreambooth/README_flux.md @@ -8,7 +8,9 @@ The `train_dreambooth_flux.py` script shows how to implement the training proced > > Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements - > a LoRA with a rank of 16 (w/ all components trained) can exceed 40GB of VRAM for training. -> For more tips & guidance on training on a resource-constrained device please check out these great guides and trainers for FLUX: [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX.md) & [`ostris` guide](https://github.com/ostris/ai-toolkit?tab=readme-ov-file#flux1-training) +> For more tips & guidance on training on a resource-constrained device please check out these great guides and trainers for FLUX: +> 1) [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX.md) +> 2) [`ostris`'s guide](https://github.com/ostris/ai-toolkit?tab=readme-ov-file#flux1-training) > [!NOTE] From df54cd8dc0606ce2013cce8ebdcd9f9180c89459 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 14 Aug 2024 10:35:38 +0300 Subject: [PATCH 05/82] add test for latent caching --- examples/dreambooth/test_dreambooth_flux.py | 24 +++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/examples/dreambooth/test_dreambooth_flux.py b/examples/dreambooth/test_dreambooth_flux.py index 2d5703d2a24a..c7a83157b091 100644 --- a/examples/dreambooth/test_dreambooth_flux.py +++ b/examples/dreambooth/test_dreambooth_flux.py @@ -62,6 +62,30 @@ def test_dreambooth(self): self.assertTrue(os.path.isfile(os.path.join(tmpdir, "transformer", "diffusion_pytorch_model.safetensors"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) + def test_dreambooth_latent_caching(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --cache_latents + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "transformer", "diffusion_pytorch_model.safetensors"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) + def test_dreambooth_checkpointing(self): with tempfile.TemporaryDirectory() as tmpdir: # Run training script with checkpointing From e0e0319960d0608d220c76c8b4f8b1d59cc84b86 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 14 Aug 2024 10:57:20 +0300 Subject: [PATCH 06/82] add ostris noise scheduler https://github.com/ostris/ai-toolkit/blob/9ee1ef2a0a2a9a02b92d114a95f21312e5906e54/toolkit/samplers/custom_flowmatch_sampler.py#L95 --- .../dreambooth/train_dreambooth_lora_flux.py | 102 +++++++++++++++++- 1 file changed, 101 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 33f5a89a0741..317a230458ef 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1002,6 +1002,106 @@ def encode_prompt( return prompt_embeds, pooled_prompt_embeds, text_ids +# CustomFlowMatchEulerDiscreteScheduler was taken from ostris ai-toolkit trainer: +# https://github.com/ostris/ai-toolkit/blob/9ee1ef2a0a2a9a02b92d114a95f21312e5906e54/toolkit/samplers/custom_flowmatch_sampler.py#L95 +class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + with torch.no_grad(): + # create weights for timesteps + num_timesteps = 1000 + + # generate the multiplier based on cosmap loss weighing + # this is only used on linear timesteps for now + + # cosine map weighing is higher in the middle and lower at the ends + # bot = 1 - 2 * self.sigmas + 2 * self.sigmas ** 2 + # cosmap_weighing = 2 / (math.pi * bot) + + # sigma sqrt weighing is significantly higher at the end and lower at the beginning + sigma_sqrt_weighing = (self.sigmas ** -2.0).float() + # clip at 1e4 (1e6 is too high) + sigma_sqrt_weighing = torch.clamp(sigma_sqrt_weighing, max=1e4) + # bring to a mean of 1 + sigma_sqrt_weighing = sigma_sqrt_weighing / sigma_sqrt_weighing.mean() + + # Create linear timesteps from 1000 to 0 + timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu') + + self.linear_timesteps = timesteps + # self.linear_timesteps_weights = cosmap_weighing + self.linear_timesteps_weights = sigma_sqrt_weighing + + # self.sigmas = self.get_sigmas(timesteps, n_dim=1, dtype=torch.float32, device='cpu') + pass + + def get_weights_for_timesteps(self, timesteps: torch.Tensor) -> torch.Tensor: + # Get the indices of the timesteps + step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps] + + # Get the weights for the timesteps + weights = self.linear_timesteps_weights[step_indices].flatten() + + return weights + + def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor: + sigmas = self.sigmas.to(device=device, dtype=dtype) + schedule_timesteps = self.timesteps.to(device) + timesteps = timesteps.to(device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + + return sigma + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, + ) -> torch.Tensor: + ## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578 + ## Add noise according to flow matching. + ## zt = (1 - texp) * x + texp * z1 + + # sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) + # noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + + # timestep needs to be in [0, 1], we store them in [0, 1000] + # noisy_sample = (1 - timestep) * latent + timestep * noise + t_01 = (timesteps / 1000).to(original_samples.device) + noisy_model_input = (1 - t_01) * original_samples + t_01 * noise + + # n_dim = original_samples.ndim + # sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device) + # noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise + return noisy_model_input + + def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + return sample + + def set_train_timesteps(self, num_timesteps, device, linear=False): + if linear: + timesteps = torch.linspace(1000, 0, num_timesteps, device=device) + self.timesteps = timesteps + return timesteps + else: + # distribute them closer to center. Inference distributes them as a bias toward first + # Generate values from 0 to 1 + t = torch.sigmoid(torch.randn((num_timesteps,), device=device)) + + # Scale and reverse the values to go from 1000 to 0 + timesteps = ((1 - t) * 1000) + + # Sort the timesteps in descending order + timesteps, _ = torch.sort(timesteps, descending=True) + + self.timesteps = timesteps.to(device=device) + + return timesteps def main(args): if args.report_to == "wandb" and args.hub_token is not None: @@ -1133,7 +1233,7 @@ def main(args): ) # Load scheduler and models - noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + noise_scheduler = CustomFlowMatchEulerDiscreteScheduler.from_pretrained( args.pretrained_model_name_or_path, subfolder="scheduler" ) noise_scheduler_copy = copy.deepcopy(noise_scheduler) From 18aa3697c8e0bc477696d4d22123a8f44d0e43ee Mon Sep 17 00:00:00 2001 From: Linoy Date: Wed, 14 Aug 2024 08:00:15 +0000 Subject: [PATCH 07/82] style --- .../dreambooth/train_dreambooth_lora_flux.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 317a230458ef..6e46948d7ec4 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1002,6 +1002,7 @@ def encode_prompt( return prompt_embeds, pooled_prompt_embeds, text_ids + # CustomFlowMatchEulerDiscreteScheduler was taken from ostris ai-toolkit trainer: # https://github.com/ostris/ai-toolkit/blob/9ee1ef2a0a2a9a02b92d114a95f21312e5906e54/toolkit/samplers/custom_flowmatch_sampler.py#L95 class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): @@ -1020,14 +1021,14 @@ def __init__(self, *args, **kwargs): # cosmap_weighing = 2 / (math.pi * bot) # sigma sqrt weighing is significantly higher at the end and lower at the beginning - sigma_sqrt_weighing = (self.sigmas ** -2.0).float() + sigma_sqrt_weighing = (self.sigmas**-2.0).float() # clip at 1e4 (1e6 is too high) sigma_sqrt_weighing = torch.clamp(sigma_sqrt_weighing, max=1e4) # bring to a mean of 1 sigma_sqrt_weighing = sigma_sqrt_weighing / sigma_sqrt_weighing.mean() # Create linear timesteps from 1000 to 0 - timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu') + timesteps = torch.linspace(1000, 0, num_timesteps, device="cpu") self.linear_timesteps = timesteps # self.linear_timesteps_weights = cosmap_weighing @@ -1058,10 +1059,10 @@ def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Ten return sigma def add_noise( - self, - original_samples: torch.Tensor, - noise: torch.Tensor, - timesteps: torch.Tensor, + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, ) -> torch.Tensor: ## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578 ## Add noise according to flow matching. @@ -1094,7 +1095,7 @@ def set_train_timesteps(self, num_timesteps, device, linear=False): t = torch.sigmoid(torch.randn((num_timesteps,), device=device)) # Scale and reverse the values to go from 1000 to 0 - timesteps = ((1 - t) * 1000) + timesteps = (1 - t) * 1000 # Sort the timesteps in descending order timesteps, _ = torch.sort(timesteps, descending=True) @@ -1103,6 +1104,7 @@ def set_train_timesteps(self, num_timesteps, device, linear=False): return timesteps + def main(args): if args.report_to == "wandb" and args.hub_token is not None: raise ValueError( From f97d53d009089c661a2176a18d08fa8bbf008b52 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 14 Aug 2024 11:02:08 +0300 Subject: [PATCH 08/82] fix import --- examples/dreambooth/train_dreambooth_lora_flux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 6e46948d7ec4..6b315cc3dd25 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -44,7 +44,7 @@ from torchvision.transforms.functional import crop from tqdm.auto import tqdm from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast - +from typing import Union import diffusers from diffusers import ( AutoencoderKL, From 0156becdf99ad64bc718b30048115de0a6c75dcf Mon Sep 17 00:00:00 2001 From: Linoy Date: Wed, 14 Aug 2024 08:04:12 +0000 Subject: [PATCH 09/82] style --- examples/dreambooth/train_dreambooth_lora_flux.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 6b315cc3dd25..d3675fd56fa2 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -25,6 +25,7 @@ import warnings from contextlib import nullcontext from pathlib import Path +from typing import Union import numpy as np import torch @@ -44,7 +45,7 @@ from torchvision.transforms.functional import crop from tqdm.auto import tqdm from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast -from typing import Union + import diffusers from diffusers import ( AutoencoderKL, From c4c2c48d248447cb31ad9e0ab0a783ad4f7e777e Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 14 Aug 2024 11:59:36 +0300 Subject: [PATCH 10/82] fix tests --- examples/dreambooth/test_dreambooth_flux.py | 24 -------------- .../dreambooth/test_dreambooth_lora_flux.py | 31 +++++++++++++++++++ 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/examples/dreambooth/test_dreambooth_flux.py b/examples/dreambooth/test_dreambooth_flux.py index c7a83157b091..2d5703d2a24a 100644 --- a/examples/dreambooth/test_dreambooth_flux.py +++ b/examples/dreambooth/test_dreambooth_flux.py @@ -62,30 +62,6 @@ def test_dreambooth(self): self.assertTrue(os.path.isfile(os.path.join(tmpdir, "transformer", "diffusion_pytorch_model.safetensors"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) - def test_dreambooth_latent_caching(self): - with tempfile.TemporaryDirectory() as tmpdir: - test_args = f""" - {self.script_path} - --pretrained_model_name_or_path {self.pretrained_model_name_or_path} - --instance_data_dir {self.instance_data_dir} - --instance_prompt {self.instance_prompt} - --resolution 64 - --train_batch_size 1 - --gradient_accumulation_steps 1 - --cache_latents - --max_train_steps 2 - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --output_dir {tmpdir} - """.split() - - run_command(self._launch_args + test_args) - # save_pretrained smoke test - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "transformer", "diffusion_pytorch_model.safetensors"))) - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) - def test_dreambooth_checkpointing(self): with tempfile.TemporaryDirectory() as tmpdir: # Run training script with checkpointing diff --git a/examples/dreambooth/test_dreambooth_lora_flux.py b/examples/dreambooth/test_dreambooth_lora_flux.py index b77f84447aaa..1dd9d73a2ca6 100644 --- a/examples/dreambooth/test_dreambooth_lora_flux.py +++ b/examples/dreambooth/test_dreambooth_lora_flux.py @@ -102,7 +102,38 @@ def test_dreambooth_lora_text_encoder_flux(self): (key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys() ) self.assertTrue(starts_with_expected_prefix) + def test_dreambooth_lora_latent_caching(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --cache_latents + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. + starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer) def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f""" From d514c7bac0e7d23630b15626ebafb46854066ac3 Mon Sep 17 00:00:00 2001 From: Linoy Date: Wed, 14 Aug 2024 12:35:29 +0000 Subject: [PATCH 11/82] style --- examples/dreambooth/test_dreambooth_lora_flux.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/dreambooth/test_dreambooth_lora_flux.py b/examples/dreambooth/test_dreambooth_lora_flux.py index 1dd9d73a2ca6..d197c8187b87 100644 --- a/examples/dreambooth/test_dreambooth_lora_flux.py +++ b/examples/dreambooth/test_dreambooth_lora_flux.py @@ -102,6 +102,7 @@ def test_dreambooth_lora_text_encoder_flux(self): (key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys() ) self.assertTrue(starts_with_expected_prefix) + def test_dreambooth_lora_latent_caching(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f""" @@ -134,6 +135,7 @@ def test_dreambooth_lora_latent_caching(self): # with `"transformer"` in their names. starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) self.assertTrue(starts_with_transformer) + def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f""" From d5c2a36af11d8c55b26d0357d81213d17639bb2f Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 16 Aug 2024 17:57:59 +0300 Subject: [PATCH 12/82] --change upcasting of transformer? --- examples/dreambooth/train_dreambooth_lora_flux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index d3675fd56fa2..d0238c403725 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1923,7 +1923,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator.wait_for_everyone() if accelerator.is_main_process: transformer = unwrap_model(transformer) - transformer = transformer.to(torch.float32) + transformer = transformer.to(weight_dtype) transformer_lora_layers = get_peft_model_state_dict(transformer) if args.train_text_encoder: From fbacbb564cf49ce43453cecbf23fadf37f5bf332 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 11 Sep 2024 12:00:34 +0300 Subject: [PATCH 13/82] update readme according to main --- examples/dreambooth/README_flux.md | 50 ++++++++++++++++++++++++------ 1 file changed, 40 insertions(+), 10 deletions(-) diff --git a/examples/dreambooth/README_flux.md b/examples/dreambooth/README_flux.md index b816ae6153e0..eaa0ebd80666 100644 --- a/examples/dreambooth/README_flux.md +++ b/examples/dreambooth/README_flux.md @@ -9,10 +9,9 @@ The `train_dreambooth_flux.py` script shows how to implement the training proced > Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements - > a LoRA with a rank of 16 (w/ all components trained) can exceed 40GB of VRAM for training. -> For more tips & guidance on training on a resource-constrained device please check out these great guides and trainers for FLUX: +> For more tips & guidance on training on a resource-constrained device and general good practices please check out these great guides and trainers for FLUX: > 1) [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX.md) -> 2) [`ostris`'s guide](https://github.com/ostris/ai-toolkit?tab=readme-ov-file#flux1-training) - +> 2) [`ostris`'s guide](https://github.com/ostris/ai-toolkit?tab=readme-ov-file#flux1-training) > [!NOTE] > **Gated model** @@ -103,8 +102,10 @@ accelerate launch train_dreambooth_flux.py \ --instance_prompt="a photo of sks dog" \ --resolution=1024 \ --train_batch_size=1 \ + --guidance_scale=1 \ --gradient_accumulation_steps=4 \ - --learning_rate=1e-4 \ + --optimizer="prodigy" \ + --learning_rate=1. \ --report_to="wandb" \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ @@ -123,15 +124,23 @@ To better track our training experiments, we're using the following flags in the > [!NOTE] > If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases. -> [!TIP] -> You can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so. - ## LoRA + DreamBooth [LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters. Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment. +### Prodigy Optimizer +Prodigy is an adaptive optimizer that dynamically adjusts the learning rate learned parameters based on past gradients, allowing for more efficient convergence. +By using prodigy we can "eliminate" the need for manual learning rate tuning. read more [here](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers). + +to use prodigy, specify +```bash +--optimizer="prodigy" +``` +> [!TIP] +> When using prodigy it's generally good practice to set- `--learning_rate=1.0` + To perform DreamBooth with LoRA, run: ```bash @@ -147,8 +156,10 @@ accelerate launch train_dreambooth_lora_flux.py \ --instance_prompt="a photo of sks dog" \ --resolution=512 \ --train_batch_size=1 \ + --guidance_scale=1 \ --gradient_accumulation_steps=4 \ - --learning_rate=1e-5 \ + --optimizer="prodigy" \ + --learning_rate=1. \ --report_to="wandb" \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ @@ -165,6 +176,7 @@ Alongside the transformer, fine-tuning of the CLIP text encoder is also supporte To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind: > [!NOTE] +> This is still an experimental feature. > FLUX.1 has 2 text encoders (CLIP L/14 and T5-v1.1-XXL). By enabling `--train_text_encoder`, fine-tuning of the **CLIP encoder** is performed. > At the moment, T5 fine-tuning is not supported and weights remain frozen when text encoder training is enabled. @@ -183,8 +195,10 @@ accelerate launch train_dreambooth_lora_flux.py \ --instance_prompt="a photo of sks dog" \ --resolution=512 \ --train_batch_size=1 \ + --guidance_scale=1 \ --gradient_accumulation_steps=4 \ - --learning_rate=1e-5 \ + --optimizer="prodigy" \ + --learning_rate=1. \ --report_to="wandb" \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ @@ -194,5 +208,21 @@ accelerate launch train_dreambooth_lora_flux.py \ --push_to_hub ``` +## Memory Optimizations +As mentioned, Flux Dreambooth LoRA training is very memory intensive Here are some options (some still experimental) for a more memory efficient training. +### Image Resolution +An easy way to mitigate some of the memory requirements is through `--resolution`. `--resolution` refers to the resolution for input images, all the images in the train/validation dataset are resized to this. +Note that by default, images are resized to resolution of 512, but it's good to keep in mind in case you're accustomed to training on higher resolutions. +### Gradient Checkpointing and Accumulation +* `--gradient accumulation` refers to the number of updates steps to accumulate before performing a backward/update pass. +by passing a value > 1 you can reduce the amount of backward/update passes and hence also memory reqs. +* with `--gradient checkpointing` we can save memory by not storing all intermediate activations during the forward pass. +Instead, only a subset of these activations (the checkpoints) are stored and the rest is recomputed as needed during the backward pass. Note that this comes at the expanse of a slower backward pass. +### 8-bit-Adam Optimizer +When training with `AdamW`(doesn't apply to `prodigy`) You can pass `--use_8bit_adam` to reduce the memory requirements of training. +Make sure to install `bitsandbytes` if you want to do so. +### latent caching +When training w/o validation runs, we can pre-encode the training images with the vae, and then delete it to free up some memory. +to enable `latent_caching`, first, use the version in [this PR](https://github.com/huggingface/diffusers/blob/1b195933d04e4c8281a2634128c0d2d380893f73/examples/dreambooth/train_dreambooth_lora_flux.py), and then pass `--cache_latents` ## Other notes -Thanks to `bghira` for their help with reviewing & insight sharing ♥️ \ No newline at end of file +Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️ \ No newline at end of file From 44c534efed0d62da022cf513530a15ee06721f82 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 11 Sep 2024 14:21:16 +0300 Subject: [PATCH 14/82] add pivotal tuning for CLIP --- .../dreambooth/train_dreambooth_lora_flux.py | 376 +++++++++++++++--- 1 file changed, 318 insertions(+), 58 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index b87bf6620aed..ca55b8316346 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -83,17 +83,43 @@ def save_model_card( images=None, base_model: str = None, train_text_encoder=False, + train_text_encoder_ti=False, + token_abstraction_dict=None, instance_prompt=None, validation_prompt=None, repo_folder=None, ): widget_dict = [] + trigger_str = f"You should use {instance_prompt} to trigger the image generation." + if images is not None: for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) widget_dict.append( {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} ) + diffusers_imports_pivotal = "" + diffusers_example_pivotal = "" + if train_text_encoder_ti: + embeddings_filename = f"{repo_folder}_emb" + ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"", instance_prompt)) + trigger_str = ( + "To trigger image generation of trained concept(or concepts) replace each concept identifier " + "in you prompt with the new inserted tokens:\n" + ) + diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download + from safetensors.torch import load_file + """ + diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors', repo_type="model") + state_dict = load_file(embedding_path) + pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer) + """ + if token_abstraction_dict: + for key, value in token_abstraction_dict.items(): + tokens = "".join(value) + trigger_str += f""" + to trigger concept `{key}` → use `{tokens}` in your prompt \n + """ model_description = f""" # Flux DreamBooth LoRA - {repo_id} @@ -108,9 +134,11 @@ def save_model_card( Was LoRA for the text encoder enabled? {train_text_encoder}. +Pivotal tuning was enabled: {train_text_encoder_ti}. + ## Trigger words -You should use `{instance_prompt}` to trigger the image generation. +{trigger_str} ## Download model @@ -121,8 +149,10 @@ def save_model_card( ```py from diffusers import AutoPipelineForText2Image import torch +{diffusers_imports_pivotal} pipeline = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to('cuda') pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors') +{diffusers_example_pivotal} image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0] ``` @@ -311,6 +341,23 @@ def parse_args(input_args=None): required=True, help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", ) + parser.add_argument( + "--token_abstraction", + type=str, + default="TOK", + help="identifier specifying the instance(or instances) as used in instance_prompt, validation prompt, " + "captions - e.g. TOK. To use multiple identifiers, please specify them in a comma separated string - e.g. " + "'TOK,TOK2,TOK3' etc.", + ) + + parser.add_argument( + "--num_new_tokens_per_abstraction", + type=int, + default=2, + help="number of new tokens inserted to the tokenizers per token_abstraction identifier when " + "--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new " + "tokens - ", + ) parser.add_argument( "--class_prompt", type=str, @@ -401,6 +448,25 @@ def parse_args(input_args=None): action="store_true", help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", ) + parser.add_argument( + "--train_text_encoder_ti", + action="store_true", + help=("Whether to use textual inversion"), + ) + + parser.add_argument( + "--train_text_encoder_ti_frac", + type=float, + default=0.5, + help=("The percentage of epochs to perform textual inversion"), + ) + + parser.add_argument( + "--train_text_encoder_frac", + type=float, + default=1.0, + help=("The percentage of epochs to perform text encoder tuning"), + ) parser.add_argument( "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) @@ -649,6 +715,12 @@ def parse_args(input_args=None): if args.dataset_name is not None and args.instance_data_dir is not None: raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") + if args.train_text_encoder and args.train_text_encoder_ti: + raise ValueError( + "Specify only one of `--train_text_encoder` or `--train_text_encoder_ti. " + "For full LoRA text encoder training check --train_text_encoder, for textual " + "inversion training check `--train_text_encoder_ti`" + ) env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank @@ -667,6 +739,103 @@ def parse_args(input_args=None): return args +# Modified from https://github.com/replicate/cog-sdxl/blob/main/dataset_and_utils.py +class TokenEmbeddingsHandler: + def __init__(self, text_encoders, tokenizers): + self.text_encoders = text_encoders + self.tokenizers = tokenizers + + self.train_ids: Optional[torch.Tensor] = None + self.inserting_toks: Optional[List[str]] = None + self.embeddings_settings = {} + + def initialize_new_tokens(self, inserting_toks: List[str]): + idx = 0 + for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): + assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings." + assert all( + isinstance(tok, str) for tok in inserting_toks + ), "All elements in inserting_toks should be strings." + + self.inserting_toks = inserting_toks + special_tokens_dict = {"additional_special_tokens": self.inserting_toks} + tokenizer.add_special_tokens(special_tokens_dict) + text_encoder.resize_token_embeddings(len(tokenizer)) + + self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks) + + # random initialization of new tokens + std_token_embedding = text_encoder.text_model.embeddings.token_embedding.weight.data.std() + + print(f"{idx} text encoder's std_token_embedding: {std_token_embedding}") + + text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = ( + torch.randn(len(self.train_ids), text_encoder.text_model.config.hidden_size) + .to(device=self.device) + .to(dtype=self.dtype) + * std_token_embedding + ) + self.embeddings_settings[ + f"original_embeddings_{idx}" + ] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone() + self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding + + inu = torch.ones((len(tokenizer),), dtype=torch.bool) + inu[self.train_ids] = False + + self.embeddings_settings[f"index_no_updates_{idx}"] = inu + + print(self.embeddings_settings[f"index_no_updates_{idx}"].shape) + + idx += 1 + + def save_embeddings(self, file_path: str): + assert self.train_ids is not None, "Initialize new tokens before saving embeddings." + tensors = {} + # text_encoder_0 - CLIP ViT-L/14, for now only optimizing and saving embeddings for CLIP (text_encoder_two - T5, remains untouched) + idx_to_text_encoder_name = {0: "clip_l"} + for idx, text_encoder in enumerate(self.text_encoders): + assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len( + self.tokenizers[0] + ), "Tokenizers should be the same." + new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] + + # New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), + # Note: When loading with diffusers, any name can work - simply specify in inference + tensors[idx_to_text_encoder_name[idx]] = new_token_embeddings + # tensors[f"text_encoders_{idx}"] = new_token_embeddings + + save_file(tensors, file_path) + + @property + def dtype(self): + return self.text_encoders[0].dtype + + @property + def device(self): + return self.text_encoders[0].device + + @torch.no_grad() + def retract_embeddings(self): + for idx, text_encoder in enumerate(self.text_encoders): + index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"] + text_encoder.text_model.embeddings.token_embedding.weight.data[index_no_updates] = ( + self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates] + .to(device=text_encoder.device) + .to(dtype=text_encoder.dtype) + ) + + # for the parts that were updated, we need to normalize them + # to have the same std as before + std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"] + + index_updates = ~index_no_updates + new_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] + off_ratio = std_token_embedding / new_embeddings.std() + + new_embeddings = new_embeddings * (off_ratio**0.1) + text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings + class DreamBoothDataset(Dataset): """ @@ -679,6 +848,8 @@ def __init__( instance_data_root, instance_prompt, class_prompt, + train_text_encoder_ti, + token_abstraction_dict=None, # token mapping for textual inversion class_data_root=None, class_num=None, size=1024, @@ -691,7 +862,8 @@ def __init__( self.instance_prompt = instance_prompt self.custom_instance_prompts = None self.class_prompt = class_prompt - + self.token_abstraction_dict = token_abstraction_dict + self.train_text_encoder_ti = train_text_encoder_ti # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, # we load the training data using load_dataset if args.dataset_name is not None: @@ -818,11 +990,15 @@ def __getitem__(self, index): if self.custom_instance_prompts: caption = self.custom_instance_prompts[index % self.num_instance_images] if caption: + if self.train_text_encoder_ti: + # replace instances of --token_abstraction in caption with the new tokens: "" etc. + for token_abs, token_replacement in self.token_abstraction_dict.items(): + caption = caption.replace(token_abs, "".join(token_replacement)) example["instance_prompt"] = caption else: example["instance_prompt"] = self.instance_prompt - else: # custom prompts were provided, but length does not match size of image dataset + else: # the given instance prompt is used for all images example["instance_prompt"] = self.instance_prompt if self.class_data_root: @@ -871,7 +1047,7 @@ def __getitem__(self, index): return example -def tokenize_prompt(tokenizer, prompt, max_sequence_length): +def tokenize_prompt(tokenizer, prompt, max_sequence_length, add_special_tokens=False): text_inputs = tokenizer( prompt, padding="max_length", @@ -879,6 +1055,7 @@ def tokenize_prompt(tokenizer, prompt, max_sequence_length): truncation=True, return_length=False, return_overflowing_tokens=False, + add_special_tokens=add_special_tokens, return_tensors="pt", ) text_input_ids = text_inputs.input_ids @@ -985,7 +1162,7 @@ def encode_prompt( prompt=prompt, device=device if device is not None else text_encoders[0].device, num_images_per_prompt=num_images_per_prompt, - text_input_ids=text_input_ids_list[0] if text_input_ids_list else None, + text_input_ids=text_input_ids_list[0] if text_input_ids_list is not None else None ) prompt_embeds = _encode_prompt_with_t5( @@ -995,7 +1172,7 @@ def encode_prompt( prompt=prompt, num_images_per_prompt=num_images_per_prompt, device=device if device is not None else text_encoders[1].device, - text_input_ids=text_input_ids_list[1] if text_input_ids_list else None, + text_input_ids=text_input_ids_list[1] if text_input_ids_list is not None else None ) text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) @@ -1251,6 +1428,37 @@ def main(args): args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant ) + if args.train_text_encoder_ti: + # we parse the provided token identifier (or identifiers) into a list. s.t. - "TOK" -> ["TOK"], "TOK, + # TOK2" -> ["TOK", "TOK2"] etc. + token_abstraction_list = "".join(args.token_abstraction.split()).split(",") + logger.info(f"list of token identifiers: {token_abstraction_list}") + + token_abstraction_dict = {} + token_idx = 0 + for i, token in enumerate(token_abstraction_list): + token_abstraction_dict[token] = [ + f"" for j in range(args.num_new_tokens_per_abstraction) + ] + token_idx += args.num_new_tokens_per_abstraction - 1 + + # replace instances of --token_abstraction in --instance_prompt with the new tokens: "" etc. + for token_abs, token_replacement in token_abstraction_dict.items(): + args.instance_prompt = args.instance_prompt.replace(token_abs, "".join(token_replacement)) + if args.with_prior_preservation: + args.class_prompt = args.class_prompt.replace(token_abs, "".join(token_replacement)) + if args.validation_prompt: + args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement)) + + # initialize the new tokens for textual inversion + embedding_handler = TokenEmbeddingsHandler( + [text_encoder_one], [tokenizer_one] + ) + inserting_toks = [] + for new_tok in token_abstraction_dict.values(): + inserting_toks.extend(new_tok) + embedding_handler.initialize_new_tokens(inserting_toks=inserting_toks) + # We only train the additional adapter LoRA layers transformer.requires_grad_(False) vae.requires_grad_(False) @@ -1289,6 +1497,7 @@ def main(args): target_modules=["to_k", "to_q", "to_v", "to_out.0"], ) transformer.add_adapter(transformer_lora_config) + if args.train_text_encoder: text_lora_config = LoraConfig( r=args.rank, @@ -1297,6 +1506,18 @@ def main(args): target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], ) text_encoder_one.add_adapter(text_lora_config) + # if we use textual inversion, we freeze all parameters except for the token embeddings + # in text encoder + elif args.train_text_encoder_ti: + text_lora_parameters_one = [] # for now only for CLIP + for name, param in text_encoder_one.named_parameters(): + if "token_embedding" in name: + # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 + param.data = param.to(dtype=torch.float32) + param.requires_grad = True + text_lora_parameters_one.append(param) + else: + param.requires_grad = False def unwrap_model(model): model = accelerator.unwrap_model(model) @@ -1313,7 +1534,8 @@ def save_model_hook(models, weights, output_dir): if isinstance(model, type(unwrap_model(transformer))): transformer_lora_layers_to_save = get_peft_model_state_dict(model) elif isinstance(model, type(unwrap_model(text_encoder_one))): - text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) + if args.train_text_encoder: # when --train_text_encoder_ti we don't save the layers + text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1325,6 +1547,8 @@ def save_model_hook(models, weights, output_dir): transformer_lora_layers=transformer_lora_layers_to_save, text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, ) + if args.train_text_encoder_ti: + embedding_handler.save_embeddings(f"{args.output_dir}/{Path(args.output_dir).name}_emb.safetensors") def load_model_hook(models, input_dir): transformer_ = None @@ -1391,16 +1615,22 @@ def load_model_hook(models, input_dir): cast_training_params(models, dtype=torch.float32) transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + if args.train_text_encoder: text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters())) + # If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training + freeze_text_encoder = not (args.train_text_encoder or args.train_text_encoder_ti) + # Optimization parameters transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} - if args.train_text_encoder: + if not freeze_text_encoder: # different learning rate for text encoder and unet text_parameters_one_with_lr = { "params": text_lora_parameters_one, - "weight_decay": args.adam_weight_decay_text_encoder, + "weight_decay": args.adam_weight_decay_text_encoder + if args.adam_weight_decay_text_encoder + else args.adam_weight_decay, "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, } params_to_optimize = [ @@ -1483,6 +1713,8 @@ def load_model_hook(models, input_dir): train_dataset = DreamBoothDataset( instance_data_root=args.instance_data_dir, instance_prompt=args.instance_prompt, + train_text_encoder_ti=args.train_text_encoder_ti, + token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None, class_prompt=args.class_prompt, class_data_root=args.class_data_dir if args.with_prior_preservation else None, class_num=args.num_class_images, @@ -1499,7 +1731,7 @@ def load_model_hook(models, input_dir): num_workers=args.dataloader_num_workers, ) - if not args.train_text_encoder: + if freeze_text_encoder: tokenizers = [tokenizer_one, tokenizer_two] text_encoders = [text_encoder_one, text_encoder_two] @@ -1516,20 +1748,20 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # If no type of tuning is done on the text_encoder and custom instance prompts are NOT # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid # the redundant encoding. - if not args.train_text_encoder and not train_dataset.custom_instance_prompts: + if freeze_text_encoder and not train_dataset.custom_instance_prompts: instance_prompt_hidden_states, instance_pooled_prompt_embeds, instance_text_ids = compute_text_embeddings( args.instance_prompt, text_encoders, tokenizers ) # Handle class prompt for prior-preservation. if args.with_prior_preservation: - if not args.train_text_encoder: + if freeze_text_encoder: class_prompt_hidden_states, class_pooled_prompt_embeds, class_text_ids = compute_text_embeddings( args.class_prompt, text_encoders, tokenizers ) # Clear the memory here - if not args.train_text_encoder and not train_dataset.custom_instance_prompts: + if freeze_text_encoder and not train_dataset.custom_instance_prompts: del tokenizers, text_encoders # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection del text_encoder_one, text_encoder_two @@ -1537,12 +1769,15 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): if torch.cuda.is_available(): torch.cuda.empty_cache() + # if --train_text_encoder_ti we need add_special_tokens to be True for textual inversion + add_special_tokens = True if args.train_text_encoder_ti else False + # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't # have to pass them to the dataloader. if not train_dataset.custom_instance_prompts: - if not args.train_text_encoder: + if freeze_text_encoder: prompt_embeds = instance_prompt_hidden_states pooled_prompt_embeds = instance_pooled_prompt_embeds text_ids = instance_text_ids @@ -1553,12 +1788,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) # we need to tokenize and encode the batch prompts on all training steps else: - tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt, max_sequence_length=77) + tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt, max_sequence_length=77, add_special_tokens=add_special_tokens) tokens_two = tokenize_prompt( tokenizer_two, args.instance_prompt, max_sequence_length=args.max_sequence_length ) if args.with_prior_preservation: - class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt, max_sequence_length=77) + class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt, max_sequence_length=77, add_special_tokens=add_special_tokens) class_tokens_two = tokenize_prompt( tokenizer_two, args.class_prompt, max_sequence_length=args.max_sequence_length ) @@ -1600,7 +1835,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): ) # Prepare everything with our `accelerator`. - if args.train_text_encoder: + if not freeze_text_encoder: ( transformer, text_encoder_one, @@ -1693,50 +1928,61 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): sigma = sigma.unsqueeze(-1) return sigma + if args.train_text_encoder: + num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs) + elif args.train_text_encoder_ti: # args.train_text_encoder_ti + num_train_epochs_text_encoder = int(args.train_text_encoder_ti_frac * args.num_train_epochs) + # flag used for textual inversion + pivoted = False + for epoch in range(first_epoch, args.num_train_epochs): transformer.train() - if args.train_text_encoder: - text_encoder_one.train() - # set top parameter requires_grad = True for gradient checkpointing works - accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) + # if performing any kind of optimization of text_encoder params + if args.train_text_encoder or args.train_text_encoder_ti: + if epoch == num_train_epochs_text_encoder: + print("PIVOT HALFWAY", epoch) + # stopping optimization of text_encoder params + # this flag is used to reset the optimizer to optimize only on unet params + pivoted = True + + else: + # still optimizing the text encoder + text_encoder_one.train() + # set top parameter requires_grad = True for gradient checkpointing works + if args.train_text_encoder: + accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) for step, batch in enumerate(train_dataloader): - models_to_accumulate = [transformer] - if args.train_text_encoder: - models_to_accumulate.extend([text_encoder_one]) - with accelerator.accumulate(models_to_accumulate): + if pivoted: + # stopping optimization of text_encoder params + # re setting the optimizer to optimize only on unet params + optimizer.param_groups[1]["lr"] = 0.0 + + with accelerator.accumulate(transformer): prompts = batch["prompts"] # encode batch prompts when custom prompts are provided for each image - if train_dataset.custom_instance_prompts: - if not args.train_text_encoder: + if freeze_text_encoder: prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings( prompts, text_encoders, tokenizers ) else: - tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77) + tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77, add_special_tokens=add_special_tokens) tokens_two = tokenize_prompt( - tokenizer_two, prompts, max_sequence_length=args.max_sequence_length - ) - prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( - text_encoders=[text_encoder_one, text_encoder_two], - tokenizers=[None, None], - text_input_ids_list=[tokens_one, tokens_two], - max_sequence_length=args.max_sequence_length, - device=accelerator.device, - prompt=prompts, - ) - else: - if args.train_text_encoder: - prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( - text_encoders=[text_encoder_one, text_encoder_two], - tokenizers=[None, None], - text_input_ids_list=[tokens_one, tokens_two], - max_sequence_length=args.max_sequence_length, - device=accelerator.device, - prompt=args.instance_prompt, + tokenizer_two, prompts, max_sequence_length=args.max_sequence_length, ) + if not freeze_text_encoder: + prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + text_encoders=[text_encoder_one, text_encoder_two], + tokenizers=[None, None], + text_input_ids_list=[tokens_one, tokens_two], + max_sequence_length=args.max_sequence_length, + device=accelerator.device, + prompt=None, + ) + # Convert images to latent space if args.cache_latents: model_input = latents_cache[step].sample() @@ -1846,7 +2092,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if accelerator.sync_gradients: params_to_clip = ( itertools.chain(transformer.parameters(), text_encoder_one.parameters()) - if args.train_text_encoder + if not freeze_text_encoder else transformer.parameters() ) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) @@ -1855,6 +2101,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): lr_scheduler.step() optimizer.zero_grad() + # every step, we reset the embeddings to the original embeddings. + if args.train_text_encoder_ti: + embedding_handler.retract_embeddings() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) @@ -1896,7 +2146,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if accelerator.is_main_process: if args.validation_prompt is not None and epoch % args.validation_epochs == 0: # create pipeline - if not args.train_text_encoder: + if freeze_text_encoder: text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) pipeline = FluxPipeline.from_pretrained( args.pretrained_model_name_or_path, @@ -1916,10 +2166,14 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): pipeline_args=pipeline_args, epoch=epoch, ) - if not args.train_text_encoder: + if freeze_text_encoder: del text_encoder_one, text_encoder_two torch.cuda.empty_cache() gc.collect() + else: + del text_encoder_two + torch.cuda.empty_cache() + gc.collect() # Save the lora layers accelerator.wait_for_everyone() @@ -1940,6 +2194,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): text_encoder_lora_layers=text_encoder_lora_layers, ) + if args.train_text_encoder_ti: + embeddings_path = f"{args.output_dir}/{args.output_dir}_emb.safetensors" + embedding_handler.save_embeddings(embeddings_path) + # Final inference # Load previous pipeline pipeline = FluxPipeline.from_pretrained( @@ -1964,16 +2222,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): is_final_validation=True, ) + save_model_card( + repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + train_text_encoder=args.train_text_encoder, + train_text_encoder_ti=args.train_text_encoder_ti, + token_abstraction_dict=train_dataset.token_abstraction_dict, + instance_prompt=args.instance_prompt, + validation_prompt=args.validation_prompt, + repo_folder=args.output_dir, + ) if args.push_to_hub: - save_model_card( - repo_id, - images=images, - base_model=args.pretrained_model_name_or_path, - train_text_encoder=args.train_text_encoder, - instance_prompt=args.instance_prompt, - validation_prompt=args.validation_prompt, - repo_folder=args.output_dir, - ) upload_folder( repo_id=repo_id, folder_path=args.output_dir, From 087b9825531cc20ab9ebd9973ddd4487ecdafac7 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 11 Sep 2024 17:34:38 +0300 Subject: [PATCH 15/82] fix imports, encode_prompt call,add TextualInversionLoaderMixin to FluxPipeline for inference --- examples/dreambooth/train_dreambooth_lora_flux.py | 6 ++++-- src/diffusers/pipelines/flux/pipeline_flux.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index ca55b8316346..0635b11b1b1b 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -21,11 +21,12 @@ import math import os import random +import re import shutil import warnings from contextlib import nullcontext from pathlib import Path -from typing import Union +from typing import Union, List import numpy as np import torch @@ -40,6 +41,7 @@ from peft.utils import get_peft_model_state_dict from PIL import Image from PIL.ImageOps import exif_transpose +from safetensors.torch import load_file, save_file from torch.utils.data import Dataset from torchvision import transforms from torchvision.transforms.functional import crop @@ -1980,7 +1982,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): text_input_ids_list=[tokens_one, tokens_two], max_sequence_length=args.max_sequence_length, device=accelerator.device, - prompt=None, + prompt=prompts, ) # Convert images to latent space diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index bb214885da1c..f299622bc6c4 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -137,7 +137,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): +class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin,): r""" The Flux pipeline for text-to-image generation. From d9c3e45aa5f11fc62016745233a0af93906e45ed Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 11 Sep 2024 17:44:42 +0300 Subject: [PATCH 16/82] TextualInversionLoaderMixin support for FluxPipeline for inference --- src/diffusers/pipelines/flux/pipeline_flux.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index f299622bc6c4..9a6050c763a9 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -20,7 +20,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from ...image_processor import VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -255,6 +255,9 @@ def _get_clip_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", From b4328f868f567120f264c04a81fd6b323bbe0cc2 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 13 Sep 2024 16:59:48 +0300 Subject: [PATCH 17/82] move changes to advanced flux script, revert canonical --- .../README_flux.md | 0 .../train_dreambooth_lora_flux_advanced.py | 2251 +++++++++++++++++ .../dreambooth/train_dreambooth_lora_flux.py | 522 +--- 3 files changed, 2316 insertions(+), 457 deletions(-) create mode 100644 examples/advanced_diffusion_training/README_flux.md create mode 100644 examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py diff --git a/examples/advanced_diffusion_training/README_flux.md b/examples/advanced_diffusion_training/README_flux.md new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py new file mode 100644 index 000000000000..c102d1a13613 --- /dev/null +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -0,0 +1,2251 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import copy +import gc +import itertools +import logging +import math +import os +import random +import re +import shutil +import warnings +from contextlib import nullcontext +from pathlib import Path +from typing import Union, List + +import numpy as np +import torch +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder +from huggingface_hub.utils import insecure_hashlib +from peft import LoraConfig, set_peft_model_state_dict +from peft.utils import get_peft_model_state_dict +from PIL import Image +from PIL.ImageOps import exif_transpose +from safetensors.torch import load_file, save_file +from torch.utils.data import Dataset +from torchvision import transforms +from torchvision.transforms.functional import crop +from tqdm.auto import tqdm +from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast + +import diffusers +from diffusers import ( + AutoencoderKL, + FlowMatchEulerDiscreteScheduler, + FluxPipeline, + FluxTransformer2DModel, +) +from diffusers.optimization import get_scheduler +from diffusers.training_utils import ( + _set_state_dict_into_text_encoder, + cast_training_params, + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3, +) +from diffusers.utils import ( + check_min_version, + convert_unet_state_dict_to_peft, + is_wandb_available, +) +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.torch_utils import is_compiled_module + + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.31.0.dev0") + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + images=None, + base_model: str = None, + train_text_encoder=False, + train_text_encoder_ti=False, + token_abstraction_dict=None, + instance_prompt=None, + validation_prompt=None, + repo_folder=None, +): + widget_dict = [] + trigger_str = f"You should use {instance_prompt} to trigger the image generation." + + if images is not None: + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + widget_dict.append( + {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} + ) + diffusers_imports_pivotal = "" + diffusers_example_pivotal = "" + if train_text_encoder_ti: + embeddings_filename = f"{repo_folder}_emb" + ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"", instance_prompt)) + trigger_str = ( + "To trigger image generation of trained concept(or concepts) replace each concept identifier " + "in you prompt with the new inserted tokens:\n" + ) + diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download + from safetensors.torch import load_file + """ + diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors', repo_type="model") + state_dict = load_file(embedding_path) + pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer) + """ + if token_abstraction_dict: + for key, value in token_abstraction_dict.items(): + tokens = "".join(value) + trigger_str += f""" + to trigger concept `{key}` → use `{tokens}` in your prompt \n + """ + + model_description = f""" +# Flux DreamBooth LoRA - {repo_id} + + + +## Model description + +These are {repo_id} DreamBooth LoRA weights for {base_model}. + +The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux.md). + +Was LoRA for the text encoder enabled? {train_text_encoder}. + +Pivotal tuning was enabled: {train_text_encoder_ti}. + +## Trigger words + +{trigger_str} + +## Download model + +[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab. + +## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) + +```py +from diffusers import AutoPipelineForText2Image +import torch +{diffusers_imports_pivotal} +pipeline = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to('cuda') +pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors') +{diffusers_example_pivotal} +image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0] +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) + +## License + +Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md). +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="other", + base_model=base_model, + prompt=instance_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-image", + "diffusers-training", + "diffusers", + "lora", + "flux", + "flux-diffusers", + "template:sd-lora", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(repo_folder, "README.md")) + + +def load_text_encoders(class_one, class_two): + text_encoder_one = class_one.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + text_encoder_two = class_two.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant + ) + return text_encoder_one, text_encoder_two + + +def log_validation( + pipeline, + args, + accelerator, + pipeline_args, + epoch, + is_final_validation=False, +): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() + autocast_ctx = nullcontext() + + with autocast_ctx: + images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + phase_name: [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + ] + } + ) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return images + + +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision + ) + model_class = text_encoder_config.architectures[0] + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "T5EncoderModel": + from transformers import T5EncoderModel + + return T5EncoderModel + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + help=("A folder containing the training data. "), + ) + + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + + parser.add_argument( + "--image_column", + type=str, + default="image", + help="The column of the dataset containing the target image. By " + "default, the standard Image Dataset maps out 'file_name' " + "to 'image'.", + ) + parser.add_argument( + "--caption_column", + type=str, + default=None, + help="The column of the dataset containing the instance prompt for each image", + ) + + parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") + + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", + ) + parser.add_argument( + "--token_abstraction", + type=str, + default="TOK", + help="identifier specifying the instance(or instances) as used in instance_prompt, validation prompt, " + "captions - e.g. TOK. To use multiple identifiers, please specify them in a comma separated string - e.g. " + "'TOK,TOK2,TOK3' etc.", + ) + + parser.add_argument( + "--num_new_tokens_per_abstraction", + type=int, + default=2, + help="number of new tokens inserted to the tokenizers per token_abstraction identifier when " + "--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new " + "tokens - ", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--max_sequence_length", + type=int, + default=512, + help="Maximum sequence length to use with with the T5 text encoder", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help=( + "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="flux-dreambooth-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_text_encoder", + action="store_true", + help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", + ) + parser.add_argument( + "--train_text_encoder_ti", + action="store_true", + help=("Whether to use textual inversion"), + ) + + parser.add_argument( + "--train_text_encoder_ti_frac", + type=float, + default=0.5, + help=("The percentage of epochs to perform textual inversion"), + ) + + parser.add_argument( + "--train_text_encoder_frac", + type=float, + default=1.0, + help=("The percentage of epochs to perform text encoder tuning"), + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + + parser.add_argument( + "--guidance_scale", + type=float, + default=3.5, + help="the FLUX.1 dev variant is a guidance distilled model", + ) + + parser.add_argument( + "--text_encoder_lr", + type=float, + default=5e-6, + help="Text encoder learning rate to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + parser.add_argument( + "--optimizer", + type=str, + default="AdamW", + help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), + ) + + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", + ) + + parser.add_argument( + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--prodigy_beta3", + type=float, + default=None, + help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " + "uses the value of square root of beta2. Ignored if optimizer is adamW", + ) + parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") + parser.add_argument( + "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" + ) + + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + + parser.add_argument( + "--prodigy_use_bias_correction", + type=bool, + default=True, + help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW", + ) + parser.add_argument( + "--prodigy_safeguard_warmup", + type=bool, + default=True, + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " + "Ignored if optimizer is adamW", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--cache_latents", + action="store_true", + default=False, + help="Cache the VAE latents", + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "fp16", "bf16"], + help=( + "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.dataset_name is None and args.instance_data_dir is None: + raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") + + if args.dataset_name is not None and args.instance_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") + + if args.train_text_encoder and args.train_text_encoder_ti: + raise ValueError( + "Specify only one of `--train_text_encoder` or `--train_text_encoder_ti. " + "For full LoRA text encoder training check --train_text_encoder, for textual " + "inversion training check `--train_text_encoder_ti`" + ) + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + # logger is not available yet + if args.class_data_dir is not None: + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + + return args + +# Modified from https://github.com/replicate/cog-sdxl/blob/main/dataset_and_utils.py +class TokenEmbeddingsHandler: + def __init__(self, text_encoders, tokenizers): + self.text_encoders = text_encoders + self.tokenizers = tokenizers + + self.train_ids: Optional[torch.Tensor] = None + self.inserting_toks: Optional[List[str]] = None + self.embeddings_settings = {} + + def initialize_new_tokens(self, inserting_toks: List[str]): + idx = 0 + for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): + assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings." + assert all( + isinstance(tok, str) for tok in inserting_toks + ), "All elements in inserting_toks should be strings." + + self.inserting_toks = inserting_toks + special_tokens_dict = {"additional_special_tokens": self.inserting_toks} + tokenizer.add_special_tokens(special_tokens_dict) + text_encoder.resize_token_embeddings(len(tokenizer)) + + self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks) + + # random initialization of new tokens + std_token_embedding = text_encoder.text_model.embeddings.token_embedding.weight.data.std() + + print(f"{idx} text encoder's std_token_embedding: {std_token_embedding}") + + text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = ( + torch.randn(len(self.train_ids), text_encoder.text_model.config.hidden_size) + .to(device=self.device) + .to(dtype=self.dtype) + * std_token_embedding + ) + self.embeddings_settings[ + f"original_embeddings_{idx}" + ] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone() + self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding + + inu = torch.ones((len(tokenizer),), dtype=torch.bool) + inu[self.train_ids] = False + + self.embeddings_settings[f"index_no_updates_{idx}"] = inu + + print(self.embeddings_settings[f"index_no_updates_{idx}"].shape) + + idx += 1 + + def save_embeddings(self, file_path: str): + assert self.train_ids is not None, "Initialize new tokens before saving embeddings." + tensors = {} + # text_encoder_0 - CLIP ViT-L/14, for now only optimizing and saving embeddings for CLIP (text_encoder_two - T5, remains untouched) + idx_to_text_encoder_name = {0: "clip_l"} + for idx, text_encoder in enumerate(self.text_encoders): + assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len( + self.tokenizers[0] + ), "Tokenizers should be the same." + new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] + + # New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), + # Note: When loading with diffusers, any name can work - simply specify in inference + tensors[idx_to_text_encoder_name[idx]] = new_token_embeddings + # tensors[f"text_encoders_{idx}"] = new_token_embeddings + + save_file(tensors, file_path) + + @property + def dtype(self): + return self.text_encoders[0].dtype + + @property + def device(self): + return self.text_encoders[0].device + + @torch.no_grad() + def retract_embeddings(self): + for idx, text_encoder in enumerate(self.text_encoders): + index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"] + text_encoder.text_model.embeddings.token_embedding.weight.data[index_no_updates] = ( + self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates] + .to(device=text_encoder.device) + .to(dtype=text_encoder.dtype) + ) + + # for the parts that were updated, we need to normalize them + # to have the same std as before + std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"] + + index_updates = ~index_no_updates + new_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] + off_ratio = std_token_embedding / new_embeddings.std() + + new_embeddings = new_embeddings * (off_ratio**0.1) + text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + class_prompt, + train_text_encoder_ti, + token_abstraction_dict=None, # token mapping for textual inversion + class_data_root=None, + class_num=None, + size=1024, + repeats=1, + center_crop=False, + ): + self.size = size + self.center_crop = center_crop + + self.instance_prompt = instance_prompt + self.custom_instance_prompts = None + self.class_prompt = class_prompt + self.token_abstraction_dict = token_abstraction_dict + self.train_text_encoder_ti = train_text_encoder_ti + # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, + # we load the training data using load_dataset + if args.dataset_name is not None: + try: + from datasets import load_dataset + except ImportError: + raise ImportError( + "You are trying to load your data using the datasets library. If you wish to train using custom " + "captions please install the datasets library: `pip install datasets`. If you wish to load a " + "local folder containing images only, specify --instance_data_dir instead." + ) + # Downloading and loading a dataset from the hub. + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + # Preprocessing the datasets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + instance_images = dataset["train"][image_column] + + if args.caption_column is None: + logger.info( + "No caption column provided, defaulting to instance_prompt for all images. If your dataset " + "contains captions/prompts for the images, make sure to specify the " + "column as --caption_column" + ) + self.custom_instance_prompts = None + else: + if args.caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + custom_instance_prompts = dataset["train"][args.caption_column] + # create final list of captions according to --repeats + self.custom_instance_prompts = [] + for caption in custom_instance_prompts: + self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) + else: + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] + self.custom_instance_prompts = None + + self.instance_images = [] + for img in instance_images: + self.instance_images.extend(itertools.repeat(img, repeats)) + + self.pixel_values = [] + train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) + train_flip = transforms.RandomHorizontalFlip(p=1.0) + train_transforms = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + for image in self.instance_images: + image = exif_transpose(image) + if not image.mode == "RGB": + image = image.convert("RGB") + image = train_resize(image) + if args.random_flip and random.random() < 0.5: + # flip + image = train_flip(image) + if args.center_crop: + y1 = max(0, int(round((image.height - args.resolution) / 2.0))) + x1 = max(0, int(round((image.width - args.resolution) / 2.0))) + image = train_crop(image) + else: + y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) + image = crop(image, y1, x1, h, w) + image = train_transforms(image) + self.pixel_values.append(image) + + self.num_instance_images = len(self.instance_images) + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + if class_num is not None: + self.num_class_images = min(len(self.class_images_path), class_num) + else: + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = self.pixel_values[index % self.num_instance_images] + example["instance_images"] = instance_image + + if self.custom_instance_prompts: + caption = self.custom_instance_prompts[index % self.num_instance_images] + if caption: + if self.train_text_encoder_ti: + # replace instances of --token_abstraction in caption with the new tokens: "" etc. + for token_abs, token_replacement in self.token_abstraction_dict.items(): + caption = caption.replace(token_abs, "".join(token_replacement)) + example["instance_prompt"] = caption + else: + example["instance_prompt"] = self.instance_prompt + + else: # the given instance prompt is used for all images + example["instance_prompt"] = self.instance_prompt + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = exif_transpose(class_image) + + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt"] = self.class_prompt + + return example + + +def collate_fn(examples, with_prior_preservation=False): + pixel_values = [example["instance_images"] for example in examples] + prompts = [example["instance_prompt"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if with_prior_preservation: + pixel_values += [example["class_images"] for example in examples] + prompts += [example["class_prompt"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + batch = {"pixel_values": pixel_values, "prompts": prompts} + return batch + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def tokenize_prompt(tokenizer, prompt, max_sequence_length, add_special_tokens=False): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + add_special_tokens=add_special_tokens, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + return text_input_ids + + +def _encode_prompt_with_t5( + text_encoder, + tokenizer, + max_sequence_length=512, + prompt=None, + num_images_per_prompt=1, + device=None, + text_input_ids=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("text_input_ids must be provided when the tokenizer is not specified") + + prompt_embeds = text_encoder(text_input_ids.to(device))[0] + + dtype = text_encoder.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + +def _encode_prompt_with_clip( + text_encoder, + tokenizer, + prompt: str, + device=None, + text_input_ids=None, + num_images_per_prompt: int = 1, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=77, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("text_input_ids must be provided when the tokenizer is not specified") + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + +def encode_prompt( + text_encoders, + tokenizers, + prompt: str, + max_sequence_length, + device=None, + num_images_per_prompt: int = 1, + text_input_ids_list=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + dtype = text_encoders[0].dtype + + pooled_prompt_embeds = _encode_prompt_with_clip( + text_encoder=text_encoders[0], + tokenizer=tokenizers[0], + prompt=prompt, + device=device if device is not None else text_encoders[0].device, + num_images_per_prompt=num_images_per_prompt, + text_input_ids=text_input_ids_list[0] if text_input_ids_list is not None else None + ) + + prompt_embeds = _encode_prompt_with_t5( + text_encoder=text_encoders[1], + tokenizer=tokenizers[1], + max_sequence_length=max_sequence_length, + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + device=device if device is not None else text_encoders[1].device, + text_input_ids=text_input_ids_list[1] if text_input_ids_list is not None else None + ) + + text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + text_ids = text_ids.repeat(num_images_per_prompt, 1, 1) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + +# CustomFlowMatchEulerDiscreteScheduler was taken from ostris ai-toolkit trainer: +# https://github.com/ostris/ai-toolkit/blob/9ee1ef2a0a2a9a02b92d114a95f21312e5906e54/toolkit/samplers/custom_flowmatch_sampler.py#L95 +class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + with torch.no_grad(): + # create weights for timesteps + num_timesteps = 1000 + + # generate the multiplier based on cosmap loss weighing + # this is only used on linear timesteps for now + + # cosine map weighing is higher in the middle and lower at the ends + # bot = 1 - 2 * self.sigmas + 2 * self.sigmas ** 2 + # cosmap_weighing = 2 / (math.pi * bot) + + # sigma sqrt weighing is significantly higher at the end and lower at the beginning + sigma_sqrt_weighing = (self.sigmas**-2.0).float() + # clip at 1e4 (1e6 is too high) + sigma_sqrt_weighing = torch.clamp(sigma_sqrt_weighing, max=1e4) + # bring to a mean of 1 + sigma_sqrt_weighing = sigma_sqrt_weighing / sigma_sqrt_weighing.mean() + + # Create linear timesteps from 1000 to 0 + timesteps = torch.linspace(1000, 0, num_timesteps, device="cpu") + + self.linear_timesteps = timesteps + # self.linear_timesteps_weights = cosmap_weighing + self.linear_timesteps_weights = sigma_sqrt_weighing + + # self.sigmas = self.get_sigmas(timesteps, n_dim=1, dtype=torch.float32, device='cpu') + pass + + def get_weights_for_timesteps(self, timesteps: torch.Tensor) -> torch.Tensor: + # Get the indices of the timesteps + step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps] + + # Get the weights for the timesteps + weights = self.linear_timesteps_weights[step_indices].flatten() + + return weights + + def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor: + sigmas = self.sigmas.to(device=device, dtype=dtype) + schedule_timesteps = self.timesteps.to(device) + timesteps = timesteps.to(device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + + return sigma + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, + ) -> torch.Tensor: + ## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578 + ## Add noise according to flow matching. + ## zt = (1 - texp) * x + texp * z1 + + # sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) + # noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + + # timestep needs to be in [0, 1], we store them in [0, 1000] + # noisy_sample = (1 - timestep) * latent + timestep * noise + t_01 = (timesteps / 1000).to(original_samples.device) + noisy_model_input = (1 - t_01) * original_samples + t_01 * noise + + # n_dim = original_samples.ndim + # sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device) + # noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise + return noisy_model_input + + def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + return sample + + def set_train_timesteps(self, num_timesteps, device, linear=False): + if linear: + timesteps = torch.linspace(1000, 0, num_timesteps, device=device) + self.timesteps = timesteps + return timesteps + else: + # distribute them closer to center. Inference distributes them as a bias toward first + # Generate values from 0 to 1 + t = torch.sigmoid(torch.randn((num_timesteps,), device=device)) + + # Scale and reverse the values to go from 1000 to 0 + timesteps = (1 - t) * 1000 + + # Sort the timesteps in descending order + timesteps, _ = torch.sort(timesteps, descending=True) + + self.timesteps = timesteps.to(device=device) + + return timesteps + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[kwargs], + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Generate class images if prior preservation is enabled. + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available() + torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32 + if args.prior_generation_precision == "fp32": + torch_dtype = torch.float32 + elif args.prior_generation_precision == "fp16": + torch_dtype = torch.float16 + elif args.prior_generation_precision == "bf16": + torch_dtype = torch.bfloat16 + pipeline = FluxPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + revision=args.revision, + variant=args.variant, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + ).repo_id + + # Load the tokenizers + tokenizer_one = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + ) + tokenizer_two = T5TokenizerFast.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer_2", + revision=args.revision, + ) + + # import correct text encoder classes + text_encoder_cls_one = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision + ) + text_encoder_cls_two = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" + ) + + # Load scheduler and models + noise_scheduler = CustomFlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder="scheduler" + ) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + ) + transformer = FluxTransformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant + ) + + if args.train_text_encoder_ti: + # we parse the provided token identifier (or identifiers) into a list. s.t. - "TOK" -> ["TOK"], "TOK, + # TOK2" -> ["TOK", "TOK2"] etc. + token_abstraction_list = "".join(args.token_abstraction.split()).split(",") + logger.info(f"list of token identifiers: {token_abstraction_list}") + + token_abstraction_dict = {} + token_idx = 0 + for i, token in enumerate(token_abstraction_list): + token_abstraction_dict[token] = [ + f"" for j in range(args.num_new_tokens_per_abstraction) + ] + token_idx += args.num_new_tokens_per_abstraction - 1 + + # replace instances of --token_abstraction in --instance_prompt with the new tokens: "" etc. + for token_abs, token_replacement in token_abstraction_dict.items(): + args.instance_prompt = args.instance_prompt.replace(token_abs, "".join(token_replacement)) + if args.with_prior_preservation: + args.class_prompt = args.class_prompt.replace(token_abs, "".join(token_replacement)) + if args.validation_prompt: + args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement)) + + # initialize the new tokens for textual inversion + embedding_handler = TokenEmbeddingsHandler( + [text_encoder_one], [tokenizer_one] + ) + inserting_toks = [] + for new_tok in token_abstraction_dict.values(): + inserting_toks.extend(new_tok) + embedding_handler.initialize_new_tokens(inserting_toks=inserting_toks) + + # We only train the additional adapter LoRA layers + transformer.requires_grad_(False) + vae.requires_grad_(False) + text_encoder_one.requires_grad_(False) + text_encoder_two.requires_grad_(False) + + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + vae.to(accelerator.device, dtype=weight_dtype) + transformer.to(accelerator.device, dtype=weight_dtype) + text_encoder_one.to(accelerator.device, dtype=weight_dtype) + text_encoder_two.to(accelerator.device, dtype=weight_dtype) + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + if args.train_text_encoder: + text_encoder_one.gradient_checkpointing_enable() + + # now we will add new LoRA weights to the attention layers + transformer_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + ) + transformer.add_adapter(transformer_lora_config) + + if args.train_text_encoder: + text_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], + ) + text_encoder_one.add_adapter(text_lora_config) + # if we use textual inversion, we freeze all parameters except for the token embeddings + # in text encoder + elif args.train_text_encoder_ti: + text_lora_parameters_one = [] # for now only for CLIP + for name, param in text_encoder_one.named_parameters(): + if "token_embedding" in name: + # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 + param.data = param.to(dtype=torch.float32) + param.requires_grad = True + text_lora_parameters_one.append(param) + else: + param.requires_grad = False + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + transformer_lora_layers_to_save = None + text_encoder_one_lora_layers_to_save = None + + for model in models: + if isinstance(model, type(unwrap_model(transformer))): + transformer_lora_layers_to_save = get_peft_model_state_dict(model) + elif isinstance(model, type(unwrap_model(text_encoder_one))): + if args.train_text_encoder: # when --train_text_encoder_ti we don't save the layers + text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + FluxPipeline.save_lora_weights( + output_dir, + transformer_lora_layers=transformer_lora_layers_to_save, + text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, + ) + if args.train_text_encoder_ti: + embedding_handler.save_embeddings(f"{args.output_dir}/{Path(args.output_dir).name}_emb.safetensors") + + def load_model_hook(models, input_dir): + transformer_ = None + text_encoder_one_ = None + + while len(models) > 0: + model = models.pop() + + if isinstance(model, type(unwrap_model(transformer))): + transformer_ = model + elif isinstance(model, type(unwrap_model(text_encoder_one))): + text_encoder_one_ = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + lora_state_dict = FluxPipeline.lora_state_dict(input_dir) + + transformer_state_dict = { + f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") + } + transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) + incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + if args.train_text_encoder: + # Do we need to call `scale_lora_layers()` here? + _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + models = [transformer_] + if args.train_text_encoder: + models.extend([text_encoder_one_]) + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [transformer] + if args.train_text_encoder: + models.extend([text_encoder_one]) + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models, dtype=torch.float32) + + transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + + if args.train_text_encoder: + text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters())) + + # If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training + freeze_text_encoder = not (args.train_text_encoder or args.train_text_encoder_ti) + + # Optimization parameters + transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} + if not freeze_text_encoder: + # different learning rate for text encoder and unet + text_parameters_one_with_lr = { + "params": text_lora_parameters_one, + "weight_decay": args.adam_weight_decay_text_encoder + if args.adam_weight_decay_text_encoder + else args.adam_weight_decay, + "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, + } + params_to_optimize = [ + transformer_parameters_with_lr, + text_parameters_one_with_lr, + ] + else: + params_to_optimize = [transformer_parameters_with_lr] + + # Optimizer creation + if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): + logger.warning( + f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." + "Defaulting to adamW" + ) + args.optimizer = "adamw" + + if args.use_8bit_adam and not args.optimizer.lower() == "adamw": + logger.warning( + f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " + f"set to {args.optimizer.lower()}" + ) + + if args.optimizer.lower() == "adamw": + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + if args.optimizer.lower() == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_class = prodigyopt.Prodigy + + if args.learning_rate <= 0.1: + logger.warning( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + if args.train_text_encoder and args.text_encoder_lr: + logger.warning( + f"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:" + f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. " + f"When using prodigy only learning_rate is used as the initial learning rate." + ) + # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be + # --learning_rate + params_to_optimize[1]["lr"] = args.learning_rate + params_to_optimize[2]["lr"] = args.learning_rate + + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + ) + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + train_text_encoder_ti=args.train_text_encoder_ti, + token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None, + class_prompt=args.class_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_num=args.num_class_images, + size=args.resolution, + repeats=args.repeats, + center_crop=args.center_crop, + ) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=args.dataloader_num_workers, + ) + + if freeze_text_encoder: + tokenizers = [tokenizer_one, tokenizer_two] + text_encoders = [text_encoder_one, text_encoder_two] + + def compute_text_embeddings(prompt, text_encoders, tokenizers): + with torch.no_grad(): + prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + text_encoders, tokenizers, prompt, args.max_sequence_length + ) + prompt_embeds = prompt_embeds.to(accelerator.device) + pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) + text_ids = text_ids.to(accelerator.device) + return prompt_embeds, pooled_prompt_embeds, text_ids + + # If no type of tuning is done on the text_encoder and custom instance prompts are NOT + # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid + # the redundant encoding. + if freeze_text_encoder and not train_dataset.custom_instance_prompts: + instance_prompt_hidden_states, instance_pooled_prompt_embeds, instance_text_ids = compute_text_embeddings( + args.instance_prompt, text_encoders, tokenizers + ) + + # Handle class prompt for prior-preservation. + if args.with_prior_preservation: + if freeze_text_encoder: + class_prompt_hidden_states, class_pooled_prompt_embeds, class_text_ids = compute_text_embeddings( + args.class_prompt, text_encoders, tokenizers + ) + + # Clear the memory here + if freeze_text_encoder and not train_dataset.custom_instance_prompts: + del tokenizers, text_encoders + # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection + del text_encoder_one, text_encoder_two + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # if --train_text_encoder_ti we need add_special_tokens to be True for textual inversion + add_special_tokens = True if args.train_text_encoder_ti else False + + # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), + # pack the statically computed variables appropriately here. This is so that we don't + # have to pass them to the dataloader. + + if not train_dataset.custom_instance_prompts: + if freeze_text_encoder: + prompt_embeds = instance_prompt_hidden_states + pooled_prompt_embeds = instance_pooled_prompt_embeds + text_ids = instance_text_ids + if args.with_prior_preservation: + prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) + pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0) + text_ids = torch.cat([text_ids, class_text_ids], dim=0) + # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) + # we need to tokenize and encode the batch prompts on all training steps + else: + tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt, max_sequence_length=77, add_special_tokens=add_special_tokens) + tokens_two = tokenize_prompt( + tokenizer_two, args.instance_prompt, max_sequence_length=args.max_sequence_length + ) + if args.with_prior_preservation: + class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt, max_sequence_length=77, add_special_tokens=add_special_tokens) + class_tokens_two = tokenize_prompt( + tokenizer_two, args.class_prompt, max_sequence_length=args.max_sequence_length + ) + tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) + tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) + + vae_config_shift_factor = vae.config.shift_factor + vae_config_scaling_factor = vae.config.scaling_factor + vae_config_block_out_channels = vae.config.block_out_channels + if args.cache_latents: + latents_cache = [] + for batch in tqdm(train_dataloader, desc="Caching latents"): + with torch.no_grad(): + batch["pixel_values"] = batch["pixel_values"].to( + accelerator.device, non_blocking=True, dtype=weight_dtype + ) + latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) + + if args.validation_prompt is None: + del vae + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + if not freeze_text_encoder: + ( + transformer, + text_encoder_one, + optimizer, + train_dataloader, + lr_scheduler, + ) = accelerator.prepare( + transformer, + text_encoder_one, + optimizer, + train_dataloader, + lr_scheduler, + ) + else: + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_name = "dreambooth-flux-dev-lora-advanced" + accelerator.init_trackers(tracker_name, config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + if args.train_text_encoder: + num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs) + elif args.train_text_encoder_ti: # args.train_text_encoder_ti + num_train_epochs_text_encoder = int(args.train_text_encoder_ti_frac * args.num_train_epochs) + # flag used for textual inversion + pivoted = False + + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + # if performing any kind of optimization of text_encoder params + if args.train_text_encoder or args.train_text_encoder_ti: + if epoch == num_train_epochs_text_encoder: + print("PIVOT HALFWAY", epoch) + # stopping optimization of text_encoder params + # this flag is used to reset the optimizer to optimize only on unet params + pivoted = True + + else: + # still optimizing the text encoder + text_encoder_one.train() + # set top parameter requires_grad = True for gradient checkpointing works + if args.train_text_encoder: + accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) + + for step, batch in enumerate(train_dataloader): + if pivoted: + # stopping optimization of text_encoder params + # re setting the optimizer to optimize only on unet params + optimizer.param_groups[1]["lr"] = 0.0 + + with accelerator.accumulate(transformer): + prompts = batch["prompts"] + + # encode batch prompts when custom prompts are provided for each image - + if train_dataset.custom_instance_prompts: + if freeze_text_encoder: + prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings( + prompts, text_encoders, tokenizers + ) + else: + tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77, add_special_tokens=add_special_tokens) + tokens_two = tokenize_prompt( + tokenizer_two, prompts, max_sequence_length=args.max_sequence_length, + ) + + if not freeze_text_encoder: + prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + text_encoders=[text_encoder_one, text_encoder_two], + tokenizers=[None, None], + text_input_ids_list=[tokens_one, tokens_two], + max_sequence_length=args.max_sequence_length, + device=accelerator.device, + prompt=prompts, + ) + + # Convert images to latent space + if args.cache_latents: + model_input = latents_cache[step].sample() + else: + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + model_input = vae.encode(pixel_values).latent_dist.sample() + model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor + model_input = model_input.to(dtype=weight_dtype) + + vae_scale_factor = 2 ** (len(vae_config_block_out_channels)) + + latent_image_ids = FluxPipeline._prepare_latent_image_ids( + model_input.shape[0], + model_input.shape[2], + model_input.shape[3], + accelerator.device, + weight_dtype, + ) + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + bsz = model_input.shape[0] + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) + + # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 + sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) + noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + + packed_noisy_model_input = FluxPipeline._pack_latents( + noisy_model_input, + batch_size=model_input.shape[0], + num_channels_latents=model_input.shape[1], + height=model_input.shape[2], + width=model_input.shape[3], + ) + + # handle guidance + if transformer.config.guidance_embeds: + guidance = torch.tensor([args.guidance_scale], device=accelerator.device) + guidance = guidance.expand(model_input.shape[0]) + else: + guidance = None + + # Predict the noise residual + model_pred = transformer( + hidden_states=packed_noisy_model_input, + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) + timestep=timesteps / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + model_pred = FluxPipeline._unpack_latents( + model_pred, + height=int(model_input.shape[2] * vae_scale_factor / 2), + width=int(model_input.shape[3] * vae_scale_factor / 2), + vae_scale_factor=vae_scale_factor, + ) + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss + target = noise - model_input + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute prior loss + prior_loss = torch.mean( + (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( + target_prior.shape[0], -1 + ), + 1, + ) + prior_loss = prior_loss.mean() + + # Compute regular loss. + loss = torch.mean( + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + 1, + ) + loss = loss.mean() + + if args.with_prior_preservation: + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = ( + itertools.chain(transformer.parameters(), text_encoder_one.parameters()) + if not freeze_text_encoder + else transformer.parameters() + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # every step, we reset the embeddings to the original embeddings. + if args.train_text_encoder_ti: + embedding_handler.retract_embeddings() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + # create pipeline + if freeze_text_encoder: + text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) + pipeline = FluxPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + text_encoder=accelerator.unwrap_model(text_encoder_one), + text_encoder_2=accelerator.unwrap_model(text_encoder_two), + transformer=accelerator.unwrap_model(transformer), + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipeline_args = {"prompt": args.validation_prompt} + images = log_validation( + pipeline=pipeline, + args=args, + accelerator=accelerator, + pipeline_args=pipeline_args, + epoch=epoch, + ) + if freeze_text_encoder: + del text_encoder_one, text_encoder_two + torch.cuda.empty_cache() + gc.collect() + else: + del text_encoder_two + torch.cuda.empty_cache() + gc.collect() + + # Save the lora layers + accelerator.wait_for_everyone() + if accelerator.is_main_process: + transformer = unwrap_model(transformer) + transformer = transformer.to(weight_dtype) + transformer_lora_layers = get_peft_model_state_dict(transformer) + + if args.train_text_encoder: + text_encoder_one = unwrap_model(text_encoder_one) + text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32)) + else: + text_encoder_lora_layers = None + + FluxPipeline.save_lora_weights( + save_directory=args.output_dir, + transformer_lora_layers=transformer_lora_layers, + text_encoder_lora_layers=text_encoder_lora_layers, + ) + + if args.train_text_encoder_ti: + embeddings_path = f"{args.output_dir}/{args.output_dir}_emb.safetensors" + embedding_handler.save_embeddings(embeddings_path) + + # Final inference + # Load previous pipeline + pipeline = FluxPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + # load attention processors + pipeline.load_lora_weights(args.output_dir) + + # run inference + images = [] + if args.validation_prompt and args.num_validation_images > 0: + pipeline_args = {"prompt": args.validation_prompt} + images = log_validation( + pipeline=pipeline, + args=args, + accelerator=accelerator, + pipeline_args=pipeline_args, + epoch=epoch, + is_final_validation=True, + ) + + save_model_card( + repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + train_text_encoder=args.train_text_encoder, + train_text_encoder_ti=args.train_text_encoder_ti, + token_abstraction_dict=train_dataset.token_abstraction_dict, + instance_prompt=args.instance_prompt, + validation_prompt=args.validation_prompt, + repo_folder=args.output_dir, + ) + if args.push_to_hub: + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 0635b11b1b1b..7e2c60e90054 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -21,12 +21,10 @@ import math import os import random -import re import shutil import warnings from contextlib import nullcontext from pathlib import Path -from typing import Union, List import numpy as np import torch @@ -41,7 +39,6 @@ from peft.utils import get_peft_model_state_dict from PIL import Image from PIL.ImageOps import exif_transpose -from safetensors.torch import load_file, save_file from torch.utils.data import Dataset from torchvision import transforms from torchvision.transforms.functional import crop @@ -85,43 +82,17 @@ def save_model_card( images=None, base_model: str = None, train_text_encoder=False, - train_text_encoder_ti=False, - token_abstraction_dict=None, instance_prompt=None, validation_prompt=None, repo_folder=None, ): widget_dict = [] - trigger_str = f"You should use {instance_prompt} to trigger the image generation." - if images is not None: for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) widget_dict.append( {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} ) - diffusers_imports_pivotal = "" - diffusers_example_pivotal = "" - if train_text_encoder_ti: - embeddings_filename = f"{repo_folder}_emb" - ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"", instance_prompt)) - trigger_str = ( - "To trigger image generation of trained concept(or concepts) replace each concept identifier " - "in you prompt with the new inserted tokens:\n" - ) - diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download - from safetensors.torch import load_file - """ - diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors', repo_type="model") - state_dict = load_file(embedding_path) - pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer) - """ - if token_abstraction_dict: - for key, value in token_abstraction_dict.items(): - tokens = "".join(value) - trigger_str += f""" - to trigger concept `{key}` → use `{tokens}` in your prompt \n - """ model_description = f""" # Flux DreamBooth LoRA - {repo_id} @@ -136,11 +107,9 @@ def save_model_card( Was LoRA for the text encoder enabled? {train_text_encoder}. -Pivotal tuning was enabled: {train_text_encoder_ti}. - ## Trigger words -{trigger_str} +You should use `{instance_prompt}` to trigger the image generation. ## Download model @@ -151,10 +120,8 @@ def save_model_card( ```py from diffusers import AutoPipelineForText2Image import torch -{diffusers_imports_pivotal} pipeline = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to('cuda') pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors') -{diffusers_example_pivotal} image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0] ``` @@ -343,23 +310,6 @@ def parse_args(input_args=None): required=True, help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", ) - parser.add_argument( - "--token_abstraction", - type=str, - default="TOK", - help="identifier specifying the instance(or instances) as used in instance_prompt, validation prompt, " - "captions - e.g. TOK. To use multiple identifiers, please specify them in a comma separated string - e.g. " - "'TOK,TOK2,TOK3' etc.", - ) - - parser.add_argument( - "--num_new_tokens_per_abstraction", - type=int, - default=2, - help="number of new tokens inserted to the tokenizers per token_abstraction identifier when " - "--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new " - "tokens - ", - ) parser.add_argument( "--class_prompt", type=str, @@ -450,25 +400,6 @@ def parse_args(input_args=None): action="store_true", help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", ) - parser.add_argument( - "--train_text_encoder_ti", - action="store_true", - help=("Whether to use textual inversion"), - ) - - parser.add_argument( - "--train_text_encoder_ti_frac", - type=float, - default=0.5, - help=("The percentage of epochs to perform textual inversion"), - ) - - parser.add_argument( - "--train_text_encoder_frac", - type=float, - default=1.0, - help=("The percentage of epochs to perform text encoder tuning"), - ) parser.add_argument( "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) @@ -668,12 +599,6 @@ def parse_args(input_args=None): " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" ), ) - parser.add_argument( - "--cache_latents", - action="store_true", - default=False, - help="Cache the VAE latents", - ) parser.add_argument( "--report_to", type=str, @@ -717,12 +642,6 @@ def parse_args(input_args=None): if args.dataset_name is not None and args.instance_data_dir is not None: raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") - if args.train_text_encoder and args.train_text_encoder_ti: - raise ValueError( - "Specify only one of `--train_text_encoder` or `--train_text_encoder_ti. " - "For full LoRA text encoder training check --train_text_encoder, for textual " - "inversion training check `--train_text_encoder_ti`" - ) env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank @@ -741,103 +660,6 @@ def parse_args(input_args=None): return args -# Modified from https://github.com/replicate/cog-sdxl/blob/main/dataset_and_utils.py -class TokenEmbeddingsHandler: - def __init__(self, text_encoders, tokenizers): - self.text_encoders = text_encoders - self.tokenizers = tokenizers - - self.train_ids: Optional[torch.Tensor] = None - self.inserting_toks: Optional[List[str]] = None - self.embeddings_settings = {} - - def initialize_new_tokens(self, inserting_toks: List[str]): - idx = 0 - for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): - assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings." - assert all( - isinstance(tok, str) for tok in inserting_toks - ), "All elements in inserting_toks should be strings." - - self.inserting_toks = inserting_toks - special_tokens_dict = {"additional_special_tokens": self.inserting_toks} - tokenizer.add_special_tokens(special_tokens_dict) - text_encoder.resize_token_embeddings(len(tokenizer)) - - self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks) - - # random initialization of new tokens - std_token_embedding = text_encoder.text_model.embeddings.token_embedding.weight.data.std() - - print(f"{idx} text encoder's std_token_embedding: {std_token_embedding}") - - text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = ( - torch.randn(len(self.train_ids), text_encoder.text_model.config.hidden_size) - .to(device=self.device) - .to(dtype=self.dtype) - * std_token_embedding - ) - self.embeddings_settings[ - f"original_embeddings_{idx}" - ] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone() - self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding - - inu = torch.ones((len(tokenizer),), dtype=torch.bool) - inu[self.train_ids] = False - - self.embeddings_settings[f"index_no_updates_{idx}"] = inu - - print(self.embeddings_settings[f"index_no_updates_{idx}"].shape) - - idx += 1 - - def save_embeddings(self, file_path: str): - assert self.train_ids is not None, "Initialize new tokens before saving embeddings." - tensors = {} - # text_encoder_0 - CLIP ViT-L/14, for now only optimizing and saving embeddings for CLIP (text_encoder_two - T5, remains untouched) - idx_to_text_encoder_name = {0: "clip_l"} - for idx, text_encoder in enumerate(self.text_encoders): - assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len( - self.tokenizers[0] - ), "Tokenizers should be the same." - new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] - - # New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), - # Note: When loading with diffusers, any name can work - simply specify in inference - tensors[idx_to_text_encoder_name[idx]] = new_token_embeddings - # tensors[f"text_encoders_{idx}"] = new_token_embeddings - - save_file(tensors, file_path) - - @property - def dtype(self): - return self.text_encoders[0].dtype - - @property - def device(self): - return self.text_encoders[0].device - - @torch.no_grad() - def retract_embeddings(self): - for idx, text_encoder in enumerate(self.text_encoders): - index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"] - text_encoder.text_model.embeddings.token_embedding.weight.data[index_no_updates] = ( - self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates] - .to(device=text_encoder.device) - .to(dtype=text_encoder.dtype) - ) - - # for the parts that were updated, we need to normalize them - # to have the same std as before - std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"] - - index_updates = ~index_no_updates - new_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] - off_ratio = std_token_embedding / new_embeddings.std() - - new_embeddings = new_embeddings * (off_ratio**0.1) - text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings - class DreamBoothDataset(Dataset): """ @@ -850,8 +672,6 @@ def __init__( instance_data_root, instance_prompt, class_prompt, - train_text_encoder_ti, - token_abstraction_dict=None, # token mapping for textual inversion class_data_root=None, class_num=None, size=1024, @@ -864,8 +684,7 @@ def __init__( self.instance_prompt = instance_prompt self.custom_instance_prompts = None self.class_prompt = class_prompt - self.token_abstraction_dict = token_abstraction_dict - self.train_text_encoder_ti = train_text_encoder_ti + # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, # we load the training data using load_dataset if args.dataset_name is not None: @@ -992,15 +811,11 @@ def __getitem__(self, index): if self.custom_instance_prompts: caption = self.custom_instance_prompts[index % self.num_instance_images] if caption: - if self.train_text_encoder_ti: - # replace instances of --token_abstraction in caption with the new tokens: "" etc. - for token_abs, token_replacement in self.token_abstraction_dict.items(): - caption = caption.replace(token_abs, "".join(token_replacement)) example["instance_prompt"] = caption else: example["instance_prompt"] = self.instance_prompt - else: # the given instance prompt is used for all images + else: # custom prompts were provided, but length does not match size of image dataset example["instance_prompt"] = self.instance_prompt if self.class_data_root: @@ -1049,7 +864,7 @@ def __getitem__(self, index): return example -def tokenize_prompt(tokenizer, prompt, max_sequence_length, add_special_tokens=False): +def tokenize_prompt(tokenizer, prompt, max_sequence_length): text_inputs = tokenizer( prompt, padding="max_length", @@ -1057,7 +872,6 @@ def tokenize_prompt(tokenizer, prompt, max_sequence_length, add_special_tokens=F truncation=True, return_length=False, return_overflowing_tokens=False, - add_special_tokens=add_special_tokens, return_tensors="pt", ) text_input_ids = text_inputs.input_ids @@ -1164,7 +978,7 @@ def encode_prompt( prompt=prompt, device=device if device is not None else text_encoders[0].device, num_images_per_prompt=num_images_per_prompt, - text_input_ids=text_input_ids_list[0] if text_input_ids_list is not None else None + text_input_ids=text_input_ids_list[0] if text_input_ids_list else None, ) prompt_embeds = _encode_prompt_with_t5( @@ -1174,7 +988,7 @@ def encode_prompt( prompt=prompt, num_images_per_prompt=num_images_per_prompt, device=device if device is not None else text_encoders[1].device, - text_input_ids=text_input_ids_list[1] if text_input_ids_list is not None else None + text_input_ids=text_input_ids_list[1] if text_input_ids_list else None, ) text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) @@ -1183,108 +997,6 @@ def encode_prompt( return prompt_embeds, pooled_prompt_embeds, text_ids -# CustomFlowMatchEulerDiscreteScheduler was taken from ostris ai-toolkit trainer: -# https://github.com/ostris/ai-toolkit/blob/9ee1ef2a0a2a9a02b92d114a95f21312e5906e54/toolkit/samplers/custom_flowmatch_sampler.py#L95 -class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - with torch.no_grad(): - # create weights for timesteps - num_timesteps = 1000 - - # generate the multiplier based on cosmap loss weighing - # this is only used on linear timesteps for now - - # cosine map weighing is higher in the middle and lower at the ends - # bot = 1 - 2 * self.sigmas + 2 * self.sigmas ** 2 - # cosmap_weighing = 2 / (math.pi * bot) - - # sigma sqrt weighing is significantly higher at the end and lower at the beginning - sigma_sqrt_weighing = (self.sigmas**-2.0).float() - # clip at 1e4 (1e6 is too high) - sigma_sqrt_weighing = torch.clamp(sigma_sqrt_weighing, max=1e4) - # bring to a mean of 1 - sigma_sqrt_weighing = sigma_sqrt_weighing / sigma_sqrt_weighing.mean() - - # Create linear timesteps from 1000 to 0 - timesteps = torch.linspace(1000, 0, num_timesteps, device="cpu") - - self.linear_timesteps = timesteps - # self.linear_timesteps_weights = cosmap_weighing - self.linear_timesteps_weights = sigma_sqrt_weighing - - # self.sigmas = self.get_sigmas(timesteps, n_dim=1, dtype=torch.float32, device='cpu') - pass - - def get_weights_for_timesteps(self, timesteps: torch.Tensor) -> torch.Tensor: - # Get the indices of the timesteps - step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps] - - # Get the weights for the timesteps - weights = self.linear_timesteps_weights[step_indices].flatten() - - return weights - - def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor: - sigmas = self.sigmas.to(device=device, dtype=dtype) - schedule_timesteps = self.timesteps.to(device) - timesteps = timesteps.to(device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < n_dim: - sigma = sigma.unsqueeze(-1) - - return sigma - - def add_noise( - self, - original_samples: torch.Tensor, - noise: torch.Tensor, - timesteps: torch.Tensor, - ) -> torch.Tensor: - ## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578 - ## Add noise according to flow matching. - ## zt = (1 - texp) * x + texp * z1 - - # sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) - # noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise - - # timestep needs to be in [0, 1], we store them in [0, 1000] - # noisy_sample = (1 - timestep) * latent + timestep * noise - t_01 = (timesteps / 1000).to(original_samples.device) - noisy_model_input = (1 - t_01) * original_samples + t_01 * noise - - # n_dim = original_samples.ndim - # sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device) - # noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise - return noisy_model_input - - def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: - return sample - - def set_train_timesteps(self, num_timesteps, device, linear=False): - if linear: - timesteps = torch.linspace(1000, 0, num_timesteps, device=device) - self.timesteps = timesteps - return timesteps - else: - # distribute them closer to center. Inference distributes them as a bias toward first - # Generate values from 0 to 1 - t = torch.sigmoid(torch.randn((num_timesteps,), device=device)) - - # Scale and reverse the values to go from 1000 to 0 - timesteps = (1 - t) * 1000 - - # Sort the timesteps in descending order - timesteps, _ = torch.sort(timesteps, descending=True) - - self.timesteps = timesteps.to(device=device) - - return timesteps - - def main(args): if args.report_to == "wandb" and args.hub_token is not None: raise ValueError( @@ -1415,7 +1127,7 @@ def main(args): ) # Load scheduler and models - noise_scheduler = CustomFlowMatchEulerDiscreteScheduler.from_pretrained( + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( args.pretrained_model_name_or_path, subfolder="scheduler" ) noise_scheduler_copy = copy.deepcopy(noise_scheduler) @@ -1430,37 +1142,6 @@ def main(args): args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant ) - if args.train_text_encoder_ti: - # we parse the provided token identifier (or identifiers) into a list. s.t. - "TOK" -> ["TOK"], "TOK, - # TOK2" -> ["TOK", "TOK2"] etc. - token_abstraction_list = "".join(args.token_abstraction.split()).split(",") - logger.info(f"list of token identifiers: {token_abstraction_list}") - - token_abstraction_dict = {} - token_idx = 0 - for i, token in enumerate(token_abstraction_list): - token_abstraction_dict[token] = [ - f"" for j in range(args.num_new_tokens_per_abstraction) - ] - token_idx += args.num_new_tokens_per_abstraction - 1 - - # replace instances of --token_abstraction in --instance_prompt with the new tokens: "" etc. - for token_abs, token_replacement in token_abstraction_dict.items(): - args.instance_prompt = args.instance_prompt.replace(token_abs, "".join(token_replacement)) - if args.with_prior_preservation: - args.class_prompt = args.class_prompt.replace(token_abs, "".join(token_replacement)) - if args.validation_prompt: - args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement)) - - # initialize the new tokens for textual inversion - embedding_handler = TokenEmbeddingsHandler( - [text_encoder_one], [tokenizer_one] - ) - inserting_toks = [] - for new_tok in token_abstraction_dict.values(): - inserting_toks.extend(new_tok) - embedding_handler.initialize_new_tokens(inserting_toks=inserting_toks) - # We only train the additional adapter LoRA layers transformer.requires_grad_(False) vae.requires_grad_(False) @@ -1499,7 +1180,6 @@ def main(args): target_modules=["to_k", "to_q", "to_v", "to_out.0"], ) transformer.add_adapter(transformer_lora_config) - if args.train_text_encoder: text_lora_config = LoraConfig( r=args.rank, @@ -1508,18 +1188,6 @@ def main(args): target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], ) text_encoder_one.add_adapter(text_lora_config) - # if we use textual inversion, we freeze all parameters except for the token embeddings - # in text encoder - elif args.train_text_encoder_ti: - text_lora_parameters_one = [] # for now only for CLIP - for name, param in text_encoder_one.named_parameters(): - if "token_embedding" in name: - # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 - param.data = param.to(dtype=torch.float32) - param.requires_grad = True - text_lora_parameters_one.append(param) - else: - param.requires_grad = False def unwrap_model(model): model = accelerator.unwrap_model(model) @@ -1536,8 +1204,7 @@ def save_model_hook(models, weights, output_dir): if isinstance(model, type(unwrap_model(transformer))): transformer_lora_layers_to_save = get_peft_model_state_dict(model) elif isinstance(model, type(unwrap_model(text_encoder_one))): - if args.train_text_encoder: # when --train_text_encoder_ti we don't save the layers - text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) + text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1549,8 +1216,6 @@ def save_model_hook(models, weights, output_dir): transformer_lora_layers=transformer_lora_layers_to_save, text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, ) - if args.train_text_encoder_ti: - embedding_handler.save_embeddings(f"{args.output_dir}/{Path(args.output_dir).name}_emb.safetensors") def load_model_hook(models, input_dir): transformer_ = None @@ -1617,22 +1282,16 @@ def load_model_hook(models, input_dir): cast_training_params(models, dtype=torch.float32) transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) - if args.train_text_encoder: text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters())) - # If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training - freeze_text_encoder = not (args.train_text_encoder or args.train_text_encoder_ti) - # Optimization parameters transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} - if not freeze_text_encoder: + if args.train_text_encoder: # different learning rate for text encoder and unet text_parameters_one_with_lr = { "params": text_lora_parameters_one, - "weight_decay": args.adam_weight_decay_text_encoder - if args.adam_weight_decay_text_encoder - else args.adam_weight_decay, + "weight_decay": args.adam_weight_decay_text_encoder, "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, } params_to_optimize = [ @@ -1715,8 +1374,6 @@ def load_model_hook(models, input_dir): train_dataset = DreamBoothDataset( instance_data_root=args.instance_data_dir, instance_prompt=args.instance_prompt, - train_text_encoder_ti=args.train_text_encoder_ti, - token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None, class_prompt=args.class_prompt, class_data_root=args.class_data_dir if args.with_prior_preservation else None, class_num=args.num_class_images, @@ -1733,7 +1390,7 @@ def load_model_hook(models, input_dir): num_workers=args.dataloader_num_workers, ) - if freeze_text_encoder: + if not args.train_text_encoder: tokenizers = [tokenizer_one, tokenizer_two] text_encoders = [text_encoder_one, text_encoder_two] @@ -1750,20 +1407,20 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # If no type of tuning is done on the text_encoder and custom instance prompts are NOT # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid # the redundant encoding. - if freeze_text_encoder and not train_dataset.custom_instance_prompts: + if not args.train_text_encoder and not train_dataset.custom_instance_prompts: instance_prompt_hidden_states, instance_pooled_prompt_embeds, instance_text_ids = compute_text_embeddings( args.instance_prompt, text_encoders, tokenizers ) # Handle class prompt for prior-preservation. if args.with_prior_preservation: - if freeze_text_encoder: + if not args.train_text_encoder: class_prompt_hidden_states, class_pooled_prompt_embeds, class_text_ids = compute_text_embeddings( args.class_prompt, text_encoders, tokenizers ) # Clear the memory here - if freeze_text_encoder and not train_dataset.custom_instance_prompts: + if not args.train_text_encoder and not train_dataset.custom_instance_prompts: del tokenizers, text_encoders # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection del text_encoder_one, text_encoder_two @@ -1771,15 +1428,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): if torch.cuda.is_available(): torch.cuda.empty_cache() - # if --train_text_encoder_ti we need add_special_tokens to be True for textual inversion - add_special_tokens = True if args.train_text_encoder_ti else False - # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't # have to pass them to the dataloader. if not train_dataset.custom_instance_prompts: - if freeze_text_encoder: + if not args.train_text_encoder: prompt_embeds = instance_prompt_hidden_states pooled_prompt_embeds = instance_pooled_prompt_embeds text_ids = instance_text_ids @@ -1790,36 +1444,18 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) # we need to tokenize and encode the batch prompts on all training steps else: - tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt, max_sequence_length=77, add_special_tokens=add_special_tokens) + tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt, max_sequence_length=77) tokens_two = tokenize_prompt( tokenizer_two, args.instance_prompt, max_sequence_length=args.max_sequence_length ) if args.with_prior_preservation: - class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt, max_sequence_length=77, add_special_tokens=add_special_tokens) + class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt, max_sequence_length=77) class_tokens_two = tokenize_prompt( tokenizer_two, args.class_prompt, max_sequence_length=args.max_sequence_length ) tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) - vae_config_shift_factor = vae.config.shift_factor - vae_config_scaling_factor = vae.config.scaling_factor - vae_config_block_out_channels = vae.config.block_out_channels - if args.cache_latents: - latents_cache = [] - for batch in tqdm(train_dataloader, desc="Caching latents"): - with torch.no_grad(): - batch["pixel_values"] = batch["pixel_values"].to( - accelerator.device, non_blocking=True, dtype=weight_dtype - ) - latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) - - if args.validation_prompt is None: - del vae - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -1837,7 +1473,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): ) # Prepare everything with our `accelerator`. - if not freeze_text_encoder: + if args.train_text_encoder: ( transformer, text_encoder_one, @@ -1930,71 +1566,57 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): sigma = sigma.unsqueeze(-1) return sigma - if args.train_text_encoder: - num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs) - elif args.train_text_encoder_ti: # args.train_text_encoder_ti - num_train_epochs_text_encoder = int(args.train_text_encoder_ti_frac * args.num_train_epochs) - # flag used for textual inversion - pivoted = False - for epoch in range(first_epoch, args.num_train_epochs): transformer.train() - # if performing any kind of optimization of text_encoder params - if args.train_text_encoder or args.train_text_encoder_ti: - if epoch == num_train_epochs_text_encoder: - print("PIVOT HALFWAY", epoch) - # stopping optimization of text_encoder params - # this flag is used to reset the optimizer to optimize only on unet params - pivoted = True - - else: - # still optimizing the text encoder - text_encoder_one.train() - # set top parameter requires_grad = True for gradient checkpointing works - if args.train_text_encoder: - accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) + if args.train_text_encoder: + text_encoder_one.train() + # set top parameter requires_grad = True for gradient checkpointing works + accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) for step, batch in enumerate(train_dataloader): - if pivoted: - # stopping optimization of text_encoder params - # re setting the optimizer to optimize only on unet params - optimizer.param_groups[1]["lr"] = 0.0 - - with accelerator.accumulate(transformer): + models_to_accumulate = [transformer] + if args.train_text_encoder: + models_to_accumulate.extend([text_encoder_one]) + with accelerator.accumulate(models_to_accumulate): + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) prompts = batch["prompts"] # encode batch prompts when custom prompts are provided for each image - if train_dataset.custom_instance_prompts: - if freeze_text_encoder: + if not args.train_text_encoder: prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings( prompts, text_encoders, tokenizers ) else: - tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77, add_special_tokens=add_special_tokens) + tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77) tokens_two = tokenize_prompt( - tokenizer_two, prompts, max_sequence_length=args.max_sequence_length, + tokenizer_two, prompts, max_sequence_length=args.max_sequence_length + ) + prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + text_encoders=[text_encoder_one, text_encoder_two], + tokenizers=[None, None], + text_input_ids_list=[tokens_one, tokens_two], + max_sequence_length=args.max_sequence_length, + device=accelerator.device, + prompt=prompts, + ) + else: + if args.train_text_encoder: + prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + text_encoders=[text_encoder_one, text_encoder_two], + tokenizers=[None, None], + text_input_ids_list=[tokens_one, tokens_two], + max_sequence_length=args.max_sequence_length, + device=accelerator.device, + prompt=args.instance_prompt, ) - - if not freeze_text_encoder: - prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( - text_encoders=[text_encoder_one, text_encoder_two], - tokenizers=[None, None], - text_input_ids_list=[tokens_one, tokens_two], - max_sequence_length=args.max_sequence_length, - device=accelerator.device, - prompt=prompts, - ) # Convert images to latent space - if args.cache_latents: - model_input = latents_cache[step].sample() - else: - pixel_values = batch["pixel_values"].to(dtype=vae.dtype) - model_input = vae.encode(pixel_values).latent_dist.sample() - model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor + model_input = vae.encode(pixel_values).latent_dist.sample() + model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor model_input = model_input.to(dtype=weight_dtype) - vae_scale_factor = 2 ** (len(vae_config_block_out_channels)) + vae_scale_factor = 2 ** (len(vae.config.block_out_channels)) latent_image_ids = FluxPipeline._prepare_latent_image_ids( model_input.shape[0], @@ -2094,7 +1716,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if accelerator.sync_gradients: params_to_clip = ( itertools.chain(transformer.parameters(), text_encoder_one.parameters()) - if not freeze_text_encoder + if args.train_text_encoder else transformer.parameters() ) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) @@ -2103,10 +1725,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): lr_scheduler.step() optimizer.zero_grad() - # every step, we reset the embeddings to the original embeddings. - if args.train_text_encoder_ti: - embedding_handler.retract_embeddings() - # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) @@ -2148,7 +1766,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if accelerator.is_main_process: if args.validation_prompt is not None and epoch % args.validation_epochs == 0: # create pipeline - if freeze_text_encoder: + if not args.train_text_encoder: text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) pipeline = FluxPipeline.from_pretrained( args.pretrained_model_name_or_path, @@ -2168,20 +1786,16 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): pipeline_args=pipeline_args, epoch=epoch, ) - if freeze_text_encoder: + if not args.train_text_encoder: del text_encoder_one, text_encoder_two torch.cuda.empty_cache() gc.collect() - else: - del text_encoder_two - torch.cuda.empty_cache() - gc.collect() # Save the lora layers accelerator.wait_for_everyone() if accelerator.is_main_process: transformer = unwrap_model(transformer) - transformer = transformer.to(weight_dtype) + transformer = transformer.to(torch.float32) transformer_lora_layers = get_peft_model_state_dict(transformer) if args.train_text_encoder: @@ -2196,10 +1810,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): text_encoder_lora_layers=text_encoder_lora_layers, ) - if args.train_text_encoder_ti: - embeddings_path = f"{args.output_dir}/{args.output_dir}_emb.safetensors" - embedding_handler.save_embeddings(embeddings_path) - # Final inference # Load previous pipeline pipeline = FluxPipeline.from_pretrained( @@ -2224,18 +1834,16 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): is_final_validation=True, ) - save_model_card( - repo_id, - images=images, - base_model=args.pretrained_model_name_or_path, - train_text_encoder=args.train_text_encoder, - train_text_encoder_ti=args.train_text_encoder_ti, - token_abstraction_dict=train_dataset.token_abstraction_dict, - instance_prompt=args.instance_prompt, - validation_prompt=args.validation_prompt, - repo_folder=args.output_dir, - ) if args.push_to_hub: + save_model_card( + repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + train_text_encoder=args.train_text_encoder, + instance_prompt=args.instance_prompt, + validation_prompt=args.validation_prompt, + repo_folder=args.output_dir, + ) upload_folder( repo_id=repo_id, folder_path=args.output_dir, @@ -2248,4 +1856,4 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if __name__ == "__main__": args = parse_args() - main(args) + main(args) \ No newline at end of file From 7b7a6718e81b6613401dca9e43fc8f04d3cbf102 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 13 Sep 2024 17:45:57 +0300 Subject: [PATCH 18/82] add latent caching to canonical script --- .../dreambooth/train_dreambooth_lora_flux.py | 38 ++++++++++++++++--- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 7e2c60e90054..49f8fbb94b5f 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -25,6 +25,7 @@ import warnings from contextlib import nullcontext from pathlib import Path +from typing import Union import numpy as np import torch @@ -599,6 +600,12 @@ def parse_args(input_args=None): " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" ), ) + parser.add_argument( + "--cache_latents", + action="store_true", + default=False, + help="Cache the VAE latents", + ) parser.add_argument( "--report_to", type=str, @@ -1456,6 +1463,24 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) + vae_config_shift_factor = vae.config.shift_factor + vae_config_scaling_factor = vae.config.scaling_factor + vae_config_block_out_channels = vae.config.block_out_channels + if args.cache_latents: + latents_cache = [] + for batch in tqdm(train_dataloader, desc="Caching latents"): + with torch.no_grad(): + batch["pixel_values"] = batch["pixel_values"].to( + accelerator.device, non_blocking=True, dtype=weight_dtype + ) + latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) + + if args.validation_prompt is None: + del vae + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -1578,7 +1603,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.train_text_encoder: models_to_accumulate.extend([text_encoder_one]) with accelerator.accumulate(models_to_accumulate): - pixel_values = batch["pixel_values"].to(dtype=vae.dtype) prompts = batch["prompts"] # encode batch prompts when custom prompts are provided for each image - @@ -1612,11 +1636,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) # Convert images to latent space - model_input = vae.encode(pixel_values).latent_dist.sample() - model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor + if args.cache_latents: + model_input = latents_cache[step].sample() + else: + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + model_input = vae.encode(pixel_values).latent_dist.sample() + model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor model_input = model_input.to(dtype=weight_dtype) - vae_scale_factor = 2 ** (len(vae.config.block_out_channels)) + vae_scale_factor = 2 ** (len(vae_config_block_out_channels)) latent_image_ids = FluxPipeline._prepare_latent_image_ids( model_input.shape[0], @@ -1795,7 +1823,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator.wait_for_everyone() if accelerator.is_main_process: transformer = unwrap_model(transformer) - transformer = transformer.to(torch.float32) + transformer = transformer.to(weight_dtype) transformer_lora_layers = get_peft_model_state_dict(transformer) if args.train_text_encoder: From 2bb4ce1b16b9390585dad61b46969950fd0d74cc Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 13 Sep 2024 19:50:48 +0300 Subject: [PATCH 19/82] revert changes to canonical script to keep it separate from https://github.com/huggingface/diffusers/pull/9160 --- .../dreambooth/test_dreambooth_lora_flux.py | 35 +---------------- .../dreambooth/train_dreambooth_lora_flux.py | 38 +++---------------- 2 files changed, 6 insertions(+), 67 deletions(-) diff --git a/examples/dreambooth/test_dreambooth_lora_flux.py b/examples/dreambooth/test_dreambooth_lora_flux.py index d197c8187b87..11de0b72cc0a 100644 --- a/examples/dreambooth/test_dreambooth_lora_flux.py +++ b/examples/dreambooth/test_dreambooth_lora_flux.py @@ -103,39 +103,6 @@ def test_dreambooth_lora_text_encoder_flux(self): ) self.assertTrue(starts_with_expected_prefix) - def test_dreambooth_lora_latent_caching(self): - with tempfile.TemporaryDirectory() as tmpdir: - test_args = f""" - {self.script_path} - --pretrained_model_name_or_path {self.pretrained_model_name_or_path} - --instance_data_dir {self.instance_data_dir} - --instance_prompt {self.instance_prompt} - --resolution 64 - --train_batch_size 1 - --gradient_accumulation_steps 1 - --max_train_steps 2 - --cache_latents - --learning_rate 5.0e-04 - --scale_lr - --lr_scheduler constant - --lr_warmup_steps 0 - --output_dir {tmpdir} - """.split() - - run_command(self._launch_args + test_args) - # save_pretrained smoke test - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) - - # make sure the state_dict has the correct naming in the parameters. - lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) - is_lora = all("lora" in k for k in lora_state_dict.keys()) - self.assertTrue(is_lora) - - # when not training the text encoder, all the parameters in the state dict should start - # with `"transformer"` in their names. - starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) - self.assertTrue(starts_with_transformer) - def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f""" @@ -195,4 +162,4 @@ def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit_removes_mult run_command(self._launch_args + resume_run_args) - self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"}) + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"}) \ No newline at end of file diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 49f8fbb94b5f..7e2c60e90054 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -25,7 +25,6 @@ import warnings from contextlib import nullcontext from pathlib import Path -from typing import Union import numpy as np import torch @@ -600,12 +599,6 @@ def parse_args(input_args=None): " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" ), ) - parser.add_argument( - "--cache_latents", - action="store_true", - default=False, - help="Cache the VAE latents", - ) parser.add_argument( "--report_to", type=str, @@ -1463,24 +1456,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) - vae_config_shift_factor = vae.config.shift_factor - vae_config_scaling_factor = vae.config.scaling_factor - vae_config_block_out_channels = vae.config.block_out_channels - if args.cache_latents: - latents_cache = [] - for batch in tqdm(train_dataloader, desc="Caching latents"): - with torch.no_grad(): - batch["pixel_values"] = batch["pixel_values"].to( - accelerator.device, non_blocking=True, dtype=weight_dtype - ) - latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) - - if args.validation_prompt is None: - del vae - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -1603,6 +1578,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.train_text_encoder: models_to_accumulate.extend([text_encoder_one]) with accelerator.accumulate(models_to_accumulate): + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) prompts = batch["prompts"] # encode batch prompts when custom prompts are provided for each image - @@ -1636,15 +1612,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) # Convert images to latent space - if args.cache_latents: - model_input = latents_cache[step].sample() - else: - pixel_values = batch["pixel_values"].to(dtype=vae.dtype) - model_input = vae.encode(pixel_values).latent_dist.sample() - model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor + model_input = vae.encode(pixel_values).latent_dist.sample() + model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor model_input = model_input.to(dtype=weight_dtype) - vae_scale_factor = 2 ** (len(vae_config_block_out_channels)) + vae_scale_factor = 2 ** (len(vae.config.block_out_channels)) latent_image_ids = FluxPipeline._prepare_latent_image_ids( model_input.shape[0], @@ -1823,7 +1795,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator.wait_for_everyone() if accelerator.is_main_process: transformer = unwrap_model(transformer) - transformer = transformer.to(weight_dtype) + transformer = transformer.to(torch.float32) transformer_lora_layers = get_peft_model_state_dict(transformer) if args.train_text_encoder: From dc9be5b976cebabf9fa2010a304bd821dce6d46b Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 13 Sep 2024 19:51:39 +0300 Subject: [PATCH 20/82] revert changes to canonical script to keep it separate from https://github.com/huggingface/diffusers/pull/9160 --- examples/dreambooth/test_dreambooth_lora_flux.py | 2 +- examples/dreambooth/train_dreambooth_lora_flux.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/test_dreambooth_lora_flux.py b/examples/dreambooth/test_dreambooth_lora_flux.py index 11de0b72cc0a..b77f84447aaa 100644 --- a/examples/dreambooth/test_dreambooth_lora_flux.py +++ b/examples/dreambooth/test_dreambooth_lora_flux.py @@ -162,4 +162,4 @@ def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit_removes_mult run_command(self._launch_args + resume_run_args) - self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"}) \ No newline at end of file + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"}) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 7e2c60e90054..48d669418fd8 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1856,4 +1856,4 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if __name__ == "__main__": args = parse_args() - main(args) \ No newline at end of file + main(args) From a25cb907b53cc02a73af396eb5a738936ac37845 Mon Sep 17 00:00:00 2001 From: Linoy Date: Fri, 13 Sep 2024 16:55:09 +0000 Subject: [PATCH 21/82] style --- .../train_dreambooth_lora_flux_advanced.py | 49 +++++++++++-------- src/diffusers/pipelines/flux/pipeline_flux.py | 7 ++- 2 files changed, 34 insertions(+), 22 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index c102d1a13613..50ab8b62fbd7 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -26,7 +26,7 @@ import warnings from contextlib import nullcontext from pathlib import Path -from typing import Union, List +from typing import List, Union import numpy as np import torch @@ -41,7 +41,7 @@ from peft.utils import get_peft_model_state_dict from PIL import Image from PIL.ImageOps import exif_transpose -from safetensors.torch import load_file, save_file +from safetensors.torch import save_file from torch.utils.data import Dataset from torchvision import transforms from torchvision.transforms.functional import crop @@ -106,9 +106,9 @@ def save_model_card( embeddings_filename = f"{repo_folder}_emb" ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"", instance_prompt)) trigger_str = ( - "To trigger image generation of trained concept(or concepts) replace each concept identifier " - "in you prompt with the new inserted tokens:\n" - ) + "To trigger image generation of trained concept(or concepts) replace each concept identifier " + "in you prompt with the new inserted tokens:\n" + ) diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download from safetensors.torch import load_file """ @@ -348,8 +348,8 @@ def parse_args(input_args=None): type=str, default="TOK", help="identifier specifying the instance(or instances) as used in instance_prompt, validation prompt, " - "captions - e.g. TOK. To use multiple identifiers, please specify them in a comma separated string - e.g. " - "'TOK,TOK2,TOK3' etc.", + "captions - e.g. TOK. To use multiple identifiers, please specify them in a comma separated string - e.g. " + "'TOK,TOK2,TOK3' etc.", ) parser.add_argument( @@ -357,8 +357,8 @@ def parse_args(input_args=None): type=int, default=2, help="number of new tokens inserted to the tokenizers per token_abstraction identifier when " - "--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new " - "tokens - ", + "--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new " + "tokens - ", ) parser.add_argument( "--class_prompt", @@ -741,6 +741,7 @@ def parse_args(input_args=None): return args + # Modified from https://github.com/replicate/cog-sdxl/blob/main/dataset_and_utils.py class TokenEmbeddingsHandler: def __init__(self, text_encoders, tokenizers): @@ -1000,7 +1001,7 @@ def __getitem__(self, index): else: example["instance_prompt"] = self.instance_prompt - else: # the given instance prompt is used for all images + else: # the given instance prompt is used for all images example["instance_prompt"] = self.instance_prompt if self.class_data_root: @@ -1164,7 +1165,7 @@ def encode_prompt( prompt=prompt, device=device if device is not None else text_encoders[0].device, num_images_per_prompt=num_images_per_prompt, - text_input_ids=text_input_ids_list[0] if text_input_ids_list is not None else None + text_input_ids=text_input_ids_list[0] if text_input_ids_list is not None else None, ) prompt_embeds = _encode_prompt_with_t5( @@ -1174,7 +1175,7 @@ def encode_prompt( prompt=prompt, num_images_per_prompt=num_images_per_prompt, device=device if device is not None else text_encoders[1].device, - text_input_ids=text_input_ids_list[1] if text_input_ids_list is not None else None + text_input_ids=text_input_ids_list[1] if text_input_ids_list is not None else None, ) text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) @@ -1453,9 +1454,7 @@ def main(args): args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement)) # initialize the new tokens for textual inversion - embedding_handler = TokenEmbeddingsHandler( - [text_encoder_one], [tokenizer_one] - ) + embedding_handler = TokenEmbeddingsHandler([text_encoder_one], [tokenizer_one]) inserting_toks = [] for new_tok in token_abstraction_dict.values(): inserting_toks.extend(new_tok) @@ -1511,7 +1510,7 @@ def main(args): # if we use textual inversion, we freeze all parameters except for the token embeddings # in text encoder elif args.train_text_encoder_ti: - text_lora_parameters_one = [] # for now only for CLIP + text_lora_parameters_one = [] # for now only for CLIP for name, param in text_encoder_one.named_parameters(): if "token_embedding" in name: # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 @@ -1536,7 +1535,7 @@ def save_model_hook(models, weights, output_dir): if isinstance(model, type(unwrap_model(transformer))): transformer_lora_layers_to_save = get_peft_model_state_dict(model) elif isinstance(model, type(unwrap_model(text_encoder_one))): - if args.train_text_encoder: # when --train_text_encoder_ti we don't save the layers + if args.train_text_encoder: # when --train_text_encoder_ti we don't save the layers text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1790,12 +1789,16 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) # we need to tokenize and encode the batch prompts on all training steps else: - tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt, max_sequence_length=77, add_special_tokens=add_special_tokens) + tokens_one = tokenize_prompt( + tokenizer_one, args.instance_prompt, max_sequence_length=77, add_special_tokens=add_special_tokens + ) tokens_two = tokenize_prompt( tokenizer_two, args.instance_prompt, max_sequence_length=args.max_sequence_length ) if args.with_prior_preservation: - class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt, max_sequence_length=77, add_special_tokens=add_special_tokens) + class_tokens_one = tokenize_prompt( + tokenizer_one, args.class_prompt, max_sequence_length=77, add_special_tokens=add_special_tokens + ) class_tokens_two = tokenize_prompt( tokenizer_two, args.class_prompt, max_sequence_length=args.max_sequence_length ) @@ -1970,9 +1973,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): prompts, text_encoders, tokenizers ) else: - tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77, add_special_tokens=add_special_tokens) + tokens_one = tokenize_prompt( + tokenizer_one, prompts, max_sequence_length=77, add_special_tokens=add_special_tokens + ) tokens_two = tokenize_prompt( - tokenizer_two, prompts, max_sequence_length=args.max_sequence_length, + tokenizer_two, + prompts, + max_sequence_length=args.max_sequence_length, ) if not freeze_text_encoder: diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 9a6050c763a9..02e378a99a41 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -137,7 +137,12 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin,): +class FluxPipeline( + DiffusionPipeline, + FluxLoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, +): r""" The Flux pipeline for text-to-image generation. From fd75b67813d3989332256eaf712eee74711e72fd Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 13 Sep 2024 23:18:37 +0300 Subject: [PATCH 22/82] remove redundant line and change code block placement to align with logic --- .../train_dreambooth_lora_flux_advanced.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 50ab8b62fbd7..4b9b828f9425 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -1507,18 +1507,7 @@ def main(args): target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], ) text_encoder_one.add_adapter(text_lora_config) - # if we use textual inversion, we freeze all parameters except for the token embeddings - # in text encoder - elif args.train_text_encoder_ti: - text_lora_parameters_one = [] # for now only for CLIP - for name, param in text_encoder_one.named_parameters(): - if "token_embedding" in name: - # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 - param.data = param.to(dtype=torch.float32) - param.requires_grad = True - text_lora_parameters_one.append(param) - else: - param.requires_grad = False + def unwrap_model(model): model = accelerator.unwrap_model(model) @@ -1619,6 +1608,18 @@ def load_model_hook(models, input_dir): if args.train_text_encoder: text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters())) + # if we use textual inversion, we freeze all parameters except for the token embeddings + # in text encoder + elif args.train_text_encoder_ti: + text_lora_parameters_one = [] # for now only for CLIP + for name, param in text_encoder_one.named_parameters(): + if "token_embedding" in name: + # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 + param.data = param.to(dtype=torch.float32) + param.requires_grad = True + text_lora_parameters_one.append(param) + else: + param.requires_grad = False # If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training freeze_text_encoder = not (args.train_text_encoder or args.train_text_encoder_ti) @@ -1693,10 +1694,9 @@ def load_model_hook(models, input_dir): f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. " f"When using prodigy only learning_rate is used as the initial learning rate." ) - # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be + # changes the learning rate of text_encoder_parameters_one to be # --learning_rate params_to_optimize[1]["lr"] = args.learning_rate - params_to_optimize[2]["lr"] = args.learning_rate optimizer = optimizer_class( params_to_optimize, @@ -1960,7 +1960,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): for step, batch in enumerate(train_dataloader): if pivoted: # stopping optimization of text_encoder params - # re setting the optimizer to optimize only on unet params + # re setting the optimizer to optimize only on transformer params optimizer.param_groups[1]["lr"] = 0.0 with accelerator.accumulate(transformer): From 238ed7049a96418b1ffdef650977310d0bafab36 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 16 Sep 2024 09:58:06 +0300 Subject: [PATCH 23/82] add initializer_token arg --- .../train_dreambooth_lora_flux_advanced.py | 42 +++++++++++++++---- 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 4b9b828f9425..02eb1b3fe120 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -360,6 +360,16 @@ def parse_args(input_args=None): "--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new " "tokens - ", ) + parser.add_argument( + "--initializer_token", + type=str, + default="random", + help="the token (or tokens) to use to initialize the new inserted tokens when training with " + "--train_text_encoder_ti = True. By default, new tokens () are initialized with random value. " + "Alternatively, you could specify a different token whos value will be used as the starting point for the new inserted tokens" + "to do so, please specify the initializer tokens in a comma seperated string - e.g. 'random,dog,illustration'." + " such that the order of the initializers matches the order of the identifiers specified in --token_abstraction." + ) parser.add_argument( "--class_prompt", type=str, @@ -763,8 +773,10 @@ def initialize_new_tokens(self, inserting_toks: List[str]): self.inserting_toks = inserting_toks special_tokens_dict = {"additional_special_tokens": self.inserting_toks} tokenizer.add_special_tokens(special_tokens_dict) + # Resize the token embeddings as we are adding new special tokens to the tokenizer text_encoder.resize_token_embeddings(len(tokenizer)) + # Convert the token abstractions to ids self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks) # random initialization of new tokens @@ -772,21 +784,33 @@ def initialize_new_tokens(self, inserting_toks: List[str]): print(f"{idx} text encoder's std_token_embedding: {std_token_embedding}") - text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = ( - torch.randn(len(self.train_ids), text_encoder.text_model.config.hidden_size) - .to(device=self.device) - .to(dtype=self.dtype) - * std_token_embedding - ) + if args.initializer_token.lower == "random": + text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = ( + torch.randn(len(self.train_ids), text_encoder.text_model.config.hidden_size) + .to(device=self.device) + .to(dtype=self.dtype) + * std_token_embedding + ) + else: + # Convert the initializer_token, placeholder_token to ids + token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) + # Check if initializer_token is a single token or a sequence of tokens + if len(token_ids) > 1: + raise ValueError("The initializer token must be a single token.") + initializer_token_id = token_ids[0] + text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = ( + text_encoder.text_model.embeddings.token_embedding.weight.data)[initializer_token_id].clone() + self.embeddings_settings[ f"original_embeddings_{idx}" ] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone() self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding - inu = torch.ones((len(tokenizer),), dtype=torch.bool) - inu[self.train_ids] = False + # makes sure we don't update any embedding weights besides the newly added token + index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool) + index_no_updates[self.train_ids] = False - self.embeddings_settings[f"index_no_updates_{idx}"] = inu + self.embeddings_settings[f"index_no_updates_{idx}"] = index_no_updates print(self.embeddings_settings[f"index_no_updates_{idx}"].shape) From 4bf3a138cfc7ebce1714f37d97dfd47199fc6990 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 16 Sep 2024 11:21:34 +0300 Subject: [PATCH 24/82] add transformer frac for range support from pure textual inversion to the orig pivotal tuning --- .../train_dreambooth_lora_flux_advanced.py | 36 ++++++++++++++----- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 02eb1b3fe120..436d96089dc1 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -479,6 +479,13 @@ def parse_args(input_args=None): default=1.0, help=("The percentage of epochs to perform text encoder tuning"), ) + parser.add_argument( + "--train_transformer_frac", + type=float, + default=1.0, + help=("The percentage of epochs to perform transformer tuning"), + ) + parser.add_argument( "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) @@ -733,6 +740,10 @@ def parse_args(input_args=None): "For full LoRA text encoder training check --train_text_encoder, for textual " "inversion training check `--train_text_encoder_ti`" ) + if args.train_transformer_frac == 0 and not (args.train_text_encoder or args.train_text_encoder_ti): + raise ValueError( + "--train_transformer_frac must be > 0 if text_encoder training / textual inversion is not enabled" + ) env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank @@ -1959,21 +1970,28 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.train_text_encoder: num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs) + num_train_epochs_transformer = int(args.train_transformer_frac * args.num_train_epochs) elif args.train_text_encoder_ti: # args.train_text_encoder_ti num_train_epochs_text_encoder = int(args.train_text_encoder_ti_frac * args.num_train_epochs) - # flag used for textual inversion - pivoted = False + num_train_epochs_transformer = int(args.train_transformer_frac * args.num_train_epochs) + # flag used for textual inversion + pivoted_te = False + pivoted_tr = False for epoch in range(first_epoch, args.num_train_epochs): - transformer.train() + if epoch == num_train_epochs_transformer: + print("PIVOT TRANSFORMER", epoch) + # stopping optimization of transformer params + pivoted_tr = True + else: + transformer.train() # if performing any kind of optimization of text_encoder params if args.train_text_encoder or args.train_text_encoder_ti: if epoch == num_train_epochs_text_encoder: - print("PIVOT HALFWAY", epoch) + print("PIVOT TE", epoch) # stopping optimization of text_encoder params - # this flag is used to reset the optimizer to optimize only on unet params - pivoted = True - + # this flag is used to reset the optimizer to optimize only on transformer params + pivoted_te = True else: # still optimizing the text encoder text_encoder_one.train() @@ -1982,10 +2000,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) for step, batch in enumerate(train_dataloader): - if pivoted: + if pivoted_te: # stopping optimization of text_encoder params # re setting the optimizer to optimize only on transformer params optimizer.param_groups[1]["lr"] = 0.0 + elif pivoted_tr: + optimizer.param_groups[0]["lr"] = 0.0 with accelerator.accumulate(transformer): prompts = batch["prompts"] From 62b8ab8a5e1740977d9fce8e7d7ab8ef54d4c960 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 17 Sep 2024 11:47:48 +0300 Subject: [PATCH 25/82] support pure textual inversion - wip --- .../train_dreambooth_lora_flux_advanced.py | 44 ++++++++++++++----- 1 file changed, 33 insertions(+), 11 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 436d96089dc1..e05b9821f2ba 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -86,6 +86,7 @@ def save_model_card( base_model: str = None, train_text_encoder=False, train_text_encoder_ti=False, + pure_textual_inversion=False, token_abstraction_dict=None, instance_prompt=None, validation_prompt=None, @@ -100,8 +101,11 @@ def save_model_card( widget_dict.append( {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} ) + diffusers_load_lora = "" diffusers_imports_pivotal = "" diffusers_example_pivotal = "" + if not pure_textual_inversion: + diffusers_load_lora = f"""pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')""" if train_text_encoder_ti: embeddings_filename = f"{repo_folder}_emb" ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"", instance_prompt)) @@ -153,7 +157,7 @@ def save_model_card( import torch {diffusers_imports_pivotal} pipeline = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to('cuda') -pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors') +{diffusers_load_lora} {diffusers_example_pivotal} image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0] ``` @@ -1659,6 +1663,10 @@ def load_model_hook(models, input_dir): # If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training freeze_text_encoder = not (args.train_text_encoder or args.train_text_encoder_ti) + # if --train_text_encoder_ti and train_transformer_frac == 0 where essntially performing textual inversion + # and not training transformer LoRA layers + freeze_transformer = args.train_text_encoder_ti and int(args.train_transformer_frac) == 0 + # Optimization parameters transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} if not freeze_text_encoder: @@ -1670,12 +1678,19 @@ def load_model_hook(models, input_dir): else args.adam_weight_decay, "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, } + if not freeze_transformer: + params_to_optimize = [ + transformer_parameters_with_lr, + text_parameters_one_with_lr, + ] + else: + params_to_optimize = [ + text_parameters_one_with_lr + ] + else: params_to_optimize = [ transformer_parameters_with_lr, - text_parameters_one_with_lr, ] - else: - params_to_optimize = [transformer_parameters_with_lr] # Optimizer creation if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): @@ -2241,11 +2256,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): else: text_encoder_lora_layers = None - FluxPipeline.save_lora_weights( - save_directory=args.output_dir, - transformer_lora_layers=transformer_lora_layers, - text_encoder_lora_layers=text_encoder_lora_layers, - ) + if not freeze_transformer: + FluxPipeline.save_lora_weights( + save_directory=args.output_dir, + transformer_lora_layers=transformer_lora_layers, + text_encoder_lora_layers=text_encoder_lora_layers, + ) if args.train_text_encoder_ti: embeddings_path = f"{args.output_dir}/{args.output_dir}_emb.safetensors" @@ -2259,8 +2275,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): variant=args.variant, torch_dtype=weight_dtype, ) - # load attention processors - pipeline.load_lora_weights(args.output_dir) + if not freeze_transformer: + # load attention processors + pipeline.load_lora_weights(args.output_dir) + if args.train_text_encoder_ti: + # load embeddings + pass + # run inference images = [] @@ -2281,6 +2302,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): base_model=args.pretrained_model_name_or_path, train_text_encoder=args.train_text_encoder, train_text_encoder_ti=args.train_text_encoder_ti, + pure_textual_inversion=freeze_transformer, token_abstraction_dict=train_dataset.token_abstraction_dict, instance_prompt=args.instance_prompt, validation_prompt=args.validation_prompt, From 30cc65171a7be82a7df18b19091ff269391391a2 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 17 Sep 2024 13:46:14 +0300 Subject: [PATCH 26/82] adjustments to support pure textual inversion and transformer optimization in only part of the epochs --- .../train_dreambooth_lora_flux_advanced.py | 38 ++++++++++--------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index e05b9821f2ba..2de41221c4a0 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -744,10 +744,13 @@ def parse_args(input_args=None): "For full LoRA text encoder training check --train_text_encoder, for textual " "inversion training check `--train_text_encoder_ti`" ) - if args.train_transformer_frac == 0 and not (args.train_text_encoder or args.train_text_encoder_ti): + if args.train_transformer_frac < 1 and not args.train_text_encoder_ti: raise ValueError( - "--train_transformer_frac must be > 0 if text_encoder training / textual inversion is not enabled" + "--train_transformer_frac must be > 0 if text_encoder training / textual inversion is not enabled." ) + if args.train_transformer_frac < 1 and args.train_text_encoder_ti_frac < 1: + raise ValueError("--train_transformer_frac and --train_text_encoder_ti_frac are identical and smaller than 1. This contradicts with --max_train_steps, please specify different values or set both to 1.") + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank @@ -1665,7 +1668,7 @@ def load_model_hook(models, input_dir): # if --train_text_encoder_ti and train_transformer_frac == 0 where essntially performing textual inversion # and not training transformer LoRA layers - freeze_transformer = args.train_text_encoder_ti and int(args.train_transformer_frac) == 0 + pure_textual_inversion = args.train_text_encoder_ti and int(args.train_transformer_frac) == 0 # Optimization parameters transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} @@ -1678,15 +1681,17 @@ def load_model_hook(models, input_dir): else args.adam_weight_decay, "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, } - if not freeze_transformer: + if not pure_textual_inversion: params_to_optimize = [ transformer_parameters_with_lr, text_parameters_one_with_lr, ] + te_idx = 1 else: params_to_optimize = [ text_parameters_one_with_lr ] + te_idx = 0 else: params_to_optimize = [ transformer_parameters_with_lr, @@ -1994,19 +1999,17 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): pivoted_te = False pivoted_tr = False for epoch in range(first_epoch, args.num_train_epochs): - if epoch == num_train_epochs_transformer: - print("PIVOT TRANSFORMER", epoch) - # stopping optimization of transformer params - pivoted_tr = True - else: - transformer.train() + transformer.train() # if performing any kind of optimization of text_encoder params if args.train_text_encoder or args.train_text_encoder_ti: if epoch == num_train_epochs_text_encoder: + # flag to stop text encoder optimization print("PIVOT TE", epoch) - # stopping optimization of text_encoder params - # this flag is used to reset the optimizer to optimize only on transformer params pivoted_te = True + if epoch == num_train_epochs_transformer: + # flag to stop transformer optimization + print("PIVOT TRANSFORMER", epoch) + pivoted_tr = True else: # still optimizing the text encoder text_encoder_one.train() @@ -2017,9 +2020,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): for step, batch in enumerate(train_dataloader): if pivoted_te: # stopping optimization of text_encoder params - # re setting the optimizer to optimize only on transformer params - optimizer.param_groups[1]["lr"] = 0.0 - elif pivoted_tr: + optimizer.param_groups[te_idx]["lr"] = 0.0 + elif pivoted_tr and not pure_textual_inversion: optimizer.param_groups[0]["lr"] = 0.0 with accelerator.accumulate(transformer): @@ -2256,7 +2258,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): else: text_encoder_lora_layers = None - if not freeze_transformer: + if not pure_textual_inversion: FluxPipeline.save_lora_weights( save_directory=args.output_dir, transformer_lora_layers=transformer_lora_layers, @@ -2275,7 +2277,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): variant=args.variant, torch_dtype=weight_dtype, ) - if not freeze_transformer: + if not pure_textual_inversion: # load attention processors pipeline.load_lora_weights(args.output_dir) if args.train_text_encoder_ti: @@ -2302,7 +2304,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): base_model=args.pretrained_model_name_or_path, train_text_encoder=args.train_text_encoder, train_text_encoder_ti=args.train_text_encoder_ti, - pure_textual_inversion=freeze_transformer, + pure_textual_inversion=pure_textual_inversion, token_abstraction_dict=train_dataset.token_abstraction_dict, instance_prompt=args.instance_prompt, validation_prompt=args.validation_prompt, From 35ac0f739ae44aaa9e4f3f013402421daec566e3 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 17 Sep 2024 14:29:21 +0300 Subject: [PATCH 27/82] fix logic when using initializer token --- .../train_dreambooth_lora_flux_advanced.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 2de41221c4a0..ebf6d6756cd7 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -371,8 +371,6 @@ def parse_args(input_args=None): help="the token (or tokens) to use to initialize the new inserted tokens when training with " "--train_text_encoder_ti = True. By default, new tokens () are initialized with random value. " "Alternatively, you could specify a different token whos value will be used as the starting point for the new inserted tokens" - "to do so, please specify the initializer tokens in a comma seperated string - e.g. 'random,dog,illustration'." - " such that the order of the initializers matches the order of the identifiers specified in --token_abstraction." ) parser.add_argument( "--class_prompt", @@ -749,7 +747,8 @@ def parse_args(input_args=None): "--train_transformer_frac must be > 0 if text_encoder training / textual inversion is not enabled." ) if args.train_transformer_frac < 1 and args.train_text_encoder_ti_frac < 1: - raise ValueError("--train_transformer_frac and --train_text_encoder_ti_frac are identical and smaller than 1. This contradicts with --max_train_steps, please specify different values or set both to 1.") + raise ValueError("--train_transformer_frac and --train_text_encoder_ti_frac are identical and smaller than 1. " + "This contradicts with --max_train_steps, please specify different values or set both to 1.") env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: @@ -816,8 +815,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]): if len(token_ids) > 1: raise ValueError("The initializer token must be a single token.") initializer_token_id = token_ids[0] - text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = ( - text_encoder.text_model.embeddings.token_embedding.weight.data)[initializer_token_id].clone() + for token_id in self.train_ids: + text_encoder.text_model.embeddings.token_embedding.weight.data[token_id] = ( + text_encoder.text_model.embeddings.token_embedding.weight.data)[initializer_token_id].clone() self.embeddings_settings[ f"original_embeddings_{idx}" From 7e91489128a2bd6c432d84e6abed11589e7df149 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 17 Sep 2024 14:36:59 +0300 Subject: [PATCH 28/82] fix pure_textual_inversion_condition --- .../train_dreambooth_lora_flux_advanced.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index ebf6d6756cd7..a57dd26d76a8 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -1668,7 +1668,7 @@ def load_model_hook(models, input_dir): # if --train_text_encoder_ti and train_transformer_frac == 0 where essntially performing textual inversion # and not training transformer LoRA layers - pure_textual_inversion = args.train_text_encoder_ti and int(args.train_transformer_frac) == 0 + pure_textual_inversion = args.train_text_encoder_ti and args.train_transformer_frac == 0 # Optimization parameters transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} @@ -2022,6 +2022,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # stopping optimization of text_encoder params optimizer.param_groups[te_idx]["lr"] = 0.0 elif pivoted_tr and not pure_textual_inversion: + print("PIVOT TRANSFORMER HELOOOO") optimizer.param_groups[0]["lr"] = 0.0 with accelerator.accumulate(transformer): From 8775bea3a386a2036d7849109794afd1775fbdfe Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 17 Sep 2024 14:53:02 +0300 Subject: [PATCH 29/82] fix ti/pivotal loading of last validation run --- .../train_dreambooth_lora_flux_advanced.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index a57dd26d76a8..eef25a9cb371 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -2283,8 +2283,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): pipeline.load_lora_weights(args.output_dir) if args.train_text_encoder_ti: # load embeddings - pass - + tokens = list(itertools.chain.from_iterable(train_dataset.token_abstraction_dict.values())) + embedding_path = hf_hub_download(repo_id=repo_id, + filename=f"{args.output_dir}_emb.safetensors", + repo_type="model") + state_dict = load_file(embedding_path) + pipeline.load_textual_inversion(state_dict["clip_l"], token=tokens, + text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer) # run inference images = [] From 00bbb58e81f3ab39c8a77c09ed83a13067c4914d Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 18 Sep 2024 10:18:55 +0300 Subject: [PATCH 30/82] remove embeddings loading for ti in final training run (to avoid adding huggingface hub dependency) --- .../train_dreambooth_lora_flux_advanced.py | 22 +++++++------------ 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index eef25a9cb371..aa2282bd34a3 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -2161,11 +2161,14 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator.backward(loss) if accelerator.sync_gradients: - params_to_clip = ( - itertools.chain(transformer.parameters(), text_encoder_one.parameters()) - if not freeze_text_encoder - else transformer.parameters() - ) + if not freeze_text_encoder: + if pure_textual_inversion: + params_to_clip = (text_encoder_one.parameters()) + else: + params_to_clip = ( + itertools.chain(transformer.parameters(), text_encoder_one.parameters())) + else: + params_to_clip = (transformer.parameters()) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() @@ -2281,15 +2284,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if not pure_textual_inversion: # load attention processors pipeline.load_lora_weights(args.output_dir) - if args.train_text_encoder_ti: - # load embeddings - tokens = list(itertools.chain.from_iterable(train_dataset.token_abstraction_dict.values())) - embedding_path = hf_hub_download(repo_id=repo_id, - filename=f"{args.output_dir}_emb.safetensors", - repo_type="model") - state_dict = load_file(embedding_path) - pipeline.load_textual_inversion(state_dict["clip_l"], token=tokens, - text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer) # run inference images = [] From 67e1bf78a974a0e59c51abdac556407d7eb270e7 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 18 Sep 2024 12:08:43 +0300 Subject: [PATCH 31/82] support pivotal for t5 --- .../train_dreambooth_lora_flux_advanced.py | 73 ++++++++++++++----- 1 file changed, 55 insertions(+), 18 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index aa2282bd34a3..be1c2a94ca0d 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -119,6 +119,7 @@ def save_model_card( diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors', repo_type="model") state_dict = load_file(embedding_path) pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer) + pipeline.load_textual_inversion(state_dict["t5"], token=[{ti_keys}], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2) """ if token_abstraction_dict: for key, value in token_abstraction_dict.items(): @@ -837,8 +838,8 @@ def initialize_new_tokens(self, inserting_toks: List[str]): def save_embeddings(self, file_path: str): assert self.train_ids is not None, "Initialize new tokens before saving embeddings." tensors = {} - # text_encoder_0 - CLIP ViT-L/14, for now only optimizing and saving embeddings for CLIP (text_encoder_two - T5, remains untouched) - idx_to_text_encoder_name = {0: "clip_l"} + # text_encoder_one, idx==0 - CLIP ViT-L/14, text_encoder_two, idx==1 - T5 xxl + idx_to_text_encoder_name = {0: "clip_l", 1: "t5"} for idx, text_encoder in enumerate(self.text_encoders): assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len( self.tokenizers[0] @@ -1496,7 +1497,7 @@ def main(args): args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement)) # initialize the new tokens for textual inversion - embedding_handler = TokenEmbeddingsHandler([text_encoder_one], [tokenizer_one]) + embedding_handler = TokenEmbeddingsHandler([text_encoder_one,text_encoder_two], [tokenizer_one, tokenizer_two]) inserting_toks = [] for new_tok in token_abstraction_dict.values(): inserting_toks.extend(new_tok) @@ -1662,6 +1663,15 @@ def load_model_hook(models, input_dir): text_lora_parameters_one.append(param) else: param.requires_grad = False + text_lora_parameters_two = [] # for now only for CLIP + for name, param in text_encoder_two.named_parameters(): + if "token_embedding" in name: + # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 + param.data = param.to(dtype=torch.float32) + param.requires_grad = True + text_lora_parameters_two.append(param) + else: + param.requires_grad = False # If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training freeze_text_encoder = not (args.train_text_encoder or args.train_text_encoder_ti) @@ -1681,17 +1691,33 @@ def load_model_hook(models, input_dir): else args.adam_weight_decay, "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, } - if not pure_textual_inversion: + if args.train_text_encoder: params_to_optimize = [ - transformer_parameters_with_lr, - text_parameters_one_with_lr, - ] + transformer_parameters_with_lr, + text_parameters_one_with_lr, + ] te_idx = 1 - else: - params_to_optimize = [ - text_parameters_one_with_lr - ] - te_idx = 0 + elif args.train_text_encoder_ti: + text_parameters_two_with_lr = { + "params": text_lora_parameters_two, + "weight_decay": args.adam_weight_decay_text_encoder + if args.adam_weight_decay_text_encoder + else args.adam_weight_decay, + "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, + } + if pure_textual_inversion: + params_to_optimize = [ + text_parameters_one_with_lr, + text_parameters_two_with_lr + ] + te_idx = 0 + else: + params_to_optimize = [ + transformer_parameters_with_lr, + text_parameters_one_with_lr, + text_parameters_two_with_lr + ] + te_idx = 1 else: params_to_optimize = [ transformer_parameters_with_lr, @@ -1899,12 +1925,14 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): ( transformer, text_encoder_one, + text_encoder_two, optimizer, train_dataloader, lr_scheduler, ) = accelerator.prepare( transformer, text_encoder_one, + text_encoder_two, optimizer, train_dataloader, lr_scheduler, @@ -2012,15 +2040,21 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): pivoted_tr = True else: # still optimizing the text encoder - text_encoder_one.train() - # set top parameter requires_grad = True for gradient checkpointing works if args.train_text_encoder: + text_encoder_one.train() + # set top parameter requires_grad = True for gradient checkpointing works accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) + else: # textual inversion / pivotal tuning + text_encoder_one.train() + text_encoder_two.train() + for step, batch in enumerate(train_dataloader): if pivoted_te: # stopping optimization of text_encoder params optimizer.param_groups[te_idx]["lr"] = 0.0 + if args.train_text_encoder_ti: + optimizer.param_groups[te_idx+1]["lr"] = 0.0 elif pivoted_tr and not pure_textual_inversion: print("PIVOT TRANSFORMER HELOOOO") optimizer.param_groups[0]["lr"] = 0.0 @@ -2162,11 +2196,14 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator.backward(loss) if accelerator.sync_gradients: if not freeze_text_encoder: - if pure_textual_inversion: - params_to_clip = (text_encoder_one.parameters()) - else: + if args.train_text_encoder: params_to_clip = ( itertools.chain(transformer.parameters(), text_encoder_one.parameters())) + elif pure_textual_inversion: + params_to_clip = (text_encoder_one.parameters(), text_encoder_two.parameters()) + else: + params_to_clip = ( + itertools.chain(transformer.parameters(), text_encoder_one.parameters(), text_encoder_two.parameters())) else: params_to_clip = (transformer.parameters()) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) @@ -2244,7 +2281,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): del text_encoder_one, text_encoder_two torch.cuda.empty_cache() gc.collect() - else: + elif args.train_text_encoder: del text_encoder_two torch.cuda.empty_cache() gc.collect() From e00b30fd7735d2c238978eed48f9681993ba1abd Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 18 Sep 2024 14:53:27 +0300 Subject: [PATCH 32/82] adapt pivotal for T5 encoder --- .../train_dreambooth_lora_flux_advanced.py | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index be1c2a94ca0d..b59a3bef0aba 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -798,13 +798,15 @@ def initialize_new_tokens(self, inserting_toks: List[str]): self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks) # random initialization of new tokens - std_token_embedding = text_encoder.text_model.embeddings.token_embedding.weight.data.std() + embeds = text_encoder.text_model.embeddings.token_embedding if idx==0 else text_encoder.encoder.embed_tokens + std_token_embedding =embeds.weight.data.std() print(f"{idx} text encoder's std_token_embedding: {std_token_embedding}") if args.initializer_token.lower == "random": - text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = ( - torch.randn(len(self.train_ids), text_encoder.text_model.config.hidden_size) + hidden_size = text_encoder.text_model.config.hidden_size if idx ==0 else text_encoder.encoder.config.hidden_size + embeds.weight.data[self.train_ids] = ( + torch.randn(len(self.train_ids), hidden_size) .to(device=self.device) .to(dtype=self.dtype) * std_token_embedding @@ -817,12 +819,12 @@ def initialize_new_tokens(self, inserting_toks: List[str]): raise ValueError("The initializer token must be a single token.") initializer_token_id = token_ids[0] for token_id in self.train_ids: - text_encoder.text_model.embeddings.token_embedding.weight.data[token_id] = ( - text_encoder.text_model.embeddings.token_embedding.weight.data)[initializer_token_id].clone() + embeds.weight.data[token_id] = ( + embeds.weight.data)[initializer_token_id].clone() self.embeddings_settings[ f"original_embeddings_{idx}" - ] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone() + ] = embeds.weight.data.clone() self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding # makes sure we don't update any embedding weights besides the newly added token @@ -841,10 +843,11 @@ def save_embeddings(self, file_path: str): # text_encoder_one, idx==0 - CLIP ViT-L/14, text_encoder_two, idx==1 - T5 xxl idx_to_text_encoder_name = {0: "clip_l", 1: "t5"} for idx, text_encoder in enumerate(self.text_encoders): - assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len( + embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens + assert embeds.weight.data.shape[0] == len( self.tokenizers[0] ), "Tokenizers should be the same." - new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] + new_token_embeddings = embeds.weight.data[self.train_ids] # New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), # Note: When loading with diffusers, any name can work - simply specify in inference @@ -864,8 +867,9 @@ def device(self): @torch.no_grad() def retract_embeddings(self): for idx, text_encoder in enumerate(self.text_encoders): + embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"] - text_encoder.text_model.embeddings.token_embedding.weight.data[index_no_updates] = ( + embeds.weight.data[index_no_updates] = ( self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates] .to(device=text_encoder.device) .to(dtype=text_encoder.dtype) @@ -876,11 +880,11 @@ def retract_embeddings(self): std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"] index_updates = ~index_no_updates - new_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] + new_embeddings = embeds.weight.data[index_updates] off_ratio = std_token_embedding / new_embeddings.std() new_embeddings = new_embeddings * (off_ratio**0.1) - text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings + embeds.weight.data[index_updates] = new_embeddings class DreamBoothDataset(Dataset): From f4d6e9aff5698e5c0810c659f2c2836ce8dd7df7 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 18 Sep 2024 15:12:07 +0300 Subject: [PATCH 33/82] adapt pivotal for T5 encoder and support in flux pipeline --- .../train_dreambooth_lora_flux_advanced.py | 2 +- src/diffusers/pipelines/flux/pipeline_flux.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index b59a3bef0aba..63ef9c8a377d 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -845,7 +845,7 @@ def save_embeddings(self, file_path: str): for idx, text_encoder in enumerate(self.text_encoders): embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens assert embeds.weight.data.shape[0] == len( - self.tokenizers[0] + self.tokenizers[idx] ), "Tokenizers should be the same." new_token_embeddings = embeds.weight.data[self.train_ids] diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 02e378a99a41..1424965a4baa 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -217,6 +217,9 @@ def _get_t5_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + text_inputs = self.tokenizer_2( prompt, padding="max_length", From a01e566dcc18554d745a0d6186db26831800c248 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 19 Sep 2024 14:53:12 +0300 Subject: [PATCH 34/82] t5 pivotal support + support fo pivotal for clip only or both --- .../train_dreambooth_lora_flux_advanced.py | 118 +++++++++++------- 1 file changed, 76 insertions(+), 42 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 63ef9c8a377d..e40676f9e641 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -86,6 +86,7 @@ def save_model_card( base_model: str = None, train_text_encoder=False, train_text_encoder_ti=False, + enable_t5_ti = False, pure_textual_inversion=False, token_abstraction_dict=None, instance_prompt=None, @@ -116,11 +117,17 @@ def save_model_card( diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download from safetensors.torch import load_file """ - diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors', repo_type="model") + if enable_t5_ti: + diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors', repo_type="model") state_dict = load_file(embedding_path) pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer) pipeline.load_textual_inversion(state_dict["t5"], token=[{ti_keys}], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2) """ + else: + diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors', repo_type="model") + state_dict = load_file(embedding_path) + pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer) + """ if token_abstraction_dict: for key, value in token_abstraction_dict.items(): tokens = "".join(value) @@ -466,7 +473,12 @@ def parse_args(input_args=None): parser.add_argument( "--train_text_encoder_ti", action="store_true", - help=("Whether to use textual inversion"), + help=("Whether to use pivotal tuning / textual inversion"), + ) + parser.add_argument( + "--enable_t5_ti", + action="store_true", + help=("Whether to use pivotal tuning / textual inversion for the T5 encoder as well (in addition to CLIP encoder)"), ) parser.add_argument( @@ -750,6 +762,8 @@ def parse_args(input_args=None): if args.train_transformer_frac < 1 and args.train_text_encoder_ti_frac < 1: raise ValueError("--train_transformer_frac and --train_text_encoder_ti_frac are identical and smaller than 1. " "This contradicts with --max_train_steps, please specify different values or set both to 1.") + if args.enable_t5_ti and not args.train_text_encoder_ti: + warnings.warn("You need not use --enable_t5_ti without --train_text_encoder_ti.") env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: @@ -1501,7 +1515,9 @@ def main(args): args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement)) # initialize the new tokens for textual inversion - embedding_handler = TokenEmbeddingsHandler([text_encoder_one,text_encoder_two], [tokenizer_one, tokenizer_two]) + text_encoders = [text_encoder_one, text_encoder_two] if args.enable_t5_ti else [text_encoder_one] + tokenizers = [tokenizer_one, tokenizer_two] if args.enable_t5_ti else [tokenizer_one] + embedding_handler = TokenEmbeddingsHandler(text_encoders, tokenizers) inserting_toks = [] for new_tok in token_abstraction_dict.values(): inserting_toks.extend(new_tok) @@ -1658,7 +1674,7 @@ def load_model_hook(models, input_dir): # if we use textual inversion, we freeze all parameters except for the token embeddings # in text encoder elif args.train_text_encoder_ti: - text_lora_parameters_one = [] # for now only for CLIP + text_lora_parameters_one = [] # CLIP for name, param in text_encoder_one.named_parameters(): if "token_embedding" in name: # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 @@ -1667,15 +1683,16 @@ def load_model_hook(models, input_dir): text_lora_parameters_one.append(param) else: param.requires_grad = False - text_lora_parameters_two = [] # for now only for CLIP - for name, param in text_encoder_two.named_parameters(): - if "token_embedding" in name: - # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 - param.data = param.to(dtype=torch.float32) - param.requires_grad = True - text_lora_parameters_two.append(param) - else: - param.requires_grad = False + if args.enable_t5_ti: # whether to do pivotal tuning/textual inversion for T5 as well + text_lora_parameters_two = [] + for name, param in text_encoder_two.named_parameters(): + if "token_embedding" in name: + # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 + param.data = param.to(dtype=torch.float32) + param.requires_grad = True + text_lora_parameters_two.append(param) + else: + param.requires_grad = False # If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training freeze_text_encoder = not (args.train_text_encoder or args.train_text_encoder_ti) @@ -1695,13 +1712,21 @@ def load_model_hook(models, input_dir): else args.adam_weight_decay, "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, } - if args.train_text_encoder: - params_to_optimize = [ + if not args.enable_t5_ti: + # pure textual inversion - only clip + if pure_textual_inversion: + params_to_optimize = [ + text_parameters_one_with_lr, + ] + te_idx = 0 + else: # regular te training or regular pivotal for clip + params_to_optimize = [ transformer_parameters_with_lr, text_parameters_one_with_lr, ] - te_idx = 1 - elif args.train_text_encoder_ti: + te_idx = 1 + elif args.enable_t5_ti: + # pivotal tuning of clip & t5 text_parameters_two_with_lr = { "params": text_lora_parameters_two, "weight_decay": args.adam_weight_decay_text_encoder @@ -1709,13 +1734,14 @@ def load_model_hook(models, input_dir): else args.adam_weight_decay, "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, } + # pure textual inversion - only clip & t5 if pure_textual_inversion: params_to_optimize = [ text_parameters_one_with_lr, text_parameters_two_with_lr ] te_idx = 0 - else: + else: # regular pivotal tuning of clip & t5 params_to_optimize = [ transformer_parameters_with_lr, text_parameters_one_with_lr, @@ -1856,7 +1882,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): torch.cuda.empty_cache() # if --train_text_encoder_ti we need add_special_tokens to be True for textual inversion - add_special_tokens = True if args.train_text_encoder_ti else False + add_special_tokens_clip = True if args.train_text_encoder_ti else False + add_special_tokens_t5 = True if (args.train_text_encoder_ti and args.enable_t5_ti) else False # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't @@ -1875,17 +1902,17 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # we need to tokenize and encode the batch prompts on all training steps else: tokens_one = tokenize_prompt( - tokenizer_one, args.instance_prompt, max_sequence_length=77, add_special_tokens=add_special_tokens + tokenizer_one, args.instance_prompt, max_sequence_length=77, add_special_tokens=add_special_tokens_clip ) tokens_two = tokenize_prompt( - tokenizer_two, args.instance_prompt, max_sequence_length=args.max_sequence_length + tokenizer_two, args.instance_prompt, max_sequence_length=args.max_sequence_length, add_special_tokens=add_special_tokens_t5 ) if args.with_prior_preservation: class_tokens_one = tokenize_prompt( - tokenizer_one, args.class_prompt, max_sequence_length=77, add_special_tokens=add_special_tokens + tokenizer_one, args.class_prompt, max_sequence_length=77, add_special_tokens=add_special_tokens_clip ) class_tokens_two = tokenize_prompt( - tokenizer_two, args.class_prompt, max_sequence_length=args.max_sequence_length + tokenizer_two, args.class_prompt, max_sequence_length=args.max_sequence_length, add_special_tokens=add_special_tokens_t5 ) tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) @@ -1926,21 +1953,27 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Prepare everything with our `accelerator`. if not freeze_text_encoder: - ( - transformer, - text_encoder_one, - text_encoder_two, - optimizer, - train_dataloader, - lr_scheduler, - ) = accelerator.prepare( - transformer, - text_encoder_one, - text_encoder_two, - optimizer, - train_dataloader, - lr_scheduler, - ) + if args.enable_t5_ti: + ( + transformer, + text_encoder_one, + text_encoder_two, + optimizer, + train_dataloader, + lr_scheduler, + ) = accelerator.prepare( + transformer, + text_encoder_one, + text_encoder_two, + optimizer, + train_dataloader, + lr_scheduler, + ) + else: + transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler + ) + else: transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( transformer, optimizer, train_dataloader, lr_scheduler @@ -2044,7 +2077,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): pivoted_tr = True else: # still optimizing the text encoder - if args.train_text_encoder: + if args.train_text_encoder or not args.enable_t5_ti: text_encoder_one.train() # set top parameter requires_grad = True for gradient checkpointing works accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) @@ -2057,7 +2090,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if pivoted_te: # stopping optimization of text_encoder params optimizer.param_groups[te_idx]["lr"] = 0.0 - if args.train_text_encoder_ti: + if args.train_text_encoder_ti and args.enable_t5_ti: optimizer.param_groups[te_idx+1]["lr"] = 0.0 elif pivoted_tr and not pure_textual_inversion: print("PIVOT TRANSFORMER HELOOOO") @@ -2074,12 +2107,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) else: tokens_one = tokenize_prompt( - tokenizer_one, prompts, max_sequence_length=77, add_special_tokens=add_special_tokens + tokenizer_one, prompts, max_sequence_length=77, add_special_tokens=add_special_tokens_clip ) tokens_two = tokenize_prompt( tokenizer_two, prompts, - max_sequence_length=args.max_sequence_length, + max_sequence_length=args.max_sequence_length, add_special_tokens=add_special_tokens_t5 ) if not freeze_text_encoder: @@ -2345,6 +2378,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): base_model=args.pretrained_model_name_or_path, train_text_encoder=args.train_text_encoder, train_text_encoder_ti=args.train_text_encoder_ti, + enable_t5_ti = args.enable_t5_ti, pure_textual_inversion=pure_textual_inversion, token_abstraction_dict=train_dataset.token_abstraction_dict, instance_prompt=args.instance_prompt, From f74c4be6e6af6c6fcaee10d701d40962a25f0282 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 20 Sep 2024 15:21:41 +0300 Subject: [PATCH 35/82] fix param chaining --- .../train_dreambooth_lora_flux_advanced.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index e40676f9e641..11ede2d7645e 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -2237,7 +2237,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): params_to_clip = ( itertools.chain(transformer.parameters(), text_encoder_one.parameters())) elif pure_textual_inversion: - params_to_clip = (text_encoder_one.parameters(), text_encoder_two.parameters()) + params_to_clip = (itertools.chain(text_encoder_one.parameters(), text_encoder_two.parameters())) else: params_to_clip = ( itertools.chain(transformer.parameters(), text_encoder_one.parameters(), text_encoder_two.parameters())) From c597bd86c26c89300b35ca4402c76f3843ae8116 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 20 Sep 2024 15:27:24 +0300 Subject: [PATCH 36/82] fix param chaining --- .../train_dreambooth_lora_flux_advanced.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 11ede2d7645e..aee3bb45c4b5 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -2242,7 +2242,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): params_to_clip = ( itertools.chain(transformer.parameters(), text_encoder_one.parameters(), text_encoder_two.parameters())) else: - params_to_clip = (transformer.parameters()) + params_to_clip = (itertools.chain(transformer.parameters())) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() From 81fe4079acb175300b4330361483a87fbb1c390a Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Sat, 21 Sep 2024 10:49:59 +0300 Subject: [PATCH 37/82] README first draft --- .../README_flux.md | 226 ++++++++++++++++++ .../train_dreambooth_lora_flux_advanced.py | 2 +- 2 files changed, 227 insertions(+), 1 deletion(-) diff --git a/examples/advanced_diffusion_training/README_flux.md b/examples/advanced_diffusion_training/README_flux.md index e69de29bb2d1..c14cfca44275 100644 --- a/examples/advanced_diffusion_training/README_flux.md +++ b/examples/advanced_diffusion_training/README_flux.md @@ -0,0 +1,226 @@ +# Advanced diffusion training examples + +## Train Dreambooth LoRA with Flux.1 Dev +> [!TIP] +> 💡 This example follows some of the techniques and recommended practices covered in the community derived guide we made for SDXL training: [LoRA training scripts of the world, unite!](https://huggingface.co/blog/sdxl_lora_advanced_script). +> As many of these are architecture agnostic & generally relevant to fine-tuning of diffusion models we suggest to take a look 🤗 + +[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject. + +LoRA - Low-Rank Adaption of Large Language Models, was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen* +In a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages: +- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114) +- Rank-decomposition matrices have significantly fewer parameters than the original model, which means that trained LoRA weights are easily portable. +- LoRA attention layers allow to control to which extent the model is adapted towards new training images via a `scale` parameter. +[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in +the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository. + +The `train_dreambooth_lora_flux_advanced.py` script shows how to implement dreambooth-LoRA, combining the training process shown in `train_dreambooth_lora_flux.py`, with +advanced features and techniques, inspired and built upon contributions by [Nataniel Ruiz](https://twitter.com/natanielruizg): [Dreambooth](https://dreambooth.github.io), [Rinon Gal](https://twitter.com/RinonGal): [Textual Inversion](https://textual-inversion.github.io), [Ron Mokady](https://twitter.com/MokadyRon): [Pivotal Tuning](https://arxiv.org/abs/2106.05744), [Simo Ryu](https://twitter.com/cloneofsimo): [cog-sdxl](https://github.com/replicate/cog-sdxl), +[ostris](https://x.com/ostrisai):[ai-toolkit](https://github.com/ostris/ai-toolkit), [bghira](https://github.com/bghira):[SimpleTuner](https://github.com/bghira/SimpleTuner), [Kohya](https://twitter.com/kohya_tech/): [sd-scripts](https://github.com/kohya-ss/sd-scripts), [The Last Ben](https://twitter.com/__TheBen): [fast-stable-diffusion](https://github.com/TheLastBen/fast-stable-diffusion) ❤️ + +> [!NOTE] +> 💡If this is your first time training a Dreambooth LoRA, congrats!🥳 +> You might want to familiarize yourself more with the techniques: [Dreambooth blog](https://huggingface.co/blog/dreambooth), [Using LoRA for Efficient Stable Diffusion Fine-Tuning blog](https://huggingface.co/blog/lora) + +## Running locally with PyTorch + +### Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +**Important** + +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install -e . +``` + +Then cd in the `examples/advanced_diffusion_training` folder and run +```bash +pip install -r requirements.txt +``` + +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +Or for a default accelerate configuration without answering questions about your environment + +```bash +accelerate config default +``` + +Or if your environment doesn't support an interactive shell e.g. a notebook + +```python +from accelerate.utils import write_basic_config +write_basic_config() +``` + +When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. +Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment. + +### Pivotal Tuning +**Training with text encoder(s)** + +Alongside the UNet, LoRA fine-tuning of the text encoders is also supported. In addition to the text encoder optimization +available with `train_dreambooth_lora_flux_advanced.py`, in the advanced script **pivotal tuning** is also supported. +[pivotal tuning](https://huggingface.co/blog/sdxl_lora_advanced_script#pivotal-tuning) combines Textual Inversion with regular diffusion fine-tuning - +we insert new tokens into the text encoders of the model, instead of reusing existing ones. +We then optimize the newly-inserted token embeddings to represent the new concept. + +To do so, just specify `--train_text_encoder_ti` while launching training (for regular text encoder optimizations, use `--train_text_encoder`). +Please keep the following points in mind: + +* Flux uses two text encoders - [CLIP]() & [T5]() , by default `--train_text_encoder_ti` performs pivotal tuning for the **CLIP** encoder only. +To activate pivotal tuning for both encoders, add the flag `--enable_t5_ti`. +* When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memory. +* pure textual inversion +* token initializer + +### 3D icon example + +Now let's get our dataset. For this example we will use some cool images of 3d rendered icons: https://huggingface.co/datasets/linoyts/3d_icon. + +Let's first download it locally: + +```python +from huggingface_hub import snapshot_download + +local_dir = "./3d_icon" +snapshot_download( + "LinoyTsaban/3d_icon", + local_dir=local_dir, repo_type="dataset", + ignore_patterns=".gitattributes", +) +``` + +Let's review some of the advanced features we're going to be using for this example: +- **custom captions**: +To use custom captioning, first ensure that you have the datasets library installed, otherwise you can install it by +```bash +pip install datasets +``` + +Now we'll simply specify the name of the dataset and caption column (in this case it's "prompt") + +``` +--dataset_name=./3d_icon +--caption_column=prompt +``` + +You can also load a dataset straight from by specifying it's name in `dataset_name`. +Look [here](https://huggingface.co/blog/sdxl_lora_advanced_script#custom-captioning) for more info on creating/loadin your own caption dataset. + +- **optimizer**: for this example, we'll use [prodigy](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers) - an adaptive optimizer +- **pivotal tuning** + +**Now, we can launch training:** + +```bash +export MODEL_NAME="black-forest-labs/FLUX.1-dev" +export DATASET_NAME="./3d_icon" +export OUTPUT_DIR="3d-icon-Flux-LoRA" + +accelerate launch train_dreambooth_lora_flux_advanced.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$DATASET_NAME \ + --instance_prompt="3d icon in the style of TOK" \ + --validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \ + --output_dir=$OUTPUT_DIR \ + --caption_column="prompt" \ + --mixed_precision="bf16" \ + --resolution=1024 \ + --train_batch_size=3 \ + --repeats=1 \ + --report_to="wandb"\ + --gradient_accumulation_steps=1 \ + --gradient_checkpointing \ + --learning_rate=1.0 \ + --text_encoder_lr=1.0 \ + --optimizer="prodigy"\ + --train_text_encoder_ti\ + --train_text_encoder_ti_frac=0.5\ + --snr_gamma=5.0 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --rank=8 \ + --max_train_steps=1000 \ + --checkpointing_steps=2000 \ + --seed="0" \ + --push_to_hub +``` + +To better track our training experiments, we're using the following flags in the command above: + +* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`. +* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected. + +Our experiments were conducted on a single 40GB A100 GPU. + + +### Inference + +Once training is done, we can perform inference like so: +1. starting with loading the transformer lora weights +```python +import torch +from huggingface_hub import hf_hub_download, upload_file +from diffusers import AutoPipelineForText2Image +from safetensors.torch import load_file + +username = "linoyts" +repo_id = f"{username}/3d-icon-Flux-LoRA" + +pipe = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to('cuda') + + +pipe.load_lora_weights(repo_id, weight_name="pytorch_lora_weights.safetensors") +``` +2. now we load the pivotal tuning embeddings + +```python +text_encoders = [pipe.text_encoder, pipe.text_encoder_2] +tokenizers = [pipe.tokenizer, pipe.tokenizer_2] + +embedding_path = hf_hub_download(repo_id=repo_id, filename="3d-icon-Flux-LoRA_emb.safetensors", repo_type="model") + +state_dict = load_file(embedding_path) +# load embeddings of text_encoder 1 (CLIP ViT-L/14) +pipe.load_textual_inversion(state_dict["clip_l"], token=["", ""], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer) +# load embeddings of text_encoder 2 (T5 XXL) +pipe.load_textual_inversion(state_dict["t5"], token=["", ""], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2) +``` + +3. let's generate images + +```python +instance_token = "" +prompt = f"a {instance_token} icon of an orange llama eating ramen, in the style of {instance_token}" + +image = pipe(prompt=prompt, num_inference_steps=25, cross_attention_kwargs={"scale": 1.0}).images[0] +image.save("llama.png") +``` + + + +### Comfy UI / AUTOMATIC1111 Inference +The new script fully supports textual inversion loading with Comfy UI and AUTOMATIC1111 formats! + +**AUTOMATIC1111 / SD.Next** \ +In AUTOMATIC1111/SD.Next we will load a LoRA and a textual embedding at the same time. +- *LoRA*: Besides the diffusers format, the script will also train a WebUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory. +- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `embeddings` directory. + +You can then run inference by prompting `a y2k_emb webpage about the movie Mean Girls `. You can use the `y2k_emb` token normally, including increasing its weight by doing `(y2k_emb:1.2)`. + +**ComfyUI** \ +In ComfyUI we will load a LoRA and a textual embedding at the same time. +- *LoRA*: Besides the diffusers format, the script will also train a ComfyUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory. Then you will load the LoRALoader node and hook that up with your model and CLIP. [Official guide for loading LoRAs](https://comfyanonymous.github.io/ComfyUI_examples/lora/) +- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `models/embeddings` directory and use it in your prompts like `embedding:y2k_emb`. [Official guide for loading embeddings](https://comfyanonymous.github.io/ComfyUI_examples/textual_inversion_embeddings/). +- diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index aee3bb45c4b5..62fecb1ac71a 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -757,7 +757,7 @@ def parse_args(input_args=None): ) if args.train_transformer_frac < 1 and not args.train_text_encoder_ti: raise ValueError( - "--train_transformer_frac must be > 0 if text_encoder training / textual inversion is not enabled." + "--train_transformer_frac must be == 1 if text_encoder training / textual inversion is not enabled." ) if args.train_transformer_frac < 1 and args.train_text_encoder_ti_frac < 1: raise ValueError("--train_transformer_frac and --train_text_encoder_ti_frac are identical and smaller than 1. " From c086a14c956ee769c2ba65bce2d6d9fcdd4310b2 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 23 Sep 2024 13:59:51 +0300 Subject: [PATCH 38/82] readme --- .../advanced_diffusion_training/README_flux.md | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/examples/advanced_diffusion_training/README_flux.md b/examples/advanced_diffusion_training/README_flux.md index c14cfca44275..fb5fdc9a7844 100644 --- a/examples/advanced_diffusion_training/README_flux.md +++ b/examples/advanced_diffusion_training/README_flux.md @@ -77,13 +77,13 @@ We then optimize the newly-inserted token embeddings to represent the new concep To do so, just specify `--train_text_encoder_ti` while launching training (for regular text encoder optimizations, use `--train_text_encoder`). Please keep the following points in mind: -* Flux uses two text encoders - [CLIP]() & [T5]() , by default `--train_text_encoder_ti` performs pivotal tuning for the **CLIP** encoder only. +* Flux uses two text encoders - [CLIP](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux#diffusers.FluxPipeline.text_encoder) & [T5](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux#diffusers.FluxPipeline.text_encoder_2) , by default `--train_text_encoder_ti` performs pivotal tuning for the **CLIP** encoder only. To activate pivotal tuning for both encoders, add the flag `--enable_t5_ti`. * When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memory. * pure textual inversion * token initializer -### 3D icon example +## Training examples Now let's get our dataset. For this example we will use some cool images of 3d rendered icons: https://huggingface.co/datasets/linoyts/3d_icon. @@ -120,6 +120,7 @@ Look [here](https://huggingface.co/blog/sdxl_lora_advanced_script#custom-caption - **optimizer**: for this example, we'll use [prodigy](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers) - an adaptive optimizer - **pivotal tuning** +### Example #1: Pivotal tuning **Now, we can launch training:** ```bash @@ -163,8 +164,11 @@ To better track our training experiments, we're using the following flags in the Our experiments were conducted on a single 40GB A100 GPU. +### Example #2: Pivotal tuning with T5 -### Inference +### Example #3: Textual Inversion + +### Inference - pivotal tuning Once training is done, we can perform inference like so: 1. starting with loading the transformer lora weights @@ -182,7 +186,8 @@ pipe = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.1-dev", pipe.load_lora_weights(repo_id, weight_name="pytorch_lora_weights.safetensors") ``` -2. now we load the pivotal tuning embeddings +2. now we load the pivotal tuning embeddings +💡note that if you didn't enable `--enable_t5_ti`, you only load the embeddings to the CLIP encoder ```python text_encoders = [pipe.text_encoder, pipe.text_encoder_2] @@ -193,7 +198,7 @@ embedding_path = hf_hub_download(repo_id=repo_id, filename="3d-icon-Flux-LoRA_em state_dict = load_file(embedding_path) # load embeddings of text_encoder 1 (CLIP ViT-L/14) pipe.load_textual_inversion(state_dict["clip_l"], token=["", ""], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer) -# load embeddings of text_encoder 2 (T5 XXL) +# load embeddings of text_encoder 2 (T5 XXL) - ignore this line if you didn't enable `--enable_t5_ti` pipe.load_textual_inversion(state_dict["t5"], token=["", ""], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2) ``` @@ -207,7 +212,7 @@ image = pipe(prompt=prompt, num_inference_steps=25, cross_attention_kwargs={"sca image.save("llama.png") ``` - +### Inference - pure textual inversion ### Comfy UI / AUTOMATIC1111 Inference The new script fully supports textual inversion loading with Comfy UI and AUTOMATIC1111 formats! @@ -223,4 +228,3 @@ You can then run inference by prompting `a y2k_emb webpage about the movie Mean In ComfyUI we will load a LoRA and a textual embedding at the same time. - *LoRA*: Besides the diffusers format, the script will also train a ComfyUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory. Then you will load the LoRALoader node and hook that up with your model and CLIP. [Official guide for loading LoRAs](https://comfyanonymous.github.io/ComfyUI_examples/lora/) - *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `models/embeddings` directory and use it in your prompts like `embedding:y2k_emb`. [Official guide for loading embeddings](https://comfyanonymous.github.io/ComfyUI_examples/textual_inversion_embeddings/). -- From 549d3d0504da1ca41f9a8e7a9f59ef7006cc4628 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 23 Sep 2024 14:13:56 +0300 Subject: [PATCH 39/82] readme --- .../README_flux.md | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/examples/advanced_diffusion_training/README_flux.md b/examples/advanced_diffusion_training/README_flux.md index fb5fdc9a7844..56eeb9ef5eae 100644 --- a/examples/advanced_diffusion_training/README_flux.md +++ b/examples/advanced_diffusion_training/README_flux.md @@ -213,6 +213,42 @@ image.save("llama.png") ``` ### Inference - pure textual inversion +In this case, we don't load transformer layers as before, since we only optimize the embeddings + +1. starting with loading the embeddings. +💡note that here too, if you didn't enable `--enable_t5_ti`, you only load the embeddings to the CLIP encoder + +```python +import torch +from huggingface_hub import hf_hub_download, upload_file +from diffusers import AutoPipelineForText2Image +from safetensors.torch import load_file + +username = "linoyts" +repo_id = f"{username}/3d-icon-Flux-LoRA" + +pipe = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to('cuda') + +text_encoders = [pipe.text_encoder, pipe.text_encoder_2] +tokenizers = [pipe.tokenizer, pipe.tokenizer_2] + +embedding_path = hf_hub_download(repo_id=repo_id, filename="3d-icon-Flux-LoRA_emb.safetensors", repo_type="model") + +state_dict = load_file(embedding_path) +# load embeddings of text_encoder 1 (CLIP ViT-L/14) +pipe.load_textual_inversion(state_dict["clip_l"], token=["", ""], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer) +# load embeddings of text_encoder 2 (T5 XXL) - ignore this line if you didn't enable `--enable_t5_ti` +pipe.load_textual_inversion(state_dict["t5"], token=["", ""], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2) +``` +2. let's generate images + +```python +instance_token = "" +prompt = f"a {instance_token} icon of an orange llama eating ramen, in the style of {instance_token}" + +image = pipe(prompt=prompt, num_inference_steps=25, cross_attention_kwargs={"scale": 1.0}).images[0] +image.save("llama.png") +``` ### Comfy UI / AUTOMATIC1111 Inference The new script fully supports textual inversion loading with Comfy UI and AUTOMATIC1111 formats! From 983bab845f0cb6b139afba5dbb007889ddb84bed Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 23 Sep 2024 14:34:34 +0300 Subject: [PATCH 40/82] readme --- .../README_flux.md | 41 +++++++++++++++++-- 1 file changed, 38 insertions(+), 3 deletions(-) diff --git a/examples/advanced_diffusion_training/README_flux.md b/examples/advanced_diffusion_training/README_flux.md index 56eeb9ef5eae..f2635488834e 100644 --- a/examples/advanced_diffusion_training/README_flux.md +++ b/examples/advanced_diffusion_training/README_flux.md @@ -137,7 +137,7 @@ accelerate launch train_dreambooth_lora_flux_advanced.py \ --caption_column="prompt" \ --mixed_precision="bf16" \ --resolution=1024 \ - --train_batch_size=3 \ + --train_batch_size=1 \ --repeats=1 \ --report_to="wandb"\ --gradient_accumulation_steps=1 \ @@ -147,7 +147,6 @@ accelerate launch train_dreambooth_lora_flux_advanced.py \ --optimizer="prodigy"\ --train_text_encoder_ti\ --train_text_encoder_ti_frac=0.5\ - --snr_gamma=5.0 \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --rank=8 \ @@ -165,9 +164,45 @@ To better track our training experiments, we're using the following flags in the Our experiments were conducted on a single 40GB A100 GPU. ### Example #2: Pivotal tuning with T5 +Now let's try that with T5 as well, so instead of only optimizing the CLIP embeddings associated with newly inserted tokens, we'll optimize +the T5 embeddings as well. We can do this by simply adding `--enable_t5_ti` to the previous configuration: +```bash +export MODEL_NAME="black-forest-labs/FLUX.1-dev" +export DATASET_NAME="./3d_icon" +export OUTPUT_DIR="3d-icon-Flux-LoRA" + +accelerate launch train_dreambooth_lora_flux_advanced.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$DATASET_NAME \ + --instance_prompt="3d icon in the style of TOK" \ + --validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \ + --output_dir=$OUTPUT_DIR \ + --caption_column="prompt" \ + --mixed_precision="bf16" \ + --resolution=1024 \ + --train_batch_size=1 \ + --repeats=1 \ + --report_to="wandb"\ + --gradient_accumulation_steps=1 \ + --gradient_checkpointing \ + --learning_rate=1.0 \ + --text_encoder_lr=1.0 \ + --optimizer="prodigy"\ + --train_text_encoder_ti\ + --enable_t5_ti\ + --train_text_encoder_ti_frac=0.5\ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --rank=8 \ + --max_train_steps=1000 \ + --checkpointing_steps=2000 \ + --seed="0" \ + --push_to_hub +``` ### Example #3: Textual Inversion + ### Inference - pivotal tuning Once training is done, we can perform inference like so: @@ -250,7 +285,7 @@ image = pipe(prompt=prompt, num_inference_steps=25, cross_attention_kwargs={"sca image.save("llama.png") ``` -### Comfy UI / AUTOMATIC1111 Inference +### Comfy UI / AUTOMATIC1111 Inference - **NEEDS TO BE UPDATED** The new script fully supports textual inversion loading with Comfy UI and AUTOMATIC1111 formats! **AUTOMATIC1111 / SD.Next** \ From aefa48abbdb514ccd18d19a643d567eddb6f9889 Mon Sep 17 00:00:00 2001 From: Linoy Date: Mon, 30 Sep 2024 08:40:43 +0000 Subject: [PATCH 41/82] style --- .../train_dreambooth_lora_flux_advanced.py | 108 ++++++++++-------- 1 file changed, 61 insertions(+), 47 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 62fecb1ac71a..865a3bd909d5 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -86,7 +86,7 @@ def save_model_card( base_model: str = None, train_text_encoder=False, train_text_encoder_ti=False, - enable_t5_ti = False, + enable_t5_ti=False, pure_textual_inversion=False, token_abstraction_dict=None, instance_prompt=None, @@ -106,7 +106,9 @@ def save_model_card( diffusers_imports_pivotal = "" diffusers_example_pivotal = "" if not pure_textual_inversion: - diffusers_load_lora = f"""pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')""" + diffusers_load_lora = ( + f"""pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')""" + ) if train_text_encoder_ti: embeddings_filename = f"{repo_folder}_emb" ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"", instance_prompt)) @@ -378,7 +380,7 @@ def parse_args(input_args=None): default="random", help="the token (or tokens) to use to initialize the new inserted tokens when training with " "--train_text_encoder_ti = True. By default, new tokens () are initialized with random value. " - "Alternatively, you could specify a different token whos value will be used as the starting point for the new inserted tokens" + "Alternatively, you could specify a different token whos value will be used as the starting point for the new inserted tokens", ) parser.add_argument( "--class_prompt", @@ -478,7 +480,9 @@ def parse_args(input_args=None): parser.add_argument( "--enable_t5_ti", action="store_true", - help=("Whether to use pivotal tuning / textual inversion for the T5 encoder as well (in addition to CLIP encoder)"), + help=( + "Whether to use pivotal tuning / textual inversion for the T5 encoder as well (in addition to CLIP encoder)" + ), ) parser.add_argument( @@ -760,8 +764,10 @@ def parse_args(input_args=None): "--train_transformer_frac must be == 1 if text_encoder training / textual inversion is not enabled." ) if args.train_transformer_frac < 1 and args.train_text_encoder_ti_frac < 1: - raise ValueError("--train_transformer_frac and --train_text_encoder_ti_frac are identical and smaller than 1. " - "This contradicts with --max_train_steps, please specify different values or set both to 1.") + raise ValueError( + "--train_transformer_frac and --train_text_encoder_ti_frac are identical and smaller than 1. " + "This contradicts with --max_train_steps, please specify different values or set both to 1." + ) if args.enable_t5_ti and not args.train_text_encoder_ti: warnings.warn("You need not use --enable_t5_ti without --train_text_encoder_ti.") @@ -812,17 +818,19 @@ def initialize_new_tokens(self, inserting_toks: List[str]): self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks) # random initialization of new tokens - embeds = text_encoder.text_model.embeddings.token_embedding if idx==0 else text_encoder.encoder.embed_tokens - std_token_embedding =embeds.weight.data.std() + embeds = ( + text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens + ) + std_token_embedding = embeds.weight.data.std() print(f"{idx} text encoder's std_token_embedding: {std_token_embedding}") if args.initializer_token.lower == "random": - hidden_size = text_encoder.text_model.config.hidden_size if idx ==0 else text_encoder.encoder.config.hidden_size + hidden_size = ( + text_encoder.text_model.config.hidden_size if idx == 0 else text_encoder.encoder.config.hidden_size + ) embeds.weight.data[self.train_ids] = ( - torch.randn(len(self.train_ids), hidden_size) - .to(device=self.device) - .to(dtype=self.dtype) + torch.randn(len(self.train_ids), hidden_size).to(device=self.device).to(dtype=self.dtype) * std_token_embedding ) else: @@ -833,12 +841,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]): raise ValueError("The initializer token must be a single token.") initializer_token_id = token_ids[0] for token_id in self.train_ids: - embeds.weight.data[token_id] = ( - embeds.weight.data)[initializer_token_id].clone() + embeds.weight.data[token_id] = (embeds.weight.data)[initializer_token_id].clone() - self.embeddings_settings[ - f"original_embeddings_{idx}" - ] = embeds.weight.data.clone() + self.embeddings_settings[f"original_embeddings_{idx}"] = embeds.weight.data.clone() self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding # makes sure we don't update any embedding weights besides the newly added token @@ -857,10 +862,10 @@ def save_embeddings(self, file_path: str): # text_encoder_one, idx==0 - CLIP ViT-L/14, text_encoder_two, idx==1 - T5 xxl idx_to_text_encoder_name = {0: "clip_l", 1: "t5"} for idx, text_encoder in enumerate(self.text_encoders): - embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens - assert embeds.weight.data.shape[0] == len( - self.tokenizers[idx] - ), "Tokenizers should be the same." + embeds = ( + text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens + ) + assert embeds.weight.data.shape[0] == len(self.tokenizers[idx]), "Tokenizers should be the same." new_token_embeddings = embeds.weight.data[self.train_ids] # New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), @@ -881,7 +886,9 @@ def device(self): @torch.no_grad() def retract_embeddings(self): for idx, text_encoder in enumerate(self.text_encoders): - embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens + embeds = ( + text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens + ) index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"] embeds.weight.data[index_no_updates] = ( self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates] @@ -1571,7 +1578,6 @@ def main(args): ) text_encoder_one.add_adapter(text_lora_config) - def unwrap_model(model): model = accelerator.unwrap_model(model) model = model._orig_mod if is_compiled_module(model) else model @@ -1683,7 +1689,7 @@ def load_model_hook(models, input_dir): text_lora_parameters_one.append(param) else: param.requires_grad = False - if args.enable_t5_ti: # whether to do pivotal tuning/textual inversion for T5 as well + if args.enable_t5_ti: # whether to do pivotal tuning/textual inversion for T5 as well text_lora_parameters_two = [] for name, param in text_encoder_two.named_parameters(): if "token_embedding" in name: @@ -1716,10 +1722,10 @@ def load_model_hook(models, input_dir): # pure textual inversion - only clip if pure_textual_inversion: params_to_optimize = [ - text_parameters_one_with_lr, - ] + text_parameters_one_with_lr, + ] te_idx = 0 - else: # regular te training or regular pivotal for clip + else: # regular te training or regular pivotal for clip params_to_optimize = [ transformer_parameters_with_lr, text_parameters_one_with_lr, @@ -1736,16 +1742,13 @@ def load_model_hook(models, input_dir): } # pure textual inversion - only clip & t5 if pure_textual_inversion: - params_to_optimize = [ - text_parameters_one_with_lr, - text_parameters_two_with_lr - ] + params_to_optimize = [text_parameters_one_with_lr, text_parameters_two_with_lr] te_idx = 0 - else: # regular pivotal tuning of clip & t5 + else: # regular pivotal tuning of clip & t5 params_to_optimize = [ transformer_parameters_with_lr, text_parameters_one_with_lr, - text_parameters_two_with_lr + text_parameters_two_with_lr, ] te_idx = 1 else: @@ -1905,14 +1908,23 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): tokenizer_one, args.instance_prompt, max_sequence_length=77, add_special_tokens=add_special_tokens_clip ) tokens_two = tokenize_prompt( - tokenizer_two, args.instance_prompt, max_sequence_length=args.max_sequence_length, add_special_tokens=add_special_tokens_t5 + tokenizer_two, + args.instance_prompt, + max_sequence_length=args.max_sequence_length, + add_special_tokens=add_special_tokens_t5, ) if args.with_prior_preservation: class_tokens_one = tokenize_prompt( - tokenizer_one, args.class_prompt, max_sequence_length=77, add_special_tokens=add_special_tokens_clip + tokenizer_one, + args.class_prompt, + max_sequence_length=77, + add_special_tokens=add_special_tokens_clip, ) class_tokens_two = tokenize_prompt( - tokenizer_two, args.class_prompt, max_sequence_length=args.max_sequence_length, add_special_tokens=add_special_tokens_t5 + tokenizer_two, + args.class_prompt, + max_sequence_length=args.max_sequence_length, + add_special_tokens=add_special_tokens_t5, ) tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) @@ -2081,17 +2093,16 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): text_encoder_one.train() # set top parameter requires_grad = True for gradient checkpointing works accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) - else: # textual inversion / pivotal tuning + else: # textual inversion / pivotal tuning text_encoder_one.train() text_encoder_two.train() - for step, batch in enumerate(train_dataloader): if pivoted_te: # stopping optimization of text_encoder params optimizer.param_groups[te_idx]["lr"] = 0.0 if args.train_text_encoder_ti and args.enable_t5_ti: - optimizer.param_groups[te_idx+1]["lr"] = 0.0 + optimizer.param_groups[te_idx + 1]["lr"] = 0.0 elif pivoted_tr and not pure_textual_inversion: print("PIVOT TRANSFORMER HELOOOO") optimizer.param_groups[0]["lr"] = 0.0 @@ -2112,7 +2123,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): tokens_two = tokenize_prompt( tokenizer_two, prompts, - max_sequence_length=args.max_sequence_length, add_special_tokens=add_special_tokens_t5 + max_sequence_length=args.max_sequence_length, + add_special_tokens=add_special_tokens_t5, ) if not freeze_text_encoder: @@ -2234,15 +2246,17 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if accelerator.sync_gradients: if not freeze_text_encoder: if args.train_text_encoder: - params_to_clip = ( - itertools.chain(transformer.parameters(), text_encoder_one.parameters())) + params_to_clip = itertools.chain(transformer.parameters(), text_encoder_one.parameters()) elif pure_textual_inversion: - params_to_clip = (itertools.chain(text_encoder_one.parameters(), text_encoder_two.parameters())) + params_to_clip = itertools.chain( + text_encoder_one.parameters(), text_encoder_two.parameters() + ) else: - params_to_clip = ( - itertools.chain(transformer.parameters(), text_encoder_one.parameters(), text_encoder_two.parameters())) + params_to_clip = itertools.chain( + transformer.parameters(), text_encoder_one.parameters(), text_encoder_two.parameters() + ) else: - params_to_clip = (itertools.chain(transformer.parameters())) + params_to_clip = itertools.chain(transformer.parameters()) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() @@ -2378,7 +2392,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): base_model=args.pretrained_model_name_or_path, train_text_encoder=args.train_text_encoder, train_text_encoder_ti=args.train_text_encoder_ti, - enable_t5_ti = args.enable_t5_ti, + enable_t5_ti=args.enable_t5_ti, pure_textual_inversion=pure_textual_inversion, token_abstraction_dict=train_dataset.token_abstraction_dict, instance_prompt=args.instance_prompt, From ae10674da42e4680fab8b9af1a2229b2330eff9a Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 30 Sep 2024 10:46:34 +0200 Subject: [PATCH 42/82] fix import --- .../train_dreambooth_lora_flux_advanced.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 865a3bd909d5..ca0f2fc5f9a7 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -26,7 +26,7 @@ import warnings from contextlib import nullcontext from pathlib import Path -from typing import List, Union +from typing import List, Union, Optional import numpy as np import torch From dcd0e71ae699a1700c5f78582e85bedc7ab6a089 Mon Sep 17 00:00:00 2001 From: Linoy Date: Mon, 30 Sep 2024 08:49:20 +0000 Subject: [PATCH 43/82] style --- .../train_dreambooth_lora_flux_advanced.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index ca0f2fc5f9a7..f56f0c92733e 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -26,7 +26,7 @@ import warnings from contextlib import nullcontext from pathlib import Path -from typing import List, Union, Optional +from typing import List, Optional, Union import numpy as np import torch From 99b75217bc00e520f455c3875a0510c95232e844 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 30 Sep 2024 10:57:34 +0200 Subject: [PATCH 44/82] add fix from https://github.com/huggingface/diffusers/pull/9419 --- .../train_dreambooth_lora_flux_advanced.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index ca0f2fc5f9a7..a29354925421 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -217,13 +217,14 @@ def log_validation( accelerator, pipeline_args, epoch, + torch_dtype, is_final_validation=False, ): logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) - pipeline = pipeline.to(accelerator.device) + pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) pipeline.set_progress_bar_config(disable=True) # run inference @@ -2327,6 +2328,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator=accelerator, pipeline_args=pipeline_args, epoch=epoch, + torch_dtype=weight_dtype, ) if freeze_text_encoder: del text_encoder_one, text_encoder_two @@ -2383,6 +2385,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator=accelerator, pipeline_args=pipeline_args, epoch=epoch, + torch_dtype=weight_dtype, is_final_validation=True, ) From 57fb65b285bc2d4944753a7e55d85e44ae956bd3 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 30 Sep 2024 16:27:57 +0200 Subject: [PATCH 45/82] add to readme, change function names --- .../README_flux.md | 51 ++++++++++++++++--- .../train_dreambooth_lora_flux_advanced.py | 10 ++-- 2 files changed, 49 insertions(+), 12 deletions(-) diff --git a/examples/advanced_diffusion_training/README_flux.md b/examples/advanced_diffusion_training/README_flux.md index f2635488834e..6e51be2264bf 100644 --- a/examples/advanced_diffusion_training/README_flux.md +++ b/examples/advanced_diffusion_training/README_flux.md @@ -5,7 +5,7 @@ > 💡 This example follows some of the techniques and recommended practices covered in the community derived guide we made for SDXL training: [LoRA training scripts of the world, unite!](https://huggingface.co/blog/sdxl_lora_advanced_script). > As many of these are architecture agnostic & generally relevant to fine-tuning of diffusion models we suggest to take a look 🤗 -[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject. +[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like flux, stable diffusion given just a few(3~5) images of a subject. LoRA - Low-Rank Adaption of Large Language Models, was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen* In a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages: @@ -65,10 +65,10 @@ write_basic_config() When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment. -### Pivotal Tuning +### Pivotal Tuning (and more) **Training with text encoder(s)** -Alongside the UNet, LoRA fine-tuning of the text encoders is also supported. In addition to the text encoder optimization +Alongside the Transformer, LoRA fine-tuning of the text encoders is also supported. In addition to the text encoder optimization available with `train_dreambooth_lora_flux_advanced.py`, in the advanced script **pivotal tuning** is also supported. [pivotal tuning](https://huggingface.co/blog/sdxl_lora_advanced_script#pivotal-tuning) combines Textual Inversion with regular diffusion fine-tuning - we insert new tokens into the text encoders of the model, instead of reusing existing ones. @@ -80,8 +80,8 @@ Please keep the following points in mind: * Flux uses two text encoders - [CLIP](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux#diffusers.FluxPipeline.text_encoder) & [T5](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux#diffusers.FluxPipeline.text_encoder_2) , by default `--train_text_encoder_ti` performs pivotal tuning for the **CLIP** encoder only. To activate pivotal tuning for both encoders, add the flag `--enable_t5_ti`. * When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memory. -* pure textual inversion -* token initializer +* **pure textual inversion** - to support the full range from pivotal tuning to textual inversion we introduce `--train_transformer_frac` which controls the amount of epochs the transformer LoRA layers are trained. By default, `--train_transformer_frac==1`, to trigger a textual inversion run set `--train_transformer_frac==0`. Values between 0 and 1 are supported as well, and we welcome the community to experiment w/ different settings and share the results! +* **token initializer** - similar to the original textual inversion work, you can specify a token of your choosing as the starting point for training. By default, when enabling `--train_text_encoder_ti`, the new inserted tokens are initialized randomly. You can specify a token in `--initializer_token` such that the starting point for the trained embeddings will be the embeddings associated with your chosen `--initializer_token`. ## Training examples @@ -201,8 +201,44 @@ accelerate launch train_dreambooth_lora_flux_advanced.py \ ``` ### Example #3: Textual Inversion +To explore a pure textual inversion - i.e. only optimizing the text embeddings w/o training transformer LoRA layers, we +can set the value for `--train_transformer_frac` - which is responsible for the percent of epochs in which the transformer is +trained. By setting `--train_transformer_frac == 0` and enabling `--train_text_encoder_ti` we trigger a textual inversion train +run. +```bash +export MODEL_NAME="black-forest-labs/FLUX.1-dev" +export DATASET_NAME="./3d_icon" +export OUTPUT_DIR="3d-icon-Flux-LoRA" - +accelerate launch train_dreambooth_lora_flux_advanced.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$DATASET_NAME \ + --instance_prompt="3d icon in the style of TOK" \ + --validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \ + --output_dir=$OUTPUT_DIR \ + --caption_column="prompt" \ + --mixed_precision="bf16" \ + --resolution=1024 \ + --train_batch_size=1 \ + --repeats=1 \ + --report_to="wandb"\ + --gradient_accumulation_steps=1 \ + --gradient_checkpointing \ + --learning_rate=1.0 \ + --text_encoder_lr=1.0 \ + --optimizer="prodigy"\ + --train_text_encoder_ti\ + --enable_t5_ti\ + --train_text_encoder_ti_frac=0.5\ + --train_transformer_frac=0\ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --rank=8 \ + --max_train_steps=1000 \ + --checkpointing_steps=2000 \ + --seed="0" \ + --push_to_hub +``` ### Inference - pivotal tuning Once training is done, we can perform inference like so: @@ -248,7 +284,8 @@ image.save("llama.png") ``` ### Inference - pure textual inversion -In this case, we don't load transformer layers as before, since we only optimize the embeddings +In this case, we don't load transformer layers as before, since we only optimize the text embeddings. The output of a textual inversion train run is a +`.safetensors` file containing the trained embeddings for the new tokens either for the CLIP encoder, or for both encoders (CLIP and T5) 1. starting with loading the embeddings. 💡note that here too, if you didn't enable `--enable_t5_ti`, you only load the embeddings to the CLIP encoder diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index a67a4bde9648..9a96efd44b78 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -1134,7 +1134,7 @@ def tokenize_prompt(tokenizer, prompt, max_sequence_length, add_special_tokens=F return text_input_ids -def _encode_prompt_with_t5( +def _get_t5_prompt_embeds( text_encoder, tokenizer, max_sequence_length=512, @@ -1175,7 +1175,7 @@ def _encode_prompt_with_t5( return prompt_embeds -def _encode_prompt_with_clip( +def _get_clip_prompt_embeds( text_encoder, tokenizer, prompt: str, @@ -1228,7 +1228,7 @@ def encode_prompt( batch_size = len(prompt) dtype = text_encoders[0].dtype - pooled_prompt_embeds = _encode_prompt_with_clip( + pooled_prompt_embeds = _get_clip_prompt_embeds( text_encoder=text_encoders[0], tokenizer=tokenizers[0], prompt=prompt, @@ -1237,7 +1237,7 @@ def encode_prompt( text_input_ids=text_input_ids_list[0] if text_input_ids_list is not None else None, ) - prompt_embeds = _encode_prompt_with_t5( + prompt_embeds = _get_t5_prompt_embeds( text_encoder=text_encoders[1], tokenizer=tokenizers[1], max_sequence_length=max_sequence_length, @@ -1704,7 +1704,7 @@ def load_model_hook(models, input_dir): # If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training freeze_text_encoder = not (args.train_text_encoder or args.train_text_encoder_ti) - # if --train_text_encoder_ti and train_transformer_frac == 0 where essntially performing textual inversion + # if --train_text_encoder_ti and train_transformer_frac == 0 where essentially performing textual inversion # and not training transformer LoRA layers pure_textual_inversion = args.train_text_encoder_ti and args.train_transformer_frac == 0 From 571e49c42dc9ee19868295a2ee975137567291ea Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 1 Oct 2024 15:06:05 +0200 Subject: [PATCH 46/82] te lr changes --- .../train_dreambooth_lora_flux_advanced.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 9a96efd44b78..57710832b35b 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -1717,7 +1717,7 @@ def load_model_hook(models, input_dir): "weight_decay": args.adam_weight_decay_text_encoder if args.adam_weight_decay_text_encoder else args.adam_weight_decay, - "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, + "lr": args.text_encoder_lr, } if not args.enable_t5_ti: # pure textual inversion - only clip @@ -1739,7 +1739,7 @@ def load_model_hook(models, input_dir): "weight_decay": args.adam_weight_decay_text_encoder if args.adam_weight_decay_text_encoder else args.adam_weight_decay, - "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, + "lr": args.text_encoder_lr, } # pure textual inversion - only clip & t5 if pure_textual_inversion: @@ -1783,7 +1783,6 @@ def load_model_hook(models, input_dir): optimizer_class = bnb.optim.AdamW8bit else: optimizer_class = torch.optim.AdamW - optimizer = optimizer_class( params_to_optimize, betas=(args.adam_beta1, args.adam_beta2), @@ -1803,16 +1802,17 @@ def load_model_hook(models, input_dir): logger.warning( "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" ) - if args.train_text_encoder and args.text_encoder_lr: + if not freeze_text_encoder and args.text_encoder_lr: logger.warning( f"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:" f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. " f"When using prodigy only learning_rate is used as the initial learning rate." ) - # changes the learning rate of text_encoder_parameters_one to be + # changes the learning rate of text_encoder_parameters to be # --learning_rate - params_to_optimize[1]["lr"] = args.learning_rate + params_to_optimize[te_idx]["lr"] = args.learning_rate + params_to_optimize[-1]["lr"] = args.learning_rate optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, From 94dbe85b55b2373513d4727a44b042bf4d2d4bd0 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 1 Oct 2024 18:23:24 +0200 Subject: [PATCH 47/82] readme --- examples/advanced_diffusion_training/README_flux.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/advanced_diffusion_training/README_flux.md b/examples/advanced_diffusion_training/README_flux.md index 6e51be2264bf..c3dd3522e5ca 100644 --- a/examples/advanced_diffusion_training/README_flux.md +++ b/examples/advanced_diffusion_training/README_flux.md @@ -322,7 +322,7 @@ image = pipe(prompt=prompt, num_inference_steps=25, cross_attention_kwargs={"sca image.save("llama.png") ``` -### Comfy UI / AUTOMATIC1111 Inference - **NEEDS TO BE UPDATED** +### Comfy UI / AUTOMATIC1111 Inference The new script fully supports textual inversion loading with Comfy UI and AUTOMATIC1111 formats! **AUTOMATIC1111 / SD.Next** \ From 5e751d497521f53004ec0ef292742d2b5a16883b Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 2 Oct 2024 11:56:40 +0200 Subject: [PATCH 48/82] change concept tokens logic --- .../train_dreambooth_lora_flux_advanced.py | 38 ++++++++++++------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 57710832b35b..00cbb2205cb9 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -370,18 +370,19 @@ def parse_args(input_args=None): parser.add_argument( "--num_new_tokens_per_abstraction", type=int, - default=2, + default=None, help="number of new tokens inserted to the tokenizers per token_abstraction identifier when " "--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new " "tokens - ", ) parser.add_argument( - "--initializer_token", + "--initializer_concept_tokens", type=str, - default="random", + default=None, help="the token (or tokens) to use to initialize the new inserted tokens when training with " "--train_text_encoder_ti = True. By default, new tokens () are initialized with random value. " - "Alternatively, you could specify a different token whos value will be used as the starting point for the new inserted tokens", + "Alternatively, you could specify a different token whos value will be used as the starting point for the new inserted tokens. " + "--num_new_tokens_per_abstraction is ignored when initializer_concept_tokens are provided" ) parser.add_argument( "--class_prompt", @@ -772,6 +773,10 @@ def parse_args(input_args=None): if args.enable_t5_ti and not args.train_text_encoder_ti: warnings.warn("You need not use --enable_t5_ti without --train_text_encoder_ti.") + if args.train_text_encoder_ti and args.initializer_concept_tokens and args.num_new_tokens_per_abstraction: + warnings.warn("When specifying --initializer_concept_tokens, the number of tokens per abstraction is detrimned " + "by the initializer token. --num_new_tokens_per_abstraction will be ignored") + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank @@ -826,7 +831,8 @@ def initialize_new_tokens(self, inserting_toks: List[str]): print(f"{idx} text encoder's std_token_embedding: {std_token_embedding}") - if args.initializer_token.lower == "random": + # if initializer_concept_tokens are not provided, token embeddings are initialized randomly + if args.initializer_concept_tokens is None: hidden_size = ( text_encoder.text_model.config.hidden_size if idx == 0 else text_encoder.encoder.config.hidden_size ) @@ -836,13 +842,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]): ) else: # Convert the initializer_token, placeholder_token to ids - token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) - # Check if initializer_token is a single token or a sequence of tokens - if len(token_ids) > 1: - raise ValueError("The initializer token must be a single token.") - initializer_token_id = token_ids[0] - for token_id in self.train_ids: - embeds.weight.data[token_id] = (embeds.weight.data)[initializer_token_id].clone() + initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) + for idx, token_id in enumerate(self.train_ids): + embeds.weight.data[token_id] = (embeds.weight.data)[initializer_token_ids[idx]].clone() self.embeddings_settings[f"original_embeddings_{idx}"] = embeds.weight.data.clone() self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding @@ -1506,13 +1508,21 @@ def main(args): token_abstraction_list = "".join(args.token_abstraction.split()).split(",") logger.info(f"list of token identifiers: {token_abstraction_list}") + if args.initializer_concept_tokens is None: + num_new_tokens_per_abstraction = 2 if args.num_new_tokens_per_abstraction is None else args.num_new_tokens_per_abstraction + # if args.initializer_concept_tokens is provided, we ignore args.num_new_tokens_per_abstraction + else: + token_ids = tokenizer.encode(args.initializer_concept_tokens, add_special_tokens=False) + num_new_tokens_per_abstraction = len(token_ids) + print(f"initializer_concept_tokens: {args.initializer_concept_tokens}, num_new_tokens_per_abstraction: {num_new_tokens_per_abstraction}") + token_abstraction_dict = {} token_idx = 0 for i, token in enumerate(token_abstraction_list): token_abstraction_dict[token] = [ - f"" for j in range(args.num_new_tokens_per_abstraction) + f"" for j in range(num_new_tokens_per_abstraction) ] - token_idx += args.num_new_tokens_per_abstraction - 1 + token_idx += num_new_tokens_per_abstraction - 1 # replace instances of --token_abstraction in --instance_prompt with the new tokens: "" etc. for token_abs, token_replacement in token_abstraction_dict.items(): From d9ed2b18f25d240831e0a051e2f4407b7466261c Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 2 Oct 2024 12:24:40 +0200 Subject: [PATCH 49/82] fix indices --- .../train_dreambooth_lora_flux_advanced.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 00cbb2205cb9..e1eabfa3aa13 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -842,9 +842,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]): ) else: # Convert the initializer_token, placeholder_token to ids - initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) - for idx, token_id in enumerate(self.train_ids): - embeds.weight.data[token_id] = (embeds.weight.data)[initializer_token_ids[idx]].clone() + initializer_token_ids = tokenizer.encode(args.initializer_concept_tokens, add_special_tokens=False) + for token_idx, token_id in enumerate(self.train_ids): + embeds.weight.data[token_id] = (embeds.weight.data)[initializer_token_ids[token_idx]].clone() self.embeddings_settings[f"original_embeddings_{idx}"] = embeds.weight.data.clone() self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding @@ -1512,8 +1512,11 @@ def main(args): num_new_tokens_per_abstraction = 2 if args.num_new_tokens_per_abstraction is None else args.num_new_tokens_per_abstraction # if args.initializer_concept_tokens is provided, we ignore args.num_new_tokens_per_abstraction else: - token_ids = tokenizer.encode(args.initializer_concept_tokens, add_special_tokens=False) + token_ids = tokenizer_one.encode(args.initializer_concept_tokens, add_special_tokens=False) num_new_tokens_per_abstraction = len(token_ids) + if args.enable_t5_ti: + token_ids_t5 = tokenizer_two.encode(args.initializer_concept_tokens, add_special_tokens=False) + num_new_tokens_per_abstraction = max(len(token_ids), len(token_ids_t5)) print(f"initializer_concept_tokens: {args.initializer_concept_tokens}, num_new_tokens_per_abstraction: {num_new_tokens_per_abstraction}") token_abstraction_dict = {} From b686d041ceb3b969357e6f1e7595df3763ff913b Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 2 Oct 2024 13:38:32 +0200 Subject: [PATCH 50/82] change arg name --- .../train_dreambooth_lora_flux_advanced.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index e1eabfa3aa13..7a96b71c807c 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -376,13 +376,13 @@ def parse_args(input_args=None): "tokens - ", ) parser.add_argument( - "--initializer_concept_tokens", + "--initializer_concept", type=str, default=None, - help="the token (or tokens) to use to initialize the new inserted tokens when training with " + help="the concept to use to initialize the new inserted tokens when training with " "--train_text_encoder_ti = True. By default, new tokens () are initialized with random value. " - "Alternatively, you could specify a different token whos value will be used as the starting point for the new inserted tokens. " - "--num_new_tokens_per_abstraction is ignored when initializer_concept_tokens are provided" + "Alternatively, you could specify a different word/words whos value will be used as the starting point for the new inserted tokens. " + "--num_new_tokens_per_abstraction is ignored when initializer_concept is provided" ) parser.add_argument( "--class_prompt", @@ -773,8 +773,8 @@ def parse_args(input_args=None): if args.enable_t5_ti and not args.train_text_encoder_ti: warnings.warn("You need not use --enable_t5_ti without --train_text_encoder_ti.") - if args.train_text_encoder_ti and args.initializer_concept_tokens and args.num_new_tokens_per_abstraction: - warnings.warn("When specifying --initializer_concept_tokens, the number of tokens per abstraction is detrimned " + if args.train_text_encoder_ti and args.initializer_concept and args.num_new_tokens_per_abstraction: + warnings.warn("When specifying --initializer_concept, the number of tokens per abstraction is detrimned " "by the initializer token. --num_new_tokens_per_abstraction will be ignored") env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -831,8 +831,8 @@ def initialize_new_tokens(self, inserting_toks: List[str]): print(f"{idx} text encoder's std_token_embedding: {std_token_embedding}") - # if initializer_concept_tokens are not provided, token embeddings are initialized randomly - if args.initializer_concept_tokens is None: + # if initializer_concept are not provided, token embeddings are initialized randomly + if args.initializer_concept is None: hidden_size = ( text_encoder.text_model.config.hidden_size if idx == 0 else text_encoder.encoder.config.hidden_size ) @@ -842,7 +842,7 @@ def initialize_new_tokens(self, inserting_toks: List[str]): ) else: # Convert the initializer_token, placeholder_token to ids - initializer_token_ids = tokenizer.encode(args.initializer_concept_tokens, add_special_tokens=False) + initializer_token_ids = tokenizer.encode(args.initializer_concept, add_special_tokens=False) for token_idx, token_id in enumerate(self.train_ids): embeds.weight.data[token_id] = (embeds.weight.data)[initializer_token_ids[token_idx]].clone() @@ -1508,16 +1508,16 @@ def main(args): token_abstraction_list = "".join(args.token_abstraction.split()).split(",") logger.info(f"list of token identifiers: {token_abstraction_list}") - if args.initializer_concept_tokens is None: + if args.initializer_concept is None: num_new_tokens_per_abstraction = 2 if args.num_new_tokens_per_abstraction is None else args.num_new_tokens_per_abstraction - # if args.initializer_concept_tokens is provided, we ignore args.num_new_tokens_per_abstraction + # if args.initializer_concept is provided, we ignore args.num_new_tokens_per_abstraction else: - token_ids = tokenizer_one.encode(args.initializer_concept_tokens, add_special_tokens=False) + token_ids = tokenizer_one.encode(args.initializer_concept, add_special_tokens=False) num_new_tokens_per_abstraction = len(token_ids) if args.enable_t5_ti: - token_ids_t5 = tokenizer_two.encode(args.initializer_concept_tokens, add_special_tokens=False) + token_ids_t5 = tokenizer_two.encode(args.initializer_concept, add_special_tokens=False) num_new_tokens_per_abstraction = max(len(token_ids), len(token_ids_t5)) - print(f"initializer_concept_tokens: {args.initializer_concept_tokens}, num_new_tokens_per_abstraction: {num_new_tokens_per_abstraction}") + print(f"initializer_concept: {args.initializer_concept}, num_new_tokens_per_abstraction: {num_new_tokens_per_abstraction}") token_abstraction_dict = {} token_idx = 0 From 3881dbb84ca3655b06b4bc73c1aa2ea09fc45d60 Mon Sep 17 00:00:00 2001 From: Linoy Date: Wed, 2 Oct 2024 11:39:35 +0000 Subject: [PATCH 51/82] style --- .../train_dreambooth_lora_flux_advanced.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 7a96b71c807c..fa703eff1e5b 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -382,7 +382,7 @@ def parse_args(input_args=None): help="the concept to use to initialize the new inserted tokens when training with " "--train_text_encoder_ti = True. By default, new tokens () are initialized with random value. " "Alternatively, you could specify a different word/words whos value will be used as the starting point for the new inserted tokens. " - "--num_new_tokens_per_abstraction is ignored when initializer_concept is provided" + "--num_new_tokens_per_abstraction is ignored when initializer_concept is provided", ) parser.add_argument( "--class_prompt", @@ -774,8 +774,10 @@ def parse_args(input_args=None): warnings.warn("You need not use --enable_t5_ti without --train_text_encoder_ti.") if args.train_text_encoder_ti and args.initializer_concept and args.num_new_tokens_per_abstraction: - warnings.warn("When specifying --initializer_concept, the number of tokens per abstraction is detrimned " - "by the initializer token. --num_new_tokens_per_abstraction will be ignored") + warnings.warn( + "When specifying --initializer_concept, the number of tokens per abstraction is detrimned " + "by the initializer token. --num_new_tokens_per_abstraction will be ignored" + ) env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: @@ -1509,7 +1511,9 @@ def main(args): logger.info(f"list of token identifiers: {token_abstraction_list}") if args.initializer_concept is None: - num_new_tokens_per_abstraction = 2 if args.num_new_tokens_per_abstraction is None else args.num_new_tokens_per_abstraction + num_new_tokens_per_abstraction = ( + 2 if args.num_new_tokens_per_abstraction is None else args.num_new_tokens_per_abstraction + ) # if args.initializer_concept is provided, we ignore args.num_new_tokens_per_abstraction else: token_ids = tokenizer_one.encode(args.initializer_concept, add_special_tokens=False) @@ -1517,14 +1521,14 @@ def main(args): if args.enable_t5_ti: token_ids_t5 = tokenizer_two.encode(args.initializer_concept, add_special_tokens=False) num_new_tokens_per_abstraction = max(len(token_ids), len(token_ids_t5)) - print(f"initializer_concept: {args.initializer_concept}, num_new_tokens_per_abstraction: {num_new_tokens_per_abstraction}") + print( + f"initializer_concept: {args.initializer_concept}, num_new_tokens_per_abstraction: {num_new_tokens_per_abstraction}" + ) token_abstraction_dict = {} token_idx = 0 for i, token in enumerate(token_abstraction_list): - token_abstraction_dict[token] = [ - f"" for j in range(num_new_tokens_per_abstraction) - ] + token_abstraction_dict[token] = [f"" for j in range(num_new_tokens_per_abstraction)] token_idx += num_new_tokens_per_abstraction - 1 # replace instances of --token_abstraction in --instance_prompt with the new tokens: "" etc. From b3e5caad058a191f9c247241a2b1b402a041f724 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 2 Oct 2024 17:52:08 +0200 Subject: [PATCH 52/82] dummy test --- src/diffusers/pipelines/flux/pipeline_flux.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 1424965a4baa..e380ba13ef48 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -217,8 +217,8 @@ def _get_t5_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + # if isinstance(self, TextualInversionLoaderMixin): + # prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) text_inputs = self.tokenizer_2( prompt, From ca668b66d55433fe31b787343e6632ecff5266c3 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 2 Oct 2024 18:03:40 +0200 Subject: [PATCH 53/82] revert dummy test --- src/diffusers/pipelines/flux/pipeline_flux.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index e380ba13ef48..1424965a4baa 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -217,8 +217,8 @@ def _get_t5_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) - # if isinstance(self, TextualInversionLoaderMixin): - # prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) text_inputs = self.tokenizer_2( prompt, From 7fcdc0d70ccfcd39c18a64a4e847531f41c5a1a9 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 2 Oct 2024 19:13:49 +0200 Subject: [PATCH 54/82] reorder pivoting --- .../train_dreambooth_lora_flux_advanced.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index fa703eff1e5b..15a5838a1a11 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -1701,6 +1701,7 @@ def load_model_hook(models, input_dir): text_lora_parameters_one = [] # CLIP for name, param in text_encoder_one.named_parameters(): if "token_embedding" in name: + print("YES 5") # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 param.data = param.to(dtype=torch.float32) param.requires_grad = True @@ -1708,6 +1709,7 @@ def load_model_hook(models, input_dir): else: param.requires_grad = False if args.enable_t5_ti: # whether to do pivotal tuning/textual inversion for T5 as well + print("NO") text_lora_parameters_two = [] for name, param in text_encoder_two.named_parameters(): if "token_embedding" in name: @@ -1724,6 +1726,7 @@ def load_model_hook(models, input_dir): # if --train_text_encoder_ti and train_transformer_frac == 0 where essentially performing textual inversion # and not training transformer LoRA layers pure_textual_inversion = args.train_text_encoder_ti and args.train_transformer_frac == 0 + print("NO 2:", pure_textual_inversion) # Optimization parameters transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} @@ -1744,6 +1747,7 @@ def load_model_hook(models, input_dir): ] te_idx = 0 else: # regular te training or regular pivotal for clip + print("YES1") params_to_optimize = [ transformer_parameters_with_lr, text_parameters_one_with_lr, @@ -2101,28 +2105,29 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # flag to stop text encoder optimization print("PIVOT TE", epoch) pivoted_te = True - if epoch == num_train_epochs_transformer: - # flag to stop transformer optimization - print("PIVOT TRANSFORMER", epoch) - pivoted_tr = True else: # still optimizing the text encoder - if args.train_text_encoder or not args.enable_t5_ti: + if args.train_text_encoder: text_encoder_one.train() # set top parameter requires_grad = True for gradient checkpointing works accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) - else: # textual inversion / pivotal tuning + elif args.train_text_encoder_ti: # textual inversion / pivotal tuning text_encoder_one.train() + if args.enable_t5_ti: text_encoder_two.train() + if epoch == num_train_epochs_transformer: + # flag to stop transformer optimization + print("PIVOT TRANSFORMER", epoch) + pivoted_tr = True + for step, batch in enumerate(train_dataloader): if pivoted_te: # stopping optimization of text_encoder params optimizer.param_groups[te_idx]["lr"] = 0.0 - if args.train_text_encoder_ti and args.enable_t5_ti: - optimizer.param_groups[te_idx + 1]["lr"] = 0.0 + optimizer.param_groups[-1]["lr"] = 0.0 elif pivoted_tr and not pure_textual_inversion: - print("PIVOT TRANSFORMER HELOOOO") + print("PIVOT TRANSFORMER") optimizer.param_groups[0]["lr"] = 0.0 with accelerator.accumulate(transformer): From 8cb79713481a74fbd9e8f921adf93fd1be69336b Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 4 Oct 2024 15:08:35 +0200 Subject: [PATCH 55/82] add warning in case the token abstraction is not the instance prompt --- .../train_dreambooth_lora_flux_advanced.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 15a5838a1a11..ab633ff80a8c 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -1533,7 +1533,11 @@ def main(args): # replace instances of --token_abstraction in --instance_prompt with the new tokens: "" etc. for token_abs, token_replacement in token_abstraction_dict.items(): - args.instance_prompt = args.instance_prompt.replace(token_abs, "".join(token_replacement)) + new_instance_prompt = args.instance_prompt.replace(token_abs, "".join(token_replacement)) + if args.instance_prompt == new_instance_prompt: + logger.warning("Note! the instance prompt provided in --instance_prompt does not include the token abstraction specified " + "--token_abstraction. This may lead to incorrect optimization of text embeddings during pivotal tuning") + args.instance_prompt = new_instance_prompt if args.with_prior_preservation: args.class_prompt = args.class_prompt.replace(token_abs, "".join(token_replacement)) if args.validation_prompt: From b75b3e61845e7eb62db8bb67edfe3e0f9837b021 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 4 Oct 2024 18:26:29 +0200 Subject: [PATCH 56/82] experimental - wip - specific block training --- .../train_dreambooth_lora_flux_advanced.py | 49 ++++++++++++++++++- 1 file changed, 47 insertions(+), 2 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index ab633ff80a8c..e785c6f406f5 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -569,7 +569,15 @@ def parse_args(input_args=None): default=3.5, help="the FLUX.1 dev variant is a guidance distilled model", ) - + parser.add_argument( + "--lora_transformer_blocks", + type=str, + default=None, + help=( + "the transformer blocks to tune during training. please specify them in a comma separated string, e.g. `transformer.single_transformer_blocks.7.proj_out,transformer.single_transformer_blocks.20.proj_out` etc." + "NOTE: By default (if not specified) - regular LoRA training is performed. " + ), + ) parser.add_argument( "--text_encoder_lr", type=float, @@ -1582,12 +1590,49 @@ def main(args): if args.train_text_encoder: text_encoder_one.gradient_checkpointing_enable() + # Taken (and slightly modified) from B-LoRA repo https://github.com/yardenfren1996/B-LoRA/blob/main/blora_utils.py + def is_belong_to_blocks(key, blocks): + try: + for g in blocks: + if g in key: + return True + return False + except Exception as e: + raise type(e)(f"failed to is_belong_to_block, due to: {e}") + + def get_transformer_lora_target_modules(transformer, target_blocks=None): + try: + blocks = [(".").join(blk.split(".")[1:]) for blk in target_blocks] + + attns = [ + attn_processor_name.rsplit(".", 1)[0] + for attn_processor_name, _ in transformer.attn_processors.items() + if is_belong_to_blocks(attn_processor_name, blocks) + ] + + target_modules = [f"{attn}.{mat}" for mat in ["to_k", "to_q", "to_v", "to_out.0"] for attn in attns] + return target_modules + except Exception as e: + raise type(e)( + f"failed to get_target_modules, due to: {e}. " + f"Please check the modules specified in --lora_transformer_blocks are correct" + ) + + + if args.lora_transformer_blocks: + # if training specific transformer blocks + target_blocks_list = "".join(args.lora_transformer_blocks.split()).split(",") + logger.info(f"list of unet blocks to train: {target_blocks_list}") + target_modules = get_transformer_lora_target_modules(unet, target_blocks=target_blocks_list) + else: + target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + # now we will add new LoRA weights to the attention layers transformer_lora_config = LoraConfig( r=args.rank, lora_alpha=args.rank, init_lora_weights="gaussian", - target_modules=["to_k", "to_q", "to_v", "to_out.0"], + target_modules=target_modules, ) transformer.add_adapter(transformer_lora_config) From 03a6b5b550cdb0151c9c70de4987d5ae8fc410ed Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 8 Oct 2024 17:56:37 +0300 Subject: [PATCH 57/82] fix documentation and token abstraction processing --- .../train_dreambooth_lora_flux_advanced.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index e785c6f406f5..c630aa8f1461 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -663,7 +663,7 @@ def parse_args(input_args=None): "uses the value of square root of beta2. Ignored if optimizer is adamW", ) parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") - parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for transformer params") parser.add_argument( "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" ) @@ -1515,7 +1515,7 @@ def main(args): if args.train_text_encoder_ti: # we parse the provided token identifier (or identifiers) into a list. s.t. - "TOK" -> ["TOK"], "TOK, # TOK2" -> ["TOK", "TOK2"] etc. - token_abstraction_list = "".join(args.token_abstraction.split()).split(",") + token_abstraction_list = [place_holder.strip() for place_holder in re.split(r',\s*', args.token_abstraction)] logger.info(f"list of token identifiers: {token_abstraction_list}") if args.initializer_concept is None: @@ -1622,8 +1622,8 @@ def get_transformer_lora_target_modules(transformer, target_blocks=None): if args.lora_transformer_blocks: # if training specific transformer blocks target_blocks_list = "".join(args.lora_transformer_blocks.split()).split(",") - logger.info(f"list of unet blocks to train: {target_blocks_list}") - target_modules = get_transformer_lora_target_modules(unet, target_blocks=target_blocks_list) + logger.info(f"list of transformer blocks to train: {target_blocks_list}") + target_modules = get_transformer_lora_target_modules(transformer, target_blocks=target_blocks_list) else: target_modules = ["to_k", "to_q", "to_v", "to_out.0"] @@ -1780,7 +1780,7 @@ def load_model_hook(models, input_dir): # Optimization parameters transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} if not freeze_text_encoder: - # different learning rate for text encoder and unet + # different learning rate for text encoder and transformer text_parameters_one_with_lr = { "params": text_lora_parameters_one, "weight_decay": args.adam_weight_decay_text_encoder From 749e857379aeabc0ac5d7951219b4373c693f7bf Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 8 Oct 2024 17:58:15 +0300 Subject: [PATCH 58/82] remove transformer block specification feature (for now) --- .../train_dreambooth_lora_flux_advanced.py | 48 +------------------ 1 file changed, 1 insertion(+), 47 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index c630aa8f1461..8d6ac98591ba 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -569,15 +569,6 @@ def parse_args(input_args=None): default=3.5, help="the FLUX.1 dev variant is a guidance distilled model", ) - parser.add_argument( - "--lora_transformer_blocks", - type=str, - default=None, - help=( - "the transformer blocks to tune during training. please specify them in a comma separated string, e.g. `transformer.single_transformer_blocks.7.proj_out,transformer.single_transformer_blocks.20.proj_out` etc." - "NOTE: By default (if not specified) - regular LoRA training is performed. " - ), - ) parser.add_argument( "--text_encoder_lr", type=float, @@ -1590,49 +1581,12 @@ def main(args): if args.train_text_encoder: text_encoder_one.gradient_checkpointing_enable() - # Taken (and slightly modified) from B-LoRA repo https://github.com/yardenfren1996/B-LoRA/blob/main/blora_utils.py - def is_belong_to_blocks(key, blocks): - try: - for g in blocks: - if g in key: - return True - return False - except Exception as e: - raise type(e)(f"failed to is_belong_to_block, due to: {e}") - - def get_transformer_lora_target_modules(transformer, target_blocks=None): - try: - blocks = [(".").join(blk.split(".")[1:]) for blk in target_blocks] - - attns = [ - attn_processor_name.rsplit(".", 1)[0] - for attn_processor_name, _ in transformer.attn_processors.items() - if is_belong_to_blocks(attn_processor_name, blocks) - ] - - target_modules = [f"{attn}.{mat}" for mat in ["to_k", "to_q", "to_v", "to_out.0"] for attn in attns] - return target_modules - except Exception as e: - raise type(e)( - f"failed to get_target_modules, due to: {e}. " - f"Please check the modules specified in --lora_transformer_blocks are correct" - ) - - - if args.lora_transformer_blocks: - # if training specific transformer blocks - target_blocks_list = "".join(args.lora_transformer_blocks.split()).split(",") - logger.info(f"list of transformer blocks to train: {target_blocks_list}") - target_modules = get_transformer_lora_target_modules(transformer, target_blocks=target_blocks_list) - else: - target_modules = ["to_k", "to_q", "to_v", "to_out.0"] - # now we will add new LoRA weights to the attention layers transformer_lora_config = LoraConfig( r=args.rank, lora_alpha=args.rank, init_lora_weights="gaussian", - target_modules=target_modules, + target_modules=["to_k", "to_q", "to_v", "to_out.0"], ) transformer.add_adapter(transformer_lora_config) From c8ddd83cbbddbadfa2003c4f87d7259de4710a08 Mon Sep 17 00:00:00 2001 From: Linoy Date: Thu, 10 Oct 2024 07:17:52 +0000 Subject: [PATCH 59/82] style --- .../train_dreambooth_lora_flux_advanced.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 8d6ac98591ba..ab27d2524478 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -654,7 +654,9 @@ def parse_args(input_args=None): "uses the value of square root of beta2. Ignored if optimizer is adamW", ) parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") - parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for transformer params") + parser.add_argument( + "--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for transformer params" + ) parser.add_argument( "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" ) @@ -1506,7 +1508,7 @@ def main(args): if args.train_text_encoder_ti: # we parse the provided token identifier (or identifiers) into a list. s.t. - "TOK" -> ["TOK"], "TOK, # TOK2" -> ["TOK", "TOK2"] etc. - token_abstraction_list = [place_holder.strip() for place_holder in re.split(r',\s*', args.token_abstraction)] + token_abstraction_list = [place_holder.strip() for place_holder in re.split(r",\s*", args.token_abstraction)] logger.info(f"list of token identifiers: {token_abstraction_list}") if args.initializer_concept is None: @@ -1534,8 +1536,10 @@ def main(args): for token_abs, token_replacement in token_abstraction_dict.items(): new_instance_prompt = args.instance_prompt.replace(token_abs, "".join(token_replacement)) if args.instance_prompt == new_instance_prompt: - logger.warning("Note! the instance prompt provided in --instance_prompt does not include the token abstraction specified " - "--token_abstraction. This may lead to incorrect optimization of text embeddings during pivotal tuning") + logger.warning( + "Note! the instance prompt provided in --instance_prompt does not include the token abstraction specified " + "--token_abstraction. This may lead to incorrect optimization of text embeddings during pivotal tuning" + ) args.instance_prompt = new_instance_prompt if args.with_prior_preservation: args.class_prompt = args.class_prompt.replace(token_abs, "".join(token_replacement)) From 43c2cd55027acb027ed95867e57f4f862f7fb6e0 Mon Sep 17 00:00:00 2001 From: Linoy Date: Thu, 10 Oct 2024 09:06:03 +0000 Subject: [PATCH 60/82] fix copies --- .../flux/pipeline_flux_controlnet_image_to_image.py | 6 ++++++ .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 6 ++++++ src/diffusers/pipelines/flux/pipeline_flux_img2img.py | 6 ++++++ src/diffusers/pipelines/flux/pipeline_flux_inpaint.py | 6 ++++++ 4 files changed, 24 insertions(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index deeb9e3f546a..61fa8593eaed 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -247,6 +247,9 @@ def _get_t5_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + text_inputs = self.tokenizer_2( prompt, padding="max_length", @@ -291,6 +294,9 @@ def _get_clip_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index e763200155f6..44a058440f93 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -257,6 +257,9 @@ def _get_t5_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + text_inputs = self.tokenizer_2( prompt, padding="max_length", @@ -301,6 +304,9 @@ def _get_clip_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py index bee4f6ce52e7..c34221a96740 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py @@ -235,6 +235,9 @@ def _get_t5_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + text_inputs = self.tokenizer_2( prompt, padding="max_length", @@ -279,6 +282,9 @@ def _get_clip_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py index 460336700241..f59ae80ed147 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py @@ -239,6 +239,9 @@ def _get_t5_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + text_inputs = self.tokenizer_2( prompt, padding="max_length", @@ -283,6 +286,9 @@ def _get_clip_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", From 2ac68983afa633dd6146b429cc394d0a3564f27b Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 10 Oct 2024 12:09:30 +0300 Subject: [PATCH 61/82] fix indexing issue when --initializer_concept has different amounts --- .../train_dreambooth_lora_flux_advanced.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index ab27d2524478..b9fbaf9d102c 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -847,7 +847,7 @@ def initialize_new_tokens(self, inserting_toks: List[str]): # Convert the initializer_token, placeholder_token to ids initializer_token_ids = tokenizer.encode(args.initializer_concept, add_special_tokens=False) for token_idx, token_id in enumerate(self.train_ids): - embeds.weight.data[token_id] = (embeds.weight.data)[initializer_token_ids[token_idx]].clone() + embeds.weight.data[token_id] = (embeds.weight.data)[initializer_token_ids[token_idx % len(initializer_token_ids)]].clone() self.embeddings_settings[f"original_embeddings_{idx}"] = embeds.weight.data.clone() self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding From d2cd0bf7eb684c313f677f3a32ee7bc0d8e7d377 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 10 Oct 2024 12:13:51 +0300 Subject: [PATCH 62/82] add if TextualInversionLoaderMixin to all flux pipelines --- src/diffusers/pipelines/flux/pipeline_flux_controlnet.py | 8 +++++++- .../flux/pipeline_flux_controlnet_image_to_image.py | 2 +- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 2 +- src/diffusers/pipelines/flux/pipeline_flux_img2img.py | 2 +- src/diffusers/pipelines/flux/pipeline_flux_inpaint.py | 2 +- 5 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 6c072c482020..01e6bfa79fa1 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -27,7 +27,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin from ...models.autoencoders import AutoencoderKL -from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel +from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel, TextualInversionLoaderMixin from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( @@ -234,6 +234,9 @@ def _get_t5_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer_2( prompt, padding="max_length", @@ -277,6 +280,9 @@ def _get_clip_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index 61fa8593eaed..ec6a9de6e843 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -13,7 +13,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin from ...models.autoencoders import AutoencoderKL -from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel +from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel, TextualInversionLoaderMixin from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 44a058440f93..0e5a6d4a3570 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -14,7 +14,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin from ...models.autoencoders import AutoencoderKL -from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel +from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel, TextualInversionLoaderMixin from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py index c34221a96740..0f6e38e66cae 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py @@ -20,7 +20,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin +from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py index f59ae80ed147..202335967032 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py @@ -21,7 +21,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin +from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler From f1879bf87846223fe6144857198eb282b69bd78d Mon Sep 17 00:00:00 2001 From: Linoy Date: Thu, 10 Oct 2024 09:14:38 +0000 Subject: [PATCH 63/82] style --- .../train_dreambooth_lora_flux_advanced.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index b9fbaf9d102c..d4712cf126f1 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -847,7 +847,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]): # Convert the initializer_token, placeholder_token to ids initializer_token_ids = tokenizer.encode(args.initializer_concept, add_special_tokens=False) for token_idx, token_id in enumerate(self.train_ids): - embeds.weight.data[token_id] = (embeds.weight.data)[initializer_token_ids[token_idx % len(initializer_token_ids)]].clone() + embeds.weight.data[token_id] = (embeds.weight.data)[ + initializer_token_ids[token_idx % len(initializer_token_ids)] + ].clone() self.embeddings_settings[f"original_embeddings_{idx}"] = embeds.weight.data.clone() self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding From 20762bc4b5b8c7783c5b717d1bf6cdbc13dafea0 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 10 Oct 2024 12:18:31 +0300 Subject: [PATCH 64/82] fix import --- src/diffusers/pipelines/flux/pipeline_flux_controlnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 01e6bfa79fa1..d4431d80066e 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -25,9 +25,9 @@ ) from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL -from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel, TextualInversionLoaderMixin +from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( From 0e6d31e578fda0d02276f8777585cebf4dc4806c Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 10 Oct 2024 16:58:53 +0300 Subject: [PATCH 65/82] fix imports --- .../pipelines/flux/pipeline_flux_controlnet_image_to_image.py | 4 ++-- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index ec6a9de6e843..3b20e39a6c74 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -11,9 +11,9 @@ ) from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL -from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel, TextualInversionLoaderMixin +from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 0e5a6d4a3570..91cb5476b40b 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -12,9 +12,9 @@ ) from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL -from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel, TextualInversionLoaderMixin +from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( From 08aafc070ddeecb23149e0216bbac8ddfe096d39 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 10 Oct 2024 19:55:02 +0300 Subject: [PATCH 66/82] address review comments - remove necessary prints & comments, use pin_memory=True, use free_memory utils, unify warning and prints --- .../train_dreambooth_lora_flux_advanced.py | 57 +++++++------------ 1 file changed, 21 insertions(+), 36 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index d4712cf126f1..747b5c32d473 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -61,6 +61,7 @@ cast_training_params, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, + free_memory ) from diffusers.utils import ( check_min_version, @@ -229,7 +230,6 @@ def log_validation( # run inference generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None - # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() autocast_ctx = nullcontext() with autocast_ctx: @@ -250,8 +250,7 @@ def log_validation( ) del pipeline - if torch.cuda.is_available(): - torch.cuda.empty_cache() + free_memory() return images @@ -772,10 +771,10 @@ def parse_args(input_args=None): "This contradicts with --max_train_steps, please specify different values or set both to 1." ) if args.enable_t5_ti and not args.train_text_encoder_ti: - warnings.warn("You need not use --enable_t5_ti without --train_text_encoder_ti.") + logger.warning("You need not use --enable_t5_ti without --train_text_encoder_ti.") if args.train_text_encoder_ti and args.initializer_concept and args.num_new_tokens_per_abstraction: - warnings.warn( + logger.warning( "When specifying --initializer_concept, the number of tokens per abstraction is detrimned " "by the initializer token. --num_new_tokens_per_abstraction will be ignored" ) @@ -790,11 +789,10 @@ def parse_args(input_args=None): if args.class_prompt is None: raise ValueError("You must specify prompt for class images.") else: - # logger is not available yet if args.class_data_dir is not None: - warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + logger.warning("You need not use --class_data_dir without --with_prior_preservation.") if args.class_prompt is not None: - warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + logger.warning("You need not use --class_prompt without --with_prior_preservation.") return args @@ -832,7 +830,7 @@ def initialize_new_tokens(self, inserting_toks: List[str]): ) std_token_embedding = embeds.weight.data.std() - print(f"{idx} text encoder's std_token_embedding: {std_token_embedding}") + logger.info(f"{idx} text encoder's std_token_embedding: {std_token_embedding}") # if initializer_concept are not provided, token embeddings are initialized randomly if args.initializer_concept is None: @@ -860,7 +858,7 @@ def initialize_new_tokens(self, inserting_toks: List[str]): self.embeddings_settings[f"index_no_updates_{idx}"] = index_no_updates - print(self.embeddings_settings[f"index_no_updates_{idx}"].shape) + logger.info(self.embeddings_settings[f"index_no_updates_{idx}"].shape) idx += 1 @@ -1457,8 +1455,7 @@ def main(args): image.save(image_filename) del pipeline - if torch.cuda.is_available(): - torch.cuda.empty_cache() + free_memory() # Handle the repository creation if accelerator.is_main_process: @@ -1524,7 +1521,7 @@ def main(args): if args.enable_t5_ti: token_ids_t5 = tokenizer_two.encode(args.initializer_concept, add_special_tokens=False) num_new_tokens_per_abstraction = max(len(token_ids), len(token_ids_t5)) - print( + logger.info( f"initializer_concept: {args.initializer_concept}, num_new_tokens_per_abstraction: {num_new_tokens_per_abstraction}" ) @@ -1710,7 +1707,6 @@ def load_model_hook(models, input_dir): text_lora_parameters_one = [] # CLIP for name, param in text_encoder_one.named_parameters(): if "token_embedding" in name: - print("YES 5") # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 param.data = param.to(dtype=torch.float32) param.requires_grad = True @@ -1718,7 +1714,6 @@ def load_model_hook(models, input_dir): else: param.requires_grad = False if args.enable_t5_ti: # whether to do pivotal tuning/textual inversion for T5 as well - print("NO") text_lora_parameters_two = [] for name, param in text_encoder_two.named_parameters(): if "token_embedding" in name: @@ -1735,7 +1730,6 @@ def load_model_hook(models, input_dir): # if --train_text_encoder_ti and train_transformer_frac == 0 where essentially performing textual inversion # and not training transformer LoRA layers pure_textual_inversion = args.train_text_encoder_ti and args.train_transformer_frac == 0 - print("NO 2:", pure_textual_inversion) # Optimization parameters transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} @@ -1756,7 +1750,6 @@ def load_model_hook(models, input_dir): ] te_idx = 0 else: # regular te training or regular pivotal for clip - print("YES1") params_to_optimize = [ transformer_parameters_with_lr, text_parameters_one_with_lr, @@ -1908,12 +1901,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Clear the memory here if freeze_text_encoder and not train_dataset.custom_instance_prompts: - del tokenizers, text_encoders - # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection - del text_encoder_one, text_encoder_two - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() + del tokenizers, text_encoders, text_encoder_one, text_encoder_two + free_memory() # if --train_text_encoder_ti we need add_special_tokens to be True for textual inversion add_special_tokens_clip = True if args.train_text_encoder_ti else False @@ -1968,15 +1957,13 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): for batch in tqdm(train_dataloader, desc="Caching latents"): with torch.no_grad(): batch["pixel_values"] = batch["pixel_values"].to( - accelerator.device, non_blocking=True, dtype=weight_dtype + accelerator.device, non_blocking=True, dtype=weight_dtype, pin_memory=True ) latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) if args.validation_prompt is None: del vae - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + free_memory() # Scheduler and math around the number of training steps. overrode_max_train_steps = False @@ -2061,13 +2048,13 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): path = dirs[-1] if len(dirs) > 0 else None if path is None: - accelerator.print( + logger.info( f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None initial_global_step = 0 else: - accelerator.print(f"Resuming from checkpoint {path}") + logger.info(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) @@ -2112,7 +2099,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.train_text_encoder or args.train_text_encoder_ti: if epoch == num_train_epochs_text_encoder: # flag to stop text encoder optimization - print("PIVOT TE", epoch) + logger.info("PIVOT TE", epoch) pivoted_te = True else: # still optimizing the text encoder @@ -2127,7 +2114,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if epoch == num_train_epochs_transformer: # flag to stop transformer optimization - print("PIVOT TRANSFORMER", epoch) + logger.info("PIVOT TRANSFORMER", epoch) pivoted_tr = True for step, batch in enumerate(train_dataloader): @@ -2136,7 +2123,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): optimizer.param_groups[te_idx]["lr"] = 0.0 optimizer.param_groups[-1]["lr"] = 0.0 elif pivoted_tr and not pure_textual_inversion: - print("PIVOT TRANSFORMER") + logger.info("PIVOT TRANSFORMER") optimizer.param_groups[0]["lr"] = 0.0 with accelerator.accumulate(transformer): @@ -2363,12 +2350,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) if freeze_text_encoder: del text_encoder_one, text_encoder_two - torch.cuda.empty_cache() - gc.collect() + free_memory() elif args.train_text_encoder: del text_encoder_two - torch.cuda.empty_cache() - gc.collect() + free_memory() # Save the lora layers accelerator.wait_for_everyone() From c5b24229d4a1883ed38b0c3a0bef2c6c2be09e5a Mon Sep 17 00:00:00 2001 From: Linoy Date: Thu, 10 Oct 2024 16:56:14 +0000 Subject: [PATCH 67/82] style --- .../train_dreambooth_lora_flux_advanced.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 747b5c32d473..f70426ee566a 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -15,7 +15,6 @@ import argparse import copy -import gc import itertools import logging import math @@ -23,7 +22,6 @@ import random import re import shutil -import warnings from contextlib import nullcontext from pathlib import Path from typing import List, Optional, Union @@ -61,7 +59,7 @@ cast_training_params, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, - free_memory + free_memory, ) from diffusers.utils import ( check_min_version, @@ -2048,9 +2046,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): path = dirs[-1] if len(dirs) > 0 else None if path is None: - logger.info( - f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." - ) + logger.info(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.") args.resume_from_checkpoint = None initial_global_step = 0 else: From 4b117194d299b633daf5ad9e1de846306d9838c9 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 10 Oct 2024 20:07:20 +0300 Subject: [PATCH 68/82] logger info fix --- .../train_dreambooth_lora_flux_advanced.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 747b5c32d473..dcbfae5188d2 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -2099,7 +2099,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.train_text_encoder or args.train_text_encoder_ti: if epoch == num_train_epochs_text_encoder: # flag to stop text encoder optimization - logger.info("PIVOT TE", epoch) + logger.info(f"PIVOT TE {epoch}") pivoted_te = True else: # still optimizing the text encoder @@ -2114,7 +2114,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if epoch == num_train_epochs_transformer: # flag to stop transformer optimization - logger.info("PIVOT TRANSFORMER", epoch) + logger.info(f"PIVOT TRANSFORMER {epoch}") pivoted_tr = True for step, batch in enumerate(train_dataloader): @@ -2123,7 +2123,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): optimizer.param_groups[te_idx]["lr"] = 0.0 optimizer.param_groups[-1]["lr"] = 0.0 elif pivoted_tr and not pure_textual_inversion: - logger.info("PIVOT TRANSFORMER") + logger.info(f"PIVOT TRANSFORMER {epoch}") optimizer.param_groups[0]["lr"] = 0.0 with accelerator.accumulate(transformer): From 6e2cb75648ccef379a93626762806c29c5cfb3dd Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 11 Oct 2024 10:54:48 +0300 Subject: [PATCH 69/82] make lora target modules configurable and change the default --- .../train_dreambooth_lora_flux_advanced.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 7bf2bb309cf7..75cf14f19588 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -657,7 +657,12 @@ def parse_args(input_args=None): parser.add_argument( "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" ) - + parser.add_argument( + "--lora_blocks", + type=str, + default=None, + help=('The transformer modules to apply LoRA training on'), + ) parser.add_argument( "--adam_epsilon", type=float, @@ -1582,12 +1587,17 @@ def main(args): if args.train_text_encoder: text_encoder_one.gradient_checkpointing_enable() + if args.lora_blocks is not None: + target_modules = [block.strip() for block in args.lora_blocks.split(",")] + else: + target_modules = ["to_k", "to_q", "to_v", "to_out.0", + "add_k_proj", "add_q_proj", "add_v_proj", "to_add_out", "ff.net.0.proj","ff.net.2", "ff_context.net.0.proj","ff_context.net.2"] # now we will add new LoRA weights to the attention layers transformer_lora_config = LoraConfig( r=args.rank, lora_alpha=args.rank, init_lora_weights="gaussian", - target_modules=["to_k", "to_q", "to_v", "to_out.0"], + target_modules=target_modules, ) transformer.add_adapter(transformer_lora_config) From 717b5adee956153f301834dcd154baaeb277d9b7 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 11 Oct 2024 11:57:24 +0300 Subject: [PATCH 70/82] make lora target modules configurable and change the default --- .../train_dreambooth_lora_flux_advanced.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 75cf14f19588..d423f67bfc5c 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -661,7 +661,8 @@ def parse_args(input_args=None): "--lora_blocks", type=str, default=None, - help=('The transformer modules to apply LoRA training on'), + help=( + 'The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "q_proj,k_proj,v_proj,out_proj" will result in lora training of attention layers only'), ) parser.add_argument( "--adam_epsilon", From 0fde49a3a9f06c08e0f43d87cf87340858f577f6 Mon Sep 17 00:00:00 2001 From: Linoy Date: Fri, 11 Oct 2024 08:59:47 +0000 Subject: [PATCH 71/82] style --- .../train_dreambooth_lora_flux_advanced.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index d423f67bfc5c..14c92a1b4e0a 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -662,7 +662,8 @@ def parse_args(input_args=None): type=str, default=None, help=( - 'The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "q_proj,k_proj,v_proj,out_proj" will result in lora training of attention layers only'), + 'The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "q_proj,k_proj,v_proj,out_proj" will result in lora training of attention layers only' + ), ) parser.add_argument( "--adam_epsilon", @@ -1591,8 +1592,20 @@ def main(args): if args.lora_blocks is not None: target_modules = [block.strip() for block in args.lora_blocks.split(",")] else: - target_modules = ["to_k", "to_q", "to_v", "to_out.0", - "add_k_proj", "add_q_proj", "add_v_proj", "to_add_out", "ff.net.0.proj","ff.net.2", "ff_context.net.0.proj","ff_context.net.2"] + target_modules = [ + "to_k", + "to_q", + "to_v", + "to_out.0", + "add_k_proj", + "add_q_proj", + "add_v_proj", + "to_add_out", + "ff.net.0.proj", + "ff.net.2", + "ff_context.net.0.proj", + "ff_context.net.2", + ] # now we will add new LoRA weights to the attention layers transformer_lora_config = LoraConfig( r=args.rank, From e4fe609e2fd4ee23feab12d95526cbbdb331b70e Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 14 Oct 2024 21:51:59 +0300 Subject: [PATCH 72/82] make lora target modules configurable and change the default, add notes to readme --- .../README_flux.md | 17 ++++++++++++- .../train_dreambooth_lora_flux_advanced.py | 25 ++++++++++--------- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/examples/advanced_diffusion_training/README_flux.md b/examples/advanced_diffusion_training/README_flux.md index c3dd3522e5ca..e755fd8b61e0 100644 --- a/examples/advanced_diffusion_training/README_flux.md +++ b/examples/advanced_diffusion_training/README_flux.md @@ -5,7 +5,7 @@ > 💡 This example follows some of the techniques and recommended practices covered in the community derived guide we made for SDXL training: [LoRA training scripts of the world, unite!](https://huggingface.co/blog/sdxl_lora_advanced_script). > As many of these are architecture agnostic & generally relevant to fine-tuning of diffusion models we suggest to take a look 🤗 -[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like flux, stable diffusion given just a few(3~5) images of a subject. +[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text-to-image models like flux, stable diffusion given just a few(3~5) images of a subject. LoRA - Low-Rank Adaption of Large Language Models, was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen* In a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages: @@ -65,6 +65,21 @@ write_basic_config() When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment. +### Target Modules +When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them. +More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore +applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma seperated string +the exact modules for LoRA training. Here are some examples of target modules you can provide: +- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0"` +- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2"` +- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out"` +> [!NOTE] +> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma seperated string: +> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k` +> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k` +> [!NOTE] +> keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights. + ### Pivotal Tuning (and more) **Training with text encoder(s)** diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 14c92a1b4e0a..bd50e989d895 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -658,11 +658,12 @@ def parse_args(input_args=None): "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" ) parser.add_argument( - "--lora_blocks", + "--lora_layers", type=str, default=None, help=( - 'The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "q_proj,k_proj,v_proj,out_proj" will result in lora training of attention layers only' + 'The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. ' + 'E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/README_flux.md' ), ) parser.add_argument( @@ -1589,18 +1590,18 @@ def main(args): if args.train_text_encoder: text_encoder_one.gradient_checkpointing_enable() - if args.lora_blocks is not None: - target_modules = [block.strip() for block in args.lora_blocks.split(",")] + if args.lora_layers is not None: + target_modules = [layer.strip() for layer in args.lora_layers.split(",")] else: target_modules = [ - "to_k", - "to_q", - "to_v", - "to_out.0", - "add_k_proj", - "add_q_proj", - "add_v_proj", - "to_add_out", + "attn.to_k", + "attn.to_q", + "attn.to_v", + "attn.to_out.0", + "attn.add_k_proj", + "attn.add_q_proj", + "attn.add_v_proj", + "attn.to_add_out", "ff.net.0.proj", "ff.net.2", "ff_context.net.0.proj", From 3d0955bc8100273479cb6fd7b906ce10abe8b866 Mon Sep 17 00:00:00 2001 From: Linoy Date: Mon, 14 Oct 2024 18:57:24 +0000 Subject: [PATCH 73/82] style --- .../train_dreambooth_lora_flux_advanced.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index bd50e989d895..434902023605 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -662,7 +662,7 @@ def parse_args(input_args=None): type=str, default=None, help=( - 'The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. ' + "The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. " 'E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/README_flux.md' ), ) From cb265adb48fa4d3f006694e5de4a49217da7154c Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 15 Oct 2024 18:42:01 +0300 Subject: [PATCH 74/82] add tests --- .../test_dreambooth_lora_flux_advanced.py | 279 ++++++++++++++++++ 1 file changed, 279 insertions(+) create mode 100644 examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py diff --git a/examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py new file mode 100644 index 000000000000..776bce58e16a --- /dev/null +++ b/examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py @@ -0,0 +1,279 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import sys +import tempfile + +import safetensors + + +sys.path.append("..") +from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) + + +class DreamBoothLoRAFluxAdvanced(ExamplesTestsAccelerate): + instance_data_dir = "docs/source/en/imgs" + instance_prompt = "photo" + pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe" + script_path = "examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py" + + def test_dreambooth_lora_flux(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. + starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_text_encoder_flux(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --train_text_encoder + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + starts_with_expected_prefix = all( + (key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys() + ) + self.assertTrue(starts_with_expected_prefix) + + def test_dreambooth_lora_pivotal_tuning_flux_clip(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --train_text_encoder_ti + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + # make sure embeddings were also saved + self.assertTrue(os.path.isfile(os.path.join(tmpdir, f"{tmpdir}_emb.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # make sure the state_dict has the correct naming in the parameters. + textual_inversion_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, f"{tmpdir}_emb.safetensors")) + is_clip = all("clip_l" in k for k in textual_inversion_state_dict.keys()) + self.assertTrue(is_clip) + + # when performing pivotal tuning, all the parameters in the state dict should start + # with `"transformer"` in their names. + starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_pivotal_tuning_flux_clip_t5(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --train_text_encoder_ti + --enable_t5_ti + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + # make sure embeddings were also saved + self.assertTrue(os.path.isfile(os.path.join(tmpdir, f"{tmpdir}_emb.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # make sure the state_dict has the correct naming in the parameters. + textual_inversion_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, f"{tmpdir}_emb.safetensors")) + is_te = all(("clip_l" in k or "t5" in k) for k in textual_inversion_state_dict.keys()) + self.assertTrue(is_te) + + # when performing pivotal tuning, all the parameters in the state dict should start + # with `"transformer"` in their names. + starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_latent_caching(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --cache_latents + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. + starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --instance_prompt={self.instance_prompt} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=6 + --checkpoints_total_limit=2 + --checkpointing_steps=2 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --instance_prompt={self.instance_prompt} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=4 + --checkpointing_steps=2 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"}) + + resume_run_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --instance_prompt={self.instance_prompt} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=8 + --checkpointing_steps=2 + --resume_from_checkpoint=checkpoint-4 + --checkpoints_total_limit=2 + """.split() + + run_command(self._launch_args + resume_run_args) + + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"}) From 61426f0ce40caca6ed47ed8b55cb0cdf3c7cb00a Mon Sep 17 00:00:00 2001 From: Linoy Date: Tue, 15 Oct 2024 15:45:09 +0000 Subject: [PATCH 75/82] style --- .../test_dreambooth_lora_flux_advanced.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py index 776bce58e16a..512b5cca39bd 100644 --- a/examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py @@ -134,7 +134,9 @@ def test_dreambooth_lora_pivotal_tuning_flux_clip(self): self.assertTrue(is_lora) # make sure the state_dict has the correct naming in the parameters. - textual_inversion_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, f"{tmpdir}_emb.safetensors")) + textual_inversion_state_dict = safetensors.torch.load_file( + os.path.join(tmpdir, f"{tmpdir}_emb.safetensors") + ) is_clip = all("clip_l" in k for k in textual_inversion_state_dict.keys()) self.assertTrue(is_clip) @@ -175,7 +177,9 @@ def test_dreambooth_lora_pivotal_tuning_flux_clip_t5(self): self.assertTrue(is_lora) # make sure the state_dict has the correct naming in the parameters. - textual_inversion_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, f"{tmpdir}_emb.safetensors")) + textual_inversion_state_dict = safetensors.torch.load_file( + os.path.join(tmpdir, f"{tmpdir}_emb.safetensors") + ) is_te = all(("clip_l" in k or "t5" in k) for k in textual_inversion_state_dict.keys()) self.assertTrue(is_te) From 03f19f67e554314553a93d503f9dd285f78cbc6c Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 15 Oct 2024 19:13:47 +0300 Subject: [PATCH 76/82] fix repo id --- .../train_dreambooth_lora_flux_advanced.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 434902023605..2ed47d178ba5 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -1468,9 +1468,11 @@ def main(args): if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) + model_id = args.hub_model_id or Path(args.output_dir).name + repo_id = None if args.push_to_hub: repo_id = create_repo( - repo_id=args.hub_model_id or Path(args.output_dir).name, + repo_id=model_id, exist_ok=True, ).repo_id @@ -2427,7 +2429,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) save_model_card( - repo_id, + model_id if not args.push_to_hub else repo_id, images=images, base_model=args.pretrained_model_name_or_path, train_text_encoder=args.train_text_encoder, From bd2be329a5788ecdb3afc4feebd8c74722c0bd52 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 16 Oct 2024 10:57:01 +0300 Subject: [PATCH 77/82] add updated requirements for advanced flux --- .../advanced_diffusion_training/requirements_flux.txt | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 examples/advanced_diffusion_training/requirements_flux.txt diff --git a/examples/advanced_diffusion_training/requirements_flux.txt b/examples/advanced_diffusion_training/requirements_flux.txt new file mode 100644 index 000000000000..dbc124ff6526 --- /dev/null +++ b/examples/advanced_diffusion_training/requirements_flux.txt @@ -0,0 +1,8 @@ +accelerate>=0.31.0 +torchvision +transformers>=4.41.2 +ftfy +tensorboard +Jinja2 +peft>=0.11.1 +sentencepiece \ No newline at end of file From 69d28b5278a36debae04f4ca6d31996c45e2b6af Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 16 Oct 2024 11:15:00 +0300 Subject: [PATCH 78/82] fix indices of t5 pivotal tuning embeddings --- .../train_dreambooth_lora_flux_advanced.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 2ed47d178ba5..e9584f29ae81 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -810,6 +810,7 @@ def __init__(self, text_encoders, tokenizers): self.tokenizers = tokenizers self.train_ids: Optional[torch.Tensor] = None + self.train_ids_t5: Optional[torch.Tensor] = None self.inserting_toks: Optional[List[str]] = None self.embeddings_settings = {} @@ -828,7 +829,10 @@ def initialize_new_tokens(self, inserting_toks: List[str]): text_encoder.resize_token_embeddings(len(tokenizer)) # Convert the token abstractions to ids - self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks) + if idx == 0: + self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks) + else: + self.train_ids_t5 = tokenizer.convert_tokens_to_ids(self.inserting_toks) # random initialization of new tokens embeds = ( @@ -838,19 +842,20 @@ def initialize_new_tokens(self, inserting_toks: List[str]): logger.info(f"{idx} text encoder's std_token_embedding: {std_token_embedding}") + train_ids = self.train_ids if idx == 0 else self.train_ids_t5 # if initializer_concept are not provided, token embeddings are initialized randomly if args.initializer_concept is None: hidden_size = ( text_encoder.text_model.config.hidden_size if idx == 0 else text_encoder.encoder.config.hidden_size ) - embeds.weight.data[self.train_ids] = ( - torch.randn(len(self.train_ids), hidden_size).to(device=self.device).to(dtype=self.dtype) + embeds.weight.data[train_ids] = ( + torch.randn(len(train_ids), hidden_size).to(device=self.device).to(dtype=self.dtype) * std_token_embedding ) else: # Convert the initializer_token, placeholder_token to ids initializer_token_ids = tokenizer.encode(args.initializer_concept, add_special_tokens=False) - for token_idx, token_id in enumerate(self.train_ids): + for token_idx, token_id in enumerate(train_ids): embeds.weight.data[token_id] = (embeds.weight.data)[ initializer_token_ids[token_idx % len(initializer_token_ids)] ].clone() @@ -860,7 +865,7 @@ def initialize_new_tokens(self, inserting_toks: List[str]): # makes sure we don't update any embedding weights besides the newly added token index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool) - index_no_updates[self.train_ids] = False + index_no_updates[train_ids] = False self.embeddings_settings[f"index_no_updates_{idx}"] = index_no_updates @@ -874,11 +879,12 @@ def save_embeddings(self, file_path: str): # text_encoder_one, idx==0 - CLIP ViT-L/14, text_encoder_two, idx==1 - T5 xxl idx_to_text_encoder_name = {0: "clip_l", 1: "t5"} for idx, text_encoder in enumerate(self.text_encoders): + train_ids = self.train_ids if idx == 0 else self.train_ids_t5 embeds = ( text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens ) assert embeds.weight.data.shape[0] == len(self.tokenizers[idx]), "Tokenizers should be the same." - new_token_embeddings = embeds.weight.data[self.train_ids] + new_token_embeddings = embeds.weight.data[train_ids] # New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), # Note: When loading with diffusers, any name can work - simply specify in inference From 31de75262016c952281e1f1f488176d43ae6821a Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 16 Oct 2024 11:47:58 +0300 Subject: [PATCH 79/82] fix path in test --- .../test_dreambooth_lora_flux_advanced.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py index 512b5cca39bd..f5c3e002da11 100644 --- a/examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py @@ -178,7 +178,7 @@ def test_dreambooth_lora_pivotal_tuning_flux_clip_t5(self): # make sure the state_dict has the correct naming in the parameters. textual_inversion_state_dict = safetensors.torch.load_file( - os.path.join(tmpdir, f"{tmpdir}_emb.safetensors") + os.path.join(tmpdir, f"{os.path(tmpdir).name}_emb.safetensors") ) is_te = all(("clip_l" in k or "t5" in k) for k in textual_inversion_state_dict.keys()) self.assertTrue(is_te) From 5dfd685c9a1d39e7900d94c183ecf0af3e573269 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 16 Oct 2024 12:15:34 +0300 Subject: [PATCH 80/82] remove `pin_memory` --- .../train_dreambooth_lora_flux_advanced.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index e9584f29ae81..a4472cea54ea 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -1988,7 +1988,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): for batch in tqdm(train_dataloader, desc="Caching latents"): with torch.no_grad(): batch["pixel_values"] = batch["pixel_values"].to( - accelerator.device, non_blocking=True, dtype=weight_dtype, pin_memory=True + accelerator.device, non_blocking=True, dtype=weight_dtype ) latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) From 9093a4bb5f2a4b6fed7c177f28a4c654e9965146 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 16 Oct 2024 18:07:16 +0300 Subject: [PATCH 81/82] fix filename of embedding --- .../test_dreambooth_lora_flux_advanced.py | 2 +- .../train_dreambooth_lora_flux_advanced.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py index f5c3e002da11..ede426e105f5 100644 --- a/examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py @@ -178,7 +178,7 @@ def test_dreambooth_lora_pivotal_tuning_flux_clip_t5(self): # make sure the state_dict has the correct naming in the parameters. textual_inversion_state_dict = safetensors.torch.load_file( - os.path.join(tmpdir, f"{os.path(tmpdir).name}_emb.safetensors") + os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors") ) is_te = all(("clip_l" in k or "t5" in k) for k in textual_inversion_state_dict.keys()) self.assertTrue(is_te) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index a4472cea54ea..3db6896228de 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -2405,7 +2405,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) if args.train_text_encoder_ti: - embeddings_path = f"{args.output_dir}/{args.output_dir}_emb.safetensors" + embeddings_path = f"{args.output_dir}/{os.path.basename(args.output_dir)}_emb.safetensors" embedding_handler.save_embeddings(embeddings_path) # Final inference From f1b08cb095120fb03d790b67a0f2273f0ce060b9 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 16 Oct 2024 18:22:47 +0300 Subject: [PATCH 82/82] fix filename of embedding --- .../test_dreambooth_lora_flux_advanced.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py index ede426e105f5..e29c99821303 100644 --- a/examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py @@ -126,7 +126,7 @@ def test_dreambooth_lora_pivotal_tuning_flux_clip(self): # save_pretrained smoke test self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) # make sure embeddings were also saved - self.assertTrue(os.path.isfile(os.path.join(tmpdir, f"{tmpdir}_emb.safetensors"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors"))) # make sure the state_dict has the correct naming in the parameters. lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) @@ -135,7 +135,7 @@ def test_dreambooth_lora_pivotal_tuning_flux_clip(self): # make sure the state_dict has the correct naming in the parameters. textual_inversion_state_dict = safetensors.torch.load_file( - os.path.join(tmpdir, f"{tmpdir}_emb.safetensors") + os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors") ) is_clip = all("clip_l" in k for k in textual_inversion_state_dict.keys()) self.assertTrue(is_clip) @@ -169,7 +169,7 @@ def test_dreambooth_lora_pivotal_tuning_flux_clip_t5(self): # save_pretrained smoke test self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) # make sure embeddings were also saved - self.assertTrue(os.path.isfile(os.path.join(tmpdir, f"{tmpdir}_emb.safetensors"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors"))) # make sure the state_dict has the correct naming in the parameters. lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))