[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/latent-consistency-model-colab/blob/main/sdxl_lcm_draw_colab.ipynb)

In [None]:
!pip install -q diffusers transformers accelerate peft gradio==3.50.2
!pip install -q git+https://github.com/tencent-ailab/IP-Adapter einops

!mkdir /content/models
!wget https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/ip-adapter_sdxl.safetensors?download=true -O /content/models/ip-adapter_sdxl.safetensors

!mkdir /content/image_encoder
!wget https://huggingface.co/h94/IP-Adapter/raw/main/sdxl_models/image_encoder/config.json -O /content/image_encoder/config.json
!wget https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/image_encoder/model.safetensors?download=true -O /content/image_encoder/model.safetensors

from diffusers import UNet2DConditionModel, DiffusionPipeline, LCMScheduler
import torch

unet = UNet2DConditionModel.from_pretrained("latent-consistency/lcm-sdxl", torch_dtype=torch.float16, variant="fp16")
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", unet=unet, torch_dtype=torch.float16, variant="fp16")

pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
pipe.to("cuda")

from ip_adapter import IPAdapterXL
ip_model = IPAdapterXL(pipe, '/content/image_encoder', '/content/models/ip-adapter_sdxl.safetensors', 'cuda')

import gradio as gr

def generate(prompt, input_image):
  image = ip_model.generate(pil_image=input_image, num_samples=1, num_inference_steps=4, seed=420)[0]
  return image.resize((768, 768))

with gr.Blocks(title=f"Realtime Latent Consistency Model") as demo:
    with gr.Row():
      with gr.Column(scale=23):
          textbox = gr.Textbox(show_label=False, value="a close-up picture of a fluffy cat")

    with gr.Row(variant="default"):
        input_image = gr.Image(
            show_label=False,
            type="pil",
            tool="color-sketch",
            source="canvas",
            height=742,
            width=742,
            brush_radius=10.0,
        )
        output_image = gr.Image(
            show_label=False,
            type="pil",
            interactive=False,
            height=742,
            width=742,
            elem_id="output_image",
        )

    textbox.change(fn=generate, inputs=[textbox, input_image], outputs=[output_image], show_progress=False)
    input_image.change(fn=generate, inputs=[textbox, input_image], outputs=[output_image], show_progress=False)

demo.launch(inline=False, share=True, debug=True)