In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from compel import Compel, ReturnedEmbeddingsType
from diffusers import DiffusionPipeline
import torch

device='cuda'
pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", variant="fp16", use_safetensors=True, torch_dtype=torch.float16).to(device)

compel = Compel(tokenizer=[pipeline.tokenizer, pipeline.tokenizer_2] , text_encoder=[pipeline.text_encoder, pipeline.text_encoder_2], returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=[False, True])


In [None]:
# upweight "ball"
prompt = "a cat playing with a ball++ in the forest"
negative_prompt = ""
conditioning, pooled = compel([prompt, negative_prompt])
print(conditioning.shape, pooled.shape)


In [None]:
# generate image
image = pipeline(prompt_embeds=conditioning[0:1], pooled_prompt_embeds=pooled[0:1], 
                 negative_prompt_embeds=conditioning[1:2], negative_pooled_prompt_embeds=pooled[1:2],
                 num_inference_steps=24, width=768, height=768).images[0]
image

## Long prompts

In [None]:

compel = Compel(tokenizer=[pipeline.tokenizer, pipeline.tokenizer_2] , 
                text_encoder=[pipeline.text_encoder, pipeline.text_encoder_2], 
                returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, 
                requires_pooled=[False, True],
               truncate_long_prompts=False)

prompt = "a cat playing with a ball++ in the forest"
negative_prompt = "a long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long negative prompt"
conditioning, pooled = compel([prompt, negative_prompt])
print(conditioning.shape, pooled.shape)

image = pipeline(prompt_embeds=conditioning[0:1], pooled_prompt_embeds=pooled[0:1], 
                 negative_prompt_embeds=conditioning[1:2], negative_pooled_prompt_embeds=pooled[1:2],
                 num_inference_steps=24, width=768, height=768).images[0]
image

## Sequential cpu offload

In [None]:
# 

compel = Compel(tokenizer=[pipeline.tokenizer, pipeline.tokenizer_2] , 
                text_encoder=[pipeline.text_encoder, pipeline.text_encoder_2], 
                returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, 
                requires_pooled=[False, True],
                device="cuda")

pipeline.enable_sequential_cpu_offload()
prompt = "a cat playing with a ball++ in the forest"
negative_prompt = "deformed, ugly"
conditioning, pooled = compel([prompt, negative_prompt])
print(conditioning.shape, pooled.shape)

image = pipeline(prompt_embeds=conditioning[0:1], pooled_prompt_embeds=pooled[0:1], 
                 negative_prompt_embeds=conditioning[1:2], negative_pooled_prompt_embeds=pooled[1:2],
                 num_inference_steps=24, width=768, height=768).images[0]
image

## Different prompts for different encoders

In [None]:
compel1 = Compel(
    tokenizer=pipeline.tokenizer,
    text_encoder=pipeline.text_encoder,
    returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
    requires_pooled=False,
)

compel2 = Compel(
    tokenizer=pipeline.tokenizer_2,
    text_encoder=pipeline.text_encoder_2,
    returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
    requires_pooled=True,
)

conditioning1 = compel1(prompt1)
conditioning2, pooled = compel2(prompt2)
conditioning = torch.cat((conditioning1, conditioning2), dim=-1)

image = pipeline(prompt_embeds=conditioning, pooled_prompt_embeds=pooled, num_inference_steps=30).images[0]