## 1). Install requirements

In [None]:
!pip install torch==1.12.1 numpy opencv-python diffusers==0.14.0 daam==0.0.14 transformers==4.27.4 gradio==3.32.0

## 2). Run the code

You can either use it in the notebook or use the public link for a full screen

In [None]:
# %%
import numpy as np
import torch
import cv2
import gradio as gr
from diffusers import StableDiffusionInpaintPipeline
from daam import trace, set_seed

# %%
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = StableDiffusionInpaintPipeline.from_pretrained(
    "runwayml/stable-diffusion-inpainting",
).to(device)

# %%
daam_heat_map = [0, 0]
prompts = [""]

# %%
def diffusion(img, mask, prompt, neg_prompt, steps, guidance, update_attention_map, text_to_image):

    if text_to_image:
        mask = np.ones(shape=(512, 512), dtype="uint8")
        image = np.ones(shape=(512, 512, 3), dtype="uint8")

    else: # inpainting
        if type(img) != type(None):
            mask = mask[..., 0]
            image = img["image"]
        else:
            return None, prompts[0]
    
    # Diffusion
    with trace(pipe) as tc:
        image = pipe(prompt=prompt, image=image, mask_image=mask, negative_prompt=neg_prompt, num_inference_steps=steps, guidance_scale=guidance, ).images[0]
        # Damm attention weights
        heat_map = tc.compute_global_heat_map()
    if update_attention_map:
        daam_heat_map[0] = heat_map
        prompts[0] = prompt
    return image, prompts[0]

# %%
def update_mask(img, invert, blur, load_attention_map, attention):

    if load_attention_map:
        if type(attention) != type(None):
            mask = 255-attention[..., 0] if invert else attention[..., 0]
            mask = cv2.GaussianBlur(mask, (0, 0), blur) if blur else mask
            return mask
        else:
            return None

    else:
        if type(img) != type(None):
            # Preprocessing
            mask = 255-img["mask"][..., 0] if invert else img["mask"][..., 0]
            mask = cv2.GaussianBlur(mask, (0, 0), blur) if blur else mask
            return mask
        else:
            return None

# %%
def get_daam_map(word):
    heat_map = daam_heat_map[0].compute_word_heat_map(word)
    word_segment = heat_map.heatmap.cpu().numpy()
    word_segment /= word_segment.max()
    word_segment = (word_segment*255.0).astype("uint8")
    word_segment = cv2.resize(word_segment, (512, 512))
    word_segment = (word_segment > 127).astype("uint8")*255
    daam_heat_map[1] = word_segment
    return word_segment

# %%
def load_attention(attention):
    if type(attention) != type(None):
        return attention
    else:
        return None

# %%
with gr.Blocks() as demo:

    with gr.Tab(label="Inpainting"):
        with gr.Row():
            with gr.Column():
                input_image = gr.Image(label="Input image", shape=(512, 512), tool="sketch", source="upload")
                text_to_image = gr.Checkbox(label="Text to image")
                prompt = gr.Textbox(label="Prompt", value="<write what you want>, high resolution, realistic, high quality, high definition, high details, natural colors")
                neg_prompt = gr.Textbox(label="Negative prompt", value="bad details, ugly, low resolution, low quality, unatural colors, bad anatomy, blurry, pixelated, obscure, poor lighting, unclear")
                steps = gr.Slider(label="Inference steps", minimum=20, maximum=300, step=1, value=30)
                guidance = gr.Slider(label="Guidance scale", minimum=1, maximum=20, step=0.5, value=7.5)
            with gr.Column():
                mask_image = gr.Image(label="Input mask", shape=(512, 512))
                invert = gr.Checkbox(label="Invert mask")
                blur = gr.Slider(label="Blur strengh", minimum=0, maximum=30, step=1, value=10)
                load_attention_map = gr.Checkbox(label="Load attention map")
                update_attention_map = gr.Checkbox(label="Update attention map")
            output_image = gr.Image(label="Output image", shape=(512, 512))

        submit = gr.Button(value="Run")

    with gr.Tab(label="Daam attention maps"):
        with gr.Row():
            with gr.Column():
                attention = gr.Image(label="Attention mask", shape=(512, 512))
            with gr.Column():
                label = gr.Label(label="Your prompt")
                word = gr.Textbox(label="Word to visualize")

        get_map = gr.Button(value="Run")

    # Update mask
    input_image.edit(fn=update_mask, inputs=[input_image, invert, blur, load_attention_map, attention], outputs=[mask_image])
    invert.select(fn=update_mask, inputs=[input_image, invert, blur, load_attention_map, attention], outputs=[mask_image])
    blur.release(fn=update_mask, inputs=[input_image, invert, blur, load_attention_map, attention], outputs=[mask_image])
    # Run diffusion
    submit.click(fn=diffusion, inputs=[input_image, mask_image, prompt, neg_prompt, steps, guidance, update_attention_map, text_to_image], outputs=[output_image, label])
    # Attention maps
    get_map.click(fn=get_daam_map, inputs=[word], outputs=[attention])
    load_attention_map.select(fn=load_attention, inputs=[attention], outputs=[mask_image])

demo.launch(share=True)


You can shut down the server by this command

In [None]:
demo.close()