In [1]:
from diffusers import StableDiffusionPipeline, EulerAncestralDiscreteScheduler
import torch
from PIL import Image
import os
from transformers import CLIPTextModel, CLIPTokenizer

model_id = "stabilityai/stable-diffusion-2-base"

def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

In [2]:
tokenizer = CLIPTokenizer.from_pretrained(
    model_id,
    subfolder="tokenizer",
    use_auth_token=True,
)
text_encoder = CLIPTextModel.from_pretrained(
    model_id, subfolder="text_encoder", use_auth_token=True
)

vocab_size = tokenizer.vocab_size

ftfy or spacy is not installed using BERT BasicTokenizer instead of ftfy.


In [3]:
def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None):
  loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
  
  # separate token and the embeds
  trained_token = list(loaded_learned_embeds.keys())[0]
  embeds = loaded_learned_embeds[trained_token]

  # cast to dtype of text_encoder
  dtype = text_encoder.get_input_embeddings().weight.dtype
  embeds.to(dtype)

  # add the token in tokenizer
  token = token if token is not None else trained_token
  num_added_tokens = tokenizer.add_tokens(token)
  if num_added_tokens == 0:
    raise ValueError(f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer.")
  
  # resize the token embeddings
  text_encoder.resize_token_embeddings(len(tokenizer))
  
  # get the id for the token and assign the embeds
  token_id = tokenizer.convert_tokens_to_ids(token)
  text_encoder.get_input_embeddings().weight.data[token_id] = embeds

  print(trained_token)


In [4]:
embed_paths = [
    "ugly_sonic_sd_2_0.bin",
    "wrong_sd_2_0.bin",
    ]

for embed in embed_paths:
    load_learned_embed_in_clip(embed, text_encoder, tokenizer)

<ugly-sonic>
<wrong>


In [5]:
# Use the Euler scheduler here instead
scheduler = EulerAncestralDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained(model_id,
                                               scheduler=scheduler,
                                               text_encoder=text_encoder,
                                               tokenizer=tokenizer,
                                               revision="fp16",
                                               torch_dtype=torch.float16,
                                               safety_checker=None)
pipe = pipe.to("cuda")

Fetching 12 files:   0%|          | 0/12 [00:00<?, ?it/s]

You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


In [None]:
prompt = "Twitter brand logo combusting into flames, 3d rendering"
negative_prompt =  None
negative_prompt = "in the style of <wrong>"
num_samples = 4  # may have to set to 2 in Colab to avoid OOM.
num_rows = 4

all_images = [] 
for _ in range(num_rows):
    with torch.autocast("cuda"):
        images = pipe(prompt,
                      negative_prompt=negative_prompt,
                      num_images_per_prompt=num_samples,
                      num_inference_steps=50, guidance_scale=7.5).images
        all_images.extend(images)
        image_large = image_grid(images, 1, num_samples)
        display(image_large.resize((image_large.width // 2, image_large.height // 2)))
        
grid = image_grid(all_images, num_samples, num_rows)
grid.save("grid.png")

In [25]:
if not os.path.exists("outputs"):
    os.mkdir("outputs")

prompt_folder = prompt.replace("/", "_")[0:128]

if not os.path.exists(os.path.join("outputs", prompt_folder)):
    os.mkdir(os.path.join("outputs", prompt_folder))

for i, image in enumerate(all_images):
    image.save(os.path.join("outputs", prompt_folder, f"{i+1:02d}.png"))