-
Couldn't load subscription status.
- Fork 6.4k
Description
Description
code is trying to pass text_input_ids_list as argument for encode_prompt, but, text_input_ids_list isn't supposed to be part of argument as per the definition.
Call
diffusers/examples/dreambooth/train_dreambooth_sd3.py
Lines 1597 to 1601 in 298ab6e
| prompt_embeds, pooled_prompt_embeds = encode_prompt( | |
| text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three], | |
| tokenizers=None, | |
| prompt=None, | |
| text_input_ids_list=[tokens_one, tokens_two, tokens_three], |
Definition for encode_prompt
diffusers/examples/dreambooth/train_dreambooth_sd3.py
Lines 900 to 976 in 298ab6e
| def _encode_prompt_with_clip( | |
| text_encoder, | |
| tokenizer, | |
| prompt: str, | |
| device=None, | |
| num_images_per_prompt: int = 1, | |
| ): | |
| prompt = [prompt] if isinstance(prompt, str) else prompt | |
| batch_size = len(prompt) | |
| text_inputs = tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=77, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids | |
| prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) | |
| pooled_prompt_embeds = prompt_embeds[0] | |
| prompt_embeds = prompt_embeds.hidden_states[-2] | |
| prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) | |
| _, seq_len, _ = prompt_embeds.shape | |
| # duplicate text embeddings for each generation per prompt, using mps friendly method | |
| prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
| prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) | |
| return prompt_embeds, pooled_prompt_embeds | |
| def encode_prompt( | |
| text_encoders, | |
| tokenizers, | |
| prompt: str, | |
| max_sequence_length, | |
| device=None, | |
| num_images_per_prompt: int = 1, | |
| ): | |
| prompt = [prompt] if isinstance(prompt, str) else prompt | |
| clip_tokenizers = tokenizers[:2] | |
| clip_text_encoders = text_encoders[:2] | |
| clip_prompt_embeds_list = [] | |
| clip_pooled_prompt_embeds_list = [] | |
| for tokenizer, text_encoder in zip(clip_tokenizers, clip_text_encoders): | |
| prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip( | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| prompt=prompt, | |
| device=device if device is not None else text_encoder.device, | |
| num_images_per_prompt=num_images_per_prompt, | |
| ) | |
| clip_prompt_embeds_list.append(prompt_embeds) | |
| clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds) | |
| clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1) | |
| pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1) | |
| t5_prompt_embed = _encode_prompt_with_t5( | |
| text_encoders[-1], | |
| tokenizers[-1], | |
| max_sequence_length, | |
| prompt=prompt, | |
| num_images_per_prompt=num_images_per_prompt, | |
| device=device if device is not None else text_encoders[-1].device, | |
| ) | |
| clip_prompt_embeds = torch.nn.functional.pad( | |
| clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) | |
| ) | |
| prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) | |
| return prompt_embeds, pooled_prompt_embeds |
Referencing train_dreambooth_sd3_lora.py; I think definition of encode_prompt might need to be updated
diffusers/examples/dreambooth/train_dreambooth_lora_sd3.py
Lines 952 to 1037 in 298ab6e
| def _encode_prompt_with_clip( | |
| text_encoder, | |
| tokenizer, | |
| prompt: str, | |
| device=None, | |
| text_input_ids=None, | |
| num_images_per_prompt: int = 1, | |
| ): | |
| prompt = [prompt] if isinstance(prompt, str) else prompt | |
| batch_size = len(prompt) | |
| if tokenizer is not None: | |
| text_inputs = tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=77, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids | |
| else: | |
| if text_input_ids is None: | |
| raise ValueError("text_input_ids must be provided when the tokenizer is not specified") | |
| prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) | |
| pooled_prompt_embeds = prompt_embeds[0] | |
| prompt_embeds = prompt_embeds.hidden_states[-2] | |
| prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) | |
| _, seq_len, _ = prompt_embeds.shape | |
| # duplicate text embeddings for each generation per prompt, using mps friendly method | |
| prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
| prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) | |
| return prompt_embeds, pooled_prompt_embeds | |
| def encode_prompt( | |
| text_encoders, | |
| tokenizers, | |
| prompt: str, | |
| max_sequence_length, | |
| device=None, | |
| num_images_per_prompt: int = 1, | |
| text_input_ids_list=None, | |
| ): | |
| prompt = [prompt] if isinstance(prompt, str) else prompt | |
| clip_tokenizers = tokenizers[:2] | |
| clip_text_encoders = text_encoders[:2] | |
| clip_prompt_embeds_list = [] | |
| clip_pooled_prompt_embeds_list = [] | |
| for i, (tokenizer, text_encoder) in enumerate(zip(clip_tokenizers, clip_text_encoders)): | |
| prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip( | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| prompt=prompt, | |
| device=device if device is not None else text_encoder.device, | |
| num_images_per_prompt=num_images_per_prompt, | |
| text_input_ids=text_input_ids_list[i] if text_input_ids_list else None, | |
| ) | |
| clip_prompt_embeds_list.append(prompt_embeds) | |
| clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds) | |
| clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1) | |
| pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1) | |
| t5_prompt_embed = _encode_prompt_with_t5( | |
| text_encoders[-1], | |
| tokenizers[-1], | |
| max_sequence_length, | |
| prompt=prompt, | |
| num_images_per_prompt=num_images_per_prompt, | |
| text_input_ids=text_input_ids_list[-1] if text_input_ids_list else None, | |
| device=device if device is not None else text_encoders[-1].device, | |
| ) | |
| clip_prompt_embeds = torch.nn.functional.pad( | |
| clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) | |
| ) | |
| prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) | |
| return prompt_embeds, pooled_prompt_embeds |