In [None]:
import io
import random
import time

import diffusers
import requests
import torch
import PIL

# To test performance delta.
use_cuda = torch.cuda.is_available()

def inference(prompt, n_images, guidance, steps, width, height, seed, img, mask, neg_prompt):
    """Supports only 512x512 images"""
    generator = None
    if use_cuda:
        generator = torch.Generator("cuda").manual_seed(seed) if seed else None
    elif seed:      
        generator = torch.Generator()
        generator.manual_seed(seed)
    pipe = diffusers.DiffusionPipeline.from_pretrained(
        "stabilityai/stable-diffusion-2-inpainting",
        revision="fp16" if use_cuda else "fp32",
        torch_dtype=torch.float16 if use_cuda else torch.float32)
    pipe.scheduler = diffusers.DPMSolverMultistepScheduler.from_config(
        pipe.scheduler.config)
    if use_cuda:
        pipe.to("cuda")
        pipe.enable_attention_slicing()
    img = square_padding(img)
    mask = square_padding(mask)
    # # ratio = min(height / img.height, width / img.width)
    # ratio = min(512 / img.height, 512 / img.width)
    # img = img.resize((int(img.width * ratio), int(img.height * ratio)), Image.LANCZOS)
    # mask = mask.resize((int(mask.width * ratio), int(mask.height * ratio)), Image.LANCZOS)
    return pipe(
        prompt,
        image=img.resize((512, 512)),
        mask_image=mask.resize((512, 512)),
        num_images_per_prompt=n_images,
        negative_prompt=neg_prompt,
        num_inference_steps=steps,
        guidance_scale=guidance,
        # width=width,
        # height=height,
        generator=generator).images

def square_padding(img):
    width, height = img.size
    if width == height:
        return img
    new_size = max(width, height)
    new_img = Image.new('RGB', (new_size, new_size), (0, 0, 0, 255))
    new_img.paste(img, ((new_size - width) // 2, (new_size - height) // 2))
    return new_img

def run(prompt, image, mask):
    n_images = 1 # number of images
    neg_prompt = ""
    guidance = 7.5 # max = 15
    steps = 25 # [2, 100]
    width = 768  # [64, 1024] step=8
    height = 768 # [64, 1024] step=8
    seed = 10 # random
    # seed = random.randint(0, 2147483647)
    print(f"Seed: {seed}")
    start = time.time()
    gallery = inference(prompt, n_images, guidance, steps, width, height, seed, image, mask, neg_prompt)
    print("Took %.1fs" % (time.time()-start))
    return gallery[0]

def load(url):
    r = requests.get(url)
    return PIL.Image.open(io.BytesIO(r.content))

img = load("https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png")
mask = load("https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png")

out = run("a dangerous bench", img, mask)
out.save("out/bench.png")
out