In [None]:
# https://colab.research.google.com/github/minimaxir/sdxl-experiments/blob/main/sdxl_image_generation.ipynb#scrollTo=oV4TMRcqMskx

import time

import compel
import diffusers
import torch
from IPython.display import display

DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE in ("cuda", "mps") else torch.float32
print(f"Using {DEVICE}")

def get_generator(seed):
    if DEVICE in ("cuda", "mps"):
        return torch.Generator(DEVICE).manual_seed(seed)
    return torch.Generator().manual_seed(seed)

vae = diffusers.AutoencoderKL.from_pretrained(
    "madebyollin/sdxl-vae-fp16-fix", torch_dtype=DTYPE)
base = diffusers.DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    vae=vae, torch_dtype=DTYPE,
    variant="fp16", use_safetensors=True)
base.load_lora_weights("minimaxir/sdxl-wrong-lora")
#base.load_lora_weights("latent-consistency/lcm-lora-sdxl")
#base.scheduler = diffusers.LCMScheduler.from_config(base.scheduler.config)
base.to(DEVICE)

refiner = diffusers.DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-refiner-1.0",
    text_encoder_2=base.text_encoder_2,
    vae=base.vae,
    torch_dtype=DTYPE,
    use_safetensors=True,
    variant="fp16")
refiner.to(DEVICE)

compel_base = compel.Compel(
    tokenizer=[base.tokenizer, base.tokenizer_2] ,
    text_encoder=[base.text_encoder, base.text_encoder_2],
    returned_embeddings_type=compel.ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
    requires_pooled=[False, True])
compel_refiner = compel.Compel(
    tokenizer=refiner.tokenizer_2 ,
    text_encoder=refiner.text_encoder_2,
    returned_embeddings_type=compel.ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
    requires_pooled=True)

def gen_image(prompt, seed, cfg=13, high_noise_frac=0.8):
    negative_prompt = "wrong"
    conditioning, pooled = compel_base(prompt)
    conditioning_neg, pooled_neg = compel_base(negative_prompt) if negative_prompt is not None else (None, None)
    latents = base(
        #num_inference_steps=4,
        prompt_embeds=conditioning,
        pooled_prompt_embeds=pooled,
        negative_prompt_embeds=conditioning_neg,
        negative_pooled_prompt_embeds=pooled_neg,
        guidance_scale=cfg,
        generator=get_generator(seed),
        denoising_end=high_noise_frac,
        output_type="latent",
        cross_attention_kwargs={"scale": 1.}).images

    conditioning, pooled = compel_refiner(prompt)
    conditioning_neg, pooled_neg = compel_refiner(negative_prompt) if negative_prompt is not None else (None, None)
    return refiner(
        #num_inference_steps=4,
        prompt_embeds=conditioning,
        pooled_prompt_embeds=pooled,
        negative_prompt_embeds=conditioning_neg,
        negative_pooled_prompt_embeds=pooled_neg,
        guidance_scale=cfg,
        generator=get_generator(seed),
        denoising_start=high_noise_frac,
        image=latents).images[0]

In [None]:
prompt = "fireplace, warm cozy book shelf candle snow mountain window sofa rustic"
seed = 30

# start = time.time()
# img = pipe(prompt=prompt, num_inference_steps=4,
#     guidance_scale=1, generator=get_generator(seed)).images[0]
# print(f"Inference in {time.time()-start:.1f}s")
# display(img)
for i in range(1):
    start = time.time()
    image = gen_image(prompt, seed+i)
    print(f"{i} Inference in {time.time()-start:.1f}s")
    display(image)
    #image.save("out/img.png")