## prompt embedding
1. load all 3 text encoders
2. generate embedding and save to disk

In [1]:
prompt = ["photo of a car parked in an empty parking lot on a rainy night"]
is_pos_prompt = True
batch_size = 1
max_sequence_length = 256
num_images_per_prompt = 1

In [2]:
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from transformers import (
    CLIPTextModelWithProjection,
    CLIPTokenizer,
    T5EncoderModel,
    T5TokenizerFast,
)

from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_8bit=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
num_images_per_prompt = 1

In [3]:
model_id = "stabilityai/stable-diffusion-3-medium-diffusers"

tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer", use_safetensors=True)
text_encoder = CLIPTextModelWithProjection.from_pretrained(model_id, subfolder="text_encoder")
tokenizer_2 = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer_2", use_safetensors=True)
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(model_id, subfolder="text_encoder_2")
tokenizer_3 = T5TokenizerFast.from_pretrained(model_id, subfolder="tokenizer_3", use_safetensors=True)
text_encoder_3 = T5EncoderModel.from_pretrained(
    model_id,
    subfolder="text_encoder_3",
    quantization_config=quantization_config,
)

You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
`low_cpu_mem_usage` was None, now default to True since model is quantized.


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
def _get_clip_prompt_embeds(
    prompt: Union[str, List[str]],
    num_images_per_prompt: int = 1,
    device: Optional[torch.device] = None,
    clip_skip: Optional[int] = None,
    clip_model_index: int = 0,
    tokenizer: CLIPTokenizer = None,
    tokenizer_2: CLIPTokenizer = None,
    text_encoder: CLIPTextModelWithProjection = None,
    text_encoder_2: CLIPTextModelWithProjection = None,
):
    clip_tokenizers = [tokenizer, tokenizer_2]
    clip_text_encoders = [text_encoder, text_encoder_2]

    tokenizer = clip_tokenizers[clip_model_index]
    text_encoder = clip_text_encoders[clip_model_index]

    prompt = [prompt] if isinstance(prompt, str) else prompt
    batch_size = len(prompt)

    text_inputs = tokenizer(
        prompt,
        padding="max_length",
        max_length=tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
    )

    text_input_ids = text_inputs.input_ids
    untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
    if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
        removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
        logger.warning(
            "The following part of your input was truncated because CLIP can only handle sequences up to"
            f" {tokenizer.model_max_length} tokens: {removed_text}"
        )
    prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
    pooled_prompt_embeds = prompt_embeds[0]

    if clip_skip is None:
        prompt_embeds = prompt_embeds.hidden_states[-2]
    else:
        prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]

    prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)

    _, seq_len, _ = prompt_embeds.shape
    # duplicate text embeddings for each generation per prompt, using mps friendly method
    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

    pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
    pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)

    return prompt_embeds, pooled_prompt_embeds

def _get_t5_prompt_embeds(
    prompt: Union[str, List[str]] = None,
    num_images_per_prompt: int = 1,
    max_sequence_length: int = 256,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
    tokenizer_3: T5TokenizerFast = None,
    text_encoder_3: T5EncoderModel = None
):
    device = device
    dtype = dtype

    prompt = [prompt] if isinstance(prompt, str) else prompt
    batch_size = len(prompt)

    text_inputs = tokenizer_3(
        prompt,
        padding="max_length",
        max_length=max_sequence_length,
        truncation=True,
        add_special_tokens=True,
        return_tensors="pt",
    )
    text_input_ids = text_inputs.input_ids
    untruncated_ids = tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids

    if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
        removed_text = tokenizer_3.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
        logger.warning(
            "The following part of your input was truncated because `max_sequence_length` is set to "
            f" {max_sequence_length} tokens: {removed_text}"
        )

    prompt_embeds = text_encoder_3(text_input_ids.to(device))[0]

    dtype = text_encoder_3.dtype
    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

    _, seq_len, _ = prompt_embeds.shape

    # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

    return prompt_embeds

In [5]:
prompt_embed, pooled_prompt_embed = _get_clip_prompt_embeds(
    prompt=prompt,
    device="cpu",
    num_images_per_prompt=num_images_per_prompt,
    clip_skip=None,
    clip_model_index=0,
    tokenizer=tokenizer,
    tokenizer_2=tokenizer_2,
    text_encoder=text_encoder,
    text_encoder_2=text_encoder_2
)
prompt_2_embed, pooled_prompt_2_embed = _get_clip_prompt_embeds(
    prompt=prompt,
    device="cpu",
    num_images_per_prompt=num_images_per_prompt,
    clip_skip=None,
    clip_model_index=1,
    tokenizer=tokenizer,
    tokenizer_2=tokenizer_2,
    text_encoder=text_encoder,
    text_encoder_2=text_encoder_2
)
clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)

t5_prompt_embed = _get_t5_prompt_embeds(
    prompt=prompt,
    num_images_per_prompt=num_images_per_prompt,
    max_sequence_length=max_sequence_length,
    device=device,
    tokenizer_3=tokenizer_3,
    text_encoder_3=text_encoder_3
)

clip_prompt_embeds = torch.nn.functional.pad(
    clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
)

t5_prompt_embed = t5_prompt_embed.to("cpu")

prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)

In [6]:
if is_pos_prompt:
    negative_prompt_embeds, negative_pooled_prompt_embeds = torch.load("neg_emb.pt")
    prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
    pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
    torch.save((prompt_embeds, pooled_prompt_embeds), "final_emb.pt")
else:
    torch.save((prompt_embeds, pooled_prompt_embeds), "neg_emb.pt")

  negative_prompt_embeds, negative_pooled_prompt_embeds = torch.load("neg_emb.pt")
