In [None]:
import gc
import random
import time

import diffusers
import torch
import PIL

# To test performance delta.
use_cuda = torch.cuda.is_available()

def inference(prompt, n_images, guidance, steps, width, height, seed, img, strength, neg_prompt):
    generator = None
    if use_cuda:
        generator = torch.Generator("cuda").manual_seed(seed) if seed else None
    elif seed:      
        generator = torch.Generator()
        generator.manual_seed(seed)
    # There's no fp32 revision.
    # https://huggingface.co/stabilityai/stable-diffusion-2-1/tree/main
    # https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/fp16/model_index.json
    pipe = diffusers.StableDiffusionImg2ImgPipeline.from_pretrained(
        "stabilityai/stable-diffusion-2-1",
        revision="fp16",
        torch_dtype=torch.float16 if use_cuda else torch.float32)
    pipe.scheduler = diffusers.DPMSolverMultistepScheduler.from_pretrained(
        "stabilityai/stable-diffusion-2-1", subfolder="scheduler")
    if use_cuda:
        pipe.to("cuda")
        pipe.enable_attention_slicing()
    ratio = min(height / img.height, width / img.width)
    img = img.resize(
        (int(img.width * ratio), int(img.height * ratio)), PIL.Image.Resampling.LANCZOS)
    return pipe(
        prompt,
        num_images_per_prompt=n_images,
        negative_prompt=neg_prompt,
        image=img,
        num_inference_steps=steps,
        strength=strength,
        guidance_scale=guidance,
        # width=width,
        # height=height,
        generator=generator).images

def run(prompt, image):
    n_images = 1 # number of images
    neg_prompt = ""
    guidance = 7.5 # max = 15
    steps = 25 # [2, 100]
    width = 768  # [64, 1024] step=8
    height = 768 # [64, 1024] step=8
    seed = 11 # random
    # seed = random.randint(0, 2147483647)
    strength = 0.60 # [0, 1]
    print(f"Seed: {seed}")
    start = time.time()
    gallery = inference(prompt, n_images, guidance, steps, width, height, seed, image, strength, neg_prompt)
    print("Took %.1fs" % (time.time()-start))
    return gallery[0]

def getimg():
    name = "squicat3.png"
    name = "PXL_20221117_233124192.PORTRAIT.jpg"
    img = PIL.Image.open("out/" + name)
    size = img.size
    # Max is 1024x768 or 768x1024?
    while size[0] > 1024 or size[1] > 1024: # or size[0] * size[1] > 786432:
        size = (size[0]//2, size[1]//2)
    if size != img.size:
        print("Resized from", img.size, "to", size)
        img = img.resize(size, PIL.Image.Resampling.LANCZOS)
    return img

p = "a killer robot"
p = "comic book, marvel, superflat, dc comics, graphic novel"
img = run(p, getimg())
img.save("out/comic_book.png")
gc.collect()
img