In [1]:
import torch
import os
import logging
import numpy as np
import pandas as pd

from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast
from transformers import CLIPTextModelWithProjection, T5EncoderModel

In [2]:
dataset_path = 'genshin_dataset'
images_path = os.path.join(dataset_path, 'dataset1')
prompt_embbedings_path = os.path.join(dataset_path, 'dataset1_prompt_embbedings_sd3')
#os.mkdir(prompt_embbedings_path)
train_df = pd.read_csv(os.path.join(dataset_path, 'dataset1.csv'))
train_df.head()

Unnamed: 0,character,im_path,orig_path,description
0,yoimiya,genshin_dataset/dataset1/yoimiya_0.png,genshin_dataset/character_imgs/yoimiya/Captura...,A young woman with light blonde hair tied in a...
1,yoimiya,genshin_dataset/dataset1/yoimiya_0.png,genshin_dataset/character_imgs/yoimiya/Captura...,"A character with blonde hair, dressed in a red..."
2,yoimiya,genshin_dataset/dataset1/yoimiya_0.png,genshin_dataset/character_imgs/yoimiya/Captura...,A woman with light blonde hair tied in a ponyt...
3,yoimiya,genshin_dataset/dataset1/yoimiya_0.png,genshin_dataset/character_imgs/yoimiya/Captura...,"A character with blonde hair, dressed in a red..."
4,yoimiya,genshin_dataset/dataset1/yoimiya_0.png,genshin_dataset/character_imgs/yoimiya/Captura...,A young woman with blonde hair tied in a high ...


In [3]:
## Duplicate rows adding "genshin style to the prompts" ##
new_data = []
columns = list(train_df.columns)
print(len(train_df))
for idx, row in train_df.iterrows():
    clist = []
    for c in columns:
        val = row[c]
        if c == 'description': val += ' Genshin style.'
        clist.append(val)
    new_data.append(clist)

df2 = pd.DataFrame(new_data, columns=columns) 
train_df = pd.concat([train_df, df2], ignore_index=True)
print(len(train_df))

5570
11140


In [4]:
prompts = train_df['description'].unique()
print('Number of prompts: ', len(prompts))

Number of prompts:  1102


In [5]:
pretrained_model = "stabilityai/stable-diffusion-3-medium-diffusers"
weight_dtype = torch.float16

# Load the tokenizers
tokenizer_one = CLIPTokenizer.from_pretrained(pretrained_model, subfolder="tokenizer")
tokenizer_two = CLIPTokenizer.from_pretrained(pretrained_model, subfolder="tokenizer_2")
tokenizer_three = T5TokenizerFast.from_pretrained(pretrained_model, subfolder="tokenizer_3")
tokenizers = [tokenizer_one, tokenizer_two, tokenizer_three]


# Load text encoders
text_encoder_one = CLIPTextModelWithProjection.from_pretrained(pretrained_model, subfolder="text_encoder")
text_encoder_two = CLIPTextModelWithProjection.from_pretrained(pretrained_model, subfolder="text_encoder_2")
text_encoder_three = T5EncoderModel.from_pretrained(pretrained_model, subfolder="text_encoder_3")
text_encoder_one.requires_grad_(False)
text_encoder_two.requires_grad_(False)
text_encoder_three.requires_grad_(False)
text_encoder_one.to("cuda", dtype=weight_dtype)
text_encoder_two.to("cuda", dtype=weight_dtype)
text_encoder_three.to("cuda", dtype=weight_dtype)
text_encoders = [text_encoder_one, text_encoder_two, text_encoder_three]

max_sequence_length = 77 # seq len from stable-diffusion-3

You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

#### Functions to get embeddings from prompts (from hugging face codebase)

In [6]:
def _encode_prompt_with_clip(text_encoder, tokenizer, prompt: str, device=None,  text_input_ids=None, num_images_per_prompt: int = 1,):
    prompt = [prompt] if isinstance(prompt, str) else prompt
    batch_size = len(prompt)

    text_inputs = tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt",)
    text_input_ids = text_inputs.input_ids

    prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)

    pooled_prompt_embeds = prompt_embeds[0]
    prompt_embeds = prompt_embeds.hidden_states[-2]
    prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)

    _, seq_len, _ = prompt_embeds.shape
    # duplicate text embeddings for each generation per prompt, using mps friendly method
    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
    return prompt_embeds, pooled_prompt_embeds

def _encode_prompt_with_t5(text_encoder, tokenizer, max_sequence_length, prompt=None, num_images_per_prompt=1, device=None, text_input_ids=None):
    prompt = [prompt] if isinstance(prompt, str) else prompt
    batch_size = len(prompt)

    text_inputs = tokenizer(prompt, padding="max_length", max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_tensors="pt",)
    text_input_ids = text_inputs.input_ids

    prompt_embeds = text_encoder(text_input_ids.to(device))[0]

    dtype = text_encoder.dtype
    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

    _, seq_len, _ = prompt_embeds.shape

    # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
    return prompt_embeds

def encode_prompt(text_encoders, tokenizers, prompt,  max_sequence_length,  device=None,  num_images_per_prompt: int = 1,
                  text_input_ids_list=None,):
    prompt = [prompt] if isinstance(prompt, str) else prompt

    clip_tokenizers = tokenizers[:2]
    clip_text_encoders = text_encoders[:2]

    clip_prompt_embeds_list = []
    clip_pooled_prompt_embeds_list = []
    for i, (tokenizer, text_encoder) in enumerate(zip(clip_tokenizers, clip_text_encoders)):
        prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip(
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            prompt=prompt,
            device=device if device is not None else text_encoder.device,
            num_images_per_prompt=num_images_per_prompt,
            text_input_ids=text_input_ids_list[i] if text_input_ids_list else None,
        )
        clip_prompt_embeds_list.append(prompt_embeds)
        clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)

    clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1)
    pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1)

    t5_prompt_embed = _encode_prompt_with_t5(
        text_encoders[-1],
        tokenizers[-1],
        max_sequence_length,
        prompt=prompt,
        num_images_per_prompt=num_images_per_prompt,
        text_input_ids=text_input_ids_list[-1] if text_input_ids_list else None,
        device=device if device is not None else text_encoders[-1].device,
    )

    clip_prompt_embeds = torch.nn.functional.pad(
        clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
    )
    prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)

    return prompt_embeds, pooled_prompt_embeds

#### Get Embeddings for each prompt

In [7]:
## Generate embeddings for each prompt and save them
for idx, prompt in enumerate(prompts):
    # generate embeddings:
    with torch.no_grad():
        prompt_embeds, pooled_prompt_embeds = encode_prompt(
            text_encoders, tokenizers, prompt, max_sequence_length
        )
        # delete batch dimension
        prompt_embeds = torch.squeeze(prompt_embeds.to("cpu"), 0)
        pooled_prompt_embeds = torch.squeeze(pooled_prompt_embeds.to("cpu"), 0)

        # save embeddings
        emb_path = os.path.join(prompt_embbedings_path, f"emb{idx}.pt")
        pooled_emb_path = os.path.join(prompt_embbedings_path, f"pooled_emb{idx}.pt")
        torch.save(prompt_embeds, emb_path)
        torch.save(pooled_prompt_embeds, pooled_emb_path)

        # update all rows in train_df to find path
        train_df.loc[ train_df['description']==prompt, 'embeddings'] = emb_path
        train_df.loc[ train_df['description']==prompt, 'pooled_embeddings'] = pooled_emb_path
train_df.to_csv(os.path.join(dataset_path, 'dataset1_sd3_emb2.csv'), index=False)

In [8]:
train_df

Unnamed: 0,character,im_path,orig_path,description,embeddings,pooled_embeddings
0,yoimiya,genshin_dataset/dataset1/yoimiya_0.png,genshin_dataset/character_imgs/yoimiya/Captura...,A young woman with light blonde hair tied in a...,genshin_dataset/dataset1_prompt_embbedings_sd3...,genshin_dataset/dataset1_prompt_embbedings_sd3...
1,yoimiya,genshin_dataset/dataset1/yoimiya_0.png,genshin_dataset/character_imgs/yoimiya/Captura...,"A character with blonde hair, dressed in a red...",genshin_dataset/dataset1_prompt_embbedings_sd3...,genshin_dataset/dataset1_prompt_embbedings_sd3...
2,yoimiya,genshin_dataset/dataset1/yoimiya_0.png,genshin_dataset/character_imgs/yoimiya/Captura...,A woman with light blonde hair tied in a ponyt...,genshin_dataset/dataset1_prompt_embbedings_sd3...,genshin_dataset/dataset1_prompt_embbedings_sd3...
3,yoimiya,genshin_dataset/dataset1/yoimiya_0.png,genshin_dataset/character_imgs/yoimiya/Captura...,"A character with blonde hair, dressed in a red...",genshin_dataset/dataset1_prompt_embbedings_sd3...,genshin_dataset/dataset1_prompt_embbedings_sd3...
4,yoimiya,genshin_dataset/dataset1/yoimiya_0.png,genshin_dataset/character_imgs/yoimiya/Captura...,A young woman with blonde hair tied in a high ...,genshin_dataset/dataset1_prompt_embbedings_sd3...,genshin_dataset/dataset1_prompt_embbedings_sd3...
...,...,...,...,...,...,...
11135,ganyu,genshin_dataset/dataset1/ganyu_9.png,genshin_dataset/character_imgs/ganyu/zero.png,A mysterious female character with short blue ...,genshin_dataset/dataset1_prompt_embbedings_sd3...,genshin_dataset/dataset1_prompt_embbedings_sd3...
11136,ganyu,genshin_dataset/dataset1/ganyu_9.png,genshin_dataset/character_imgs/ganyu/zero.png,A blue-haired character with large purple eyes...,genshin_dataset/dataset1_prompt_embbedings_sd3...,genshin_dataset/dataset1_prompt_embbedings_sd3...
11137,ganyu,genshin_dataset/dataset1/ganyu_9.png,genshin_dataset/character_imgs/ganyu/zero.png,"A girl with pastel blue hair in soft waves, re...",genshin_dataset/dataset1_prompt_embbedings_sd3...,genshin_dataset/dataset1_prompt_embbedings_sd3...
11138,ganyu,genshin_dataset/dataset1/ganyu_9.png,genshin_dataset/character_imgs/ganyu/zero.png,A young woman with ethereal blue hair and lumi...,genshin_dataset/dataset1_prompt_embbedings_sd3...,genshin_dataset/dataset1_prompt_embbedings_sd3...
