Skip to content

text_input_ids_list not part of encode_prompt definition #9773

@Vatsal-Malaviya

Description

@Vatsal-Malaviya

Description
code is trying to pass text_input_ids_list as argument for encode_prompt, but, text_input_ids_list isn't supposed to be part of argument as per the definition.

Call

prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three],
tokenizers=None,
prompt=None,
text_input_ids_list=[tokens_one, tokens_two, tokens_three],

Definition for encode_prompt

def _encode_prompt_with_clip(
text_encoder,
tokenizer,
prompt: str,
device=None,
num_images_per_prompt: int = 1,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=77,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds, pooled_prompt_embeds
def encode_prompt(
text_encoders,
tokenizers,
prompt: str,
max_sequence_length,
device=None,
num_images_per_prompt: int = 1,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
clip_tokenizers = tokenizers[:2]
clip_text_encoders = text_encoders[:2]
clip_prompt_embeds_list = []
clip_pooled_prompt_embeds_list = []
for tokenizer, text_encoder in zip(clip_tokenizers, clip_text_encoders):
prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip(
text_encoder=text_encoder,
tokenizer=tokenizer,
prompt=prompt,
device=device if device is not None else text_encoder.device,
num_images_per_prompt=num_images_per_prompt,
)
clip_prompt_embeds_list.append(prompt_embeds)
clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1)
pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1)
t5_prompt_embed = _encode_prompt_with_t5(
text_encoders[-1],
tokenizers[-1],
max_sequence_length,
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
device=device if device is not None else text_encoders[-1].device,
)
clip_prompt_embeds = torch.nn.functional.pad(
clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
)
prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
return prompt_embeds, pooled_prompt_embeds

Referencing train_dreambooth_sd3_lora.py; I think definition of encode_prompt might need to be updated

def _encode_prompt_with_clip(
text_encoder,
tokenizer,
prompt: str,
device=None,
text_input_ids=None,
num_images_per_prompt: int = 1,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if tokenizer is not None:
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=77,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
else:
if text_input_ids is None:
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds, pooled_prompt_embeds
def encode_prompt(
text_encoders,
tokenizers,
prompt: str,
max_sequence_length,
device=None,
num_images_per_prompt: int = 1,
text_input_ids_list=None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
clip_tokenizers = tokenizers[:2]
clip_text_encoders = text_encoders[:2]
clip_prompt_embeds_list = []
clip_pooled_prompt_embeds_list = []
for i, (tokenizer, text_encoder) in enumerate(zip(clip_tokenizers, clip_text_encoders)):
prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip(
text_encoder=text_encoder,
tokenizer=tokenizer,
prompt=prompt,
device=device if device is not None else text_encoder.device,
num_images_per_prompt=num_images_per_prompt,
text_input_ids=text_input_ids_list[i] if text_input_ids_list else None,
)
clip_prompt_embeds_list.append(prompt_embeds)
clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1)
pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1)
t5_prompt_embed = _encode_prompt_with_t5(
text_encoders[-1],
tokenizers[-1],
max_sequence_length,
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
text_input_ids=text_input_ids_list[-1] if text_input_ids_list else None,
device=device if device is not None else text_encoders[-1].device,
)
clip_prompt_embeds = torch.nn.functional.pad(
clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
)
prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
return prompt_embeds, pooled_prompt_embeds

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions