From 1e357220e5532db424caffd2b712b541bcb9db44 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 26 Jun 2024 17:42:53 +0300 Subject: [PATCH 1/3] add clip_skip --- .../train_dreambooth_lora_sdxl_advanced.py | 24 ++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 64fd0a6986ed..94e731ae2250 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -499,6 +499,13 @@ def parse_args(input_args=None): default=1e-4, help="Initial learning rate (after the potential warmup period) to use.", ) + parser.add_argument( + "--clip_skip", + type=int, + default=None, + help="Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that " + "the output of the pre-final layer will be used for computing the prompt embeddings.", + ) parser.add_argument( "--text_encoder_lr", @@ -1166,7 +1173,7 @@ def tokenize_prompt(tokenizer, prompt, add_special_tokens=False): # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt -def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None): +def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None, clip_skip=None): prompt_embeds_list = [] for i, text_encoder in enumerate(text_encoders): @@ -1184,7 +1191,11 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None): # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] - prompt_embeds = prompt_embeds.hidden_states[-2] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) prompt_embeds_list.append(prompt_embeds) @@ -1752,9 +1763,9 @@ def compute_time_ids(crops_coords_top_left, original_size=None): tokenizers = [tokenizer_one, tokenizer_two] text_encoders = [text_encoder_one, text_encoder_two] - def compute_text_embeddings(prompt, text_encoders, tokenizers): + def compute_text_embeddings(prompt, text_encoders, tokenizers, clip_skip): with torch.no_grad(): - prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt) + prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt, clip_skip) prompt_embeds = prompt_embeds.to(accelerator.device) pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) return prompt_embeds, pooled_prompt_embeds @@ -1764,7 +1775,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # the redundant encoding. if freeze_text_encoder and not train_dataset.custom_instance_prompts: instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings( - args.instance_prompt, text_encoders, tokenizers + args.instance_prompt, text_encoders, tokenizers, args.clip_skip ) # Handle class prompt for prior-preservation. @@ -1962,7 +1973,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if train_dataset.custom_instance_prompts: if freeze_text_encoder: prompt_embeds, unet_add_text_embeds = compute_text_embeddings( - prompts, text_encoders, tokenizers + prompts, text_encoders, tokenizers, args.clip_skip ) else: @@ -2058,6 +2069,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): tokenizers=None, prompt=None, text_input_ids_list=[tokens_one, tokens_two], + clip_skip=args.clip_skip ) unet_added_conditions.update( {"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)} From 528911f7bf3bcf62c3f6c4041d6ec1cbfe9449e1 Mon Sep 17 00:00:00 2001 From: Linoy Date: Wed, 26 Jun 2024 15:29:45 +0000 Subject: [PATCH 2/3] style --- .../train_dreambooth_lora_sdxl_advanced.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 94e731ae2250..4f041085a087 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -504,7 +504,7 @@ def parse_args(input_args=None): type=int, default=None, help="Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that " - "the output of the pre-final layer will be used for computing the prompt embeddings.", + "the output of the pre-final layer will be used for computing the prompt embeddings.", ) parser.add_argument( @@ -2069,7 +2069,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): tokenizers=None, prompt=None, text_input_ids_list=[tokens_one, tokens_two], - clip_skip=args.clip_skip + clip_skip=args.clip_skip, ) unet_added_conditions.update( {"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)} From 67ce4d2246844bdafc6366f230c95ee33ffc9adf Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 1 Jul 2024 18:14:50 +0300 Subject: [PATCH 3/3] smol fix --- .../train_dreambooth_lora_sdxl_advanced.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index e55e1b7e974b..f29b2e0b5225 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -1261,10 +1261,10 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None, c # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] if clip_skip is None: - prompt_embeds = prompt_embeds.hidden_states[-2] + prompt_embeds = prompt_embeds[-1][-2] else: # "2" because SDXL always indexes from the penultimate layer. - prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + prompt_embeds = prompt_embeds[-1][-(clip_skip + 2)] bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) prompt_embeds_list.append(prompt_embeds)