# stablediffusion-infinity

https://github.com/lkwq007/stablediffusion-infinity

Outpainting with Stable Diffusion on an infinite canvas

## import libs

In [None]:
!cp -nrf PyPatchMatch/csrc .
!cp -nf PyPatchMatch/Makefile .
!cp -nf PyPatchMatch/travis.sh .
!cp -nf PyPatchMatch/patch_match.py .

In [None]:
import time

from ipycanvas import Canvas, hold_canvas, MultiCanvas
import ipywidgets as widgets

from utils import *
from canvas import InfCanvas

In [None]:
import diffusers
class MySDSC(diffusers.pipelines.stable_diffusion.safety_checker.StableDiffusionSafetyChecker):
    def forward(self, clip_input, images):
        images, has_nsfw_concepts = super().forward(clip_input, images)
        has_nsfw_concepts = [False for _ in has_nsfw_concepts]
        return images, has_nsfw_concepts
diffusers.pipelines.stable_diffusion.safety_checker.StableDiffusionSafetyChecker = MySDSC

In [None]:
import numpy as np
import torch
from torch import autocast
from importlib import reload
from diffusers import StableDiffusionPipeline, StableDiffusionInpaintPipeline

## setup StableDiffusionInpaintPipeline

In [None]:
text2img = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=True).to("cuda")
inpaint=StableDiffusionInpaintPipeline(
            vae=text2img.vae,
            text_encoder=text2img.text_encoder,
            tokenizer=text2img.tokenizer,
            unet=text2img.unet,
            scheduler=text2img.scheduler,
            safety_checker=text2img.safety_checker,
            feature_extractor=text2img.feature_extractor
).to("cuda")

## setup UI control

In [None]:
strength_slider = widgets.FloatSlider(
    value=0.75,
    min=0,
    max=1.0,
    step=0.01,
    description="Strength:",
    disabled=False,
    continuous_update=False,
    orientation="horizontal",
    readout=True,
    readout_format=".2f",
    # layout=widgets.Layout(width='100px')
)

step_input = widgets.BoundedIntText(
    value=50,
    min=1,
    max=1000,
    step=1,
    description="Steps:",
    disabled=False,
    layout=widgets.Layout(width="180px"),
)
guidance_input = widgets.FloatText(
    value=7.5,
    step=0.1,
    description="Guidance:",
    disabled=False,
    layout=widgets.Layout(width="180px"),
)
resize_check = widgets.Checkbox(
    value=True,
    description="Resize SD input to 512x512",
    disabled=False,
    indent=False,
    layout=widgets.Layout(width="180px"),
)


## setup InfCanvas

In [None]:
base = InfCanvas(1024, 1024, selection_size=256)
base.setup_widgets()
base.setup_mouse()


def run_outpaint(btn):
    with base.output:
        base.output.clear_output()
        base.read_selection_from_buffer()
        img = base.sel_buffer[:, :, 0:3]
        mask = base.sel_buffer[:, :, -1]
        process_size = 512 if resize_check.value else base.selection_size
        if mask.sum() > 0:
            img, mask = functbl[base.fill_button.value](img, mask)
            init_image = Image.fromarray(img)
            mask = 255 - mask
            # mask=skimage.measure.block_reduce(mask,(8,8),np.max)
            # mask=mask.repeat(8, axis=0).repeat(8, axis=1)
            mask_image = Image.fromarray(mask)
            # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
            with autocast("cuda"):
                images = inpaint(
                    prompt=base.text_input.value,
                    init_image=init_image.resize(
                        (process_size, process_size), resample=Image.Resampling.LANCZOS
                    ),
                    mask_image=mask_image.resize((process_size, process_size)),
                    strength=strength_slider.value,
                    num_inference_steps=step_input.value,
                    guidance_scale=guidance_input.value,
                )["sample"]
        else:
            with autocast("cuda"):
                images = text2img(
                    prompt=base.text_input.value,
                    height=process_size,
                    width=process_size,
                )["sample"]
        out = base.sel_buffer.copy()
        out[:, :, 0:3] = np.array(
            images[0].resize(
                (base.selection_size, base.selection_size),
                resample=Image.Resampling.LANCZOS,
            )
        )
        out[:, :, -1] = 255
        base.fill_selection(out)
        with hold_canvas():
            base.draw_selection_box()


def export_button_clicked(btn):
    with base.output:
        base.output.clear_output()
        img = base.export()
        pil = Image.fromarray(img)
        time_str = time.strftime("%Y%m%d_%H%M%S")
        pil.save(f"outpaint_{time_str}.png")
        print(f"Canvas saved to outpaint_{time_str}.png")
        display(pil)


base.run_button.on_click(run_outpaint)
base.export_button.on_click(export_button_clicked)

def undo_button_clicked(btn):
    with base.output:
        if base.sel_dirty:
            base.canvas[2].clear()
            base.sel_buffer = base.sel_buffer_bak.copy()
            base.sel_dirty = False
            # base.draw_selection_box()

def commit_button_clicked(btn):
    if base.sel_dirty:
        base.write_selection_to_buffer()

retry_button = widgets.Button(
    description= "",
    disabled=False,
    tooltip="Retry",
    icon="refresh",
    layout=widgets.Layout(width="60px")
)
undo_button = widgets.Button(
    description= "",
    disabled=False,
    tooltip="Undo",
    icon="undo",
    layout=widgets.Layout(width="60px")
)
commit_button = widgets.Button(
    description= "",
    disabled=False,
    tooltip="Commit",
    icon="check",
    layout=widgets.Layout(width="60px")
)
control_label=widgets.Label("Commit/Retry/Undo")
retry_button.on_click(run_outpaint)
undo_button.on_click(undo_button_clicked)
commit_button.on_click(commit_button_clicked)

## upload an image?

In [None]:
import io
from PIL import ImageOps
uploader_label = widgets.Label("[Optional] Upload an image? (will be resized to fit into canvas)")
uploader = widgets.FileUpload(
    accept="image/*",
    multiple=False
)
uploader_output = widgets.Output()
upload_button = widgets.Button(description='Confirm')
resize_input = widgets.FloatText(value=0.5, description='Resize')
def start_func(btn):
    with uploader_output:
        uploader_output.clear_output()
        if len(uploader.value)>0:
            keys=list(uploader.value.keys())
            val=uploader.value[keys[-1]]["content"]
            pil=Image.open(io.BytesIO(val))
            print(f"Will use {keys[-1]} as the base image for outpainting")
            pil = pil.resize((int(resize_input.value*pil.size[0]), int(resize_input.value*pil.size[1])))# XXX
            w,h=pil.size
            if w>base.width-100 or h>base.height-100:
                pil=ImageOps.contain(pil, (base.width-100,base.height-100))
            base.buffer_dirty=True
            w,h=pil.size
            arr=np.array(pil.convert("RGBA"))
            yo=(base.height-h)//2
            xo=(base.width-w)//2
            base.buffer*=0
            base.buffer[yo:yo+h,xo:xo+w,0:3]=arr[:,:,0:3]
            base.buffer[yo:yo+h,xo:xo+w,-1]=arr[:,:,-1]
            base.draw_buffer()
        else:
            print("No image uploaded")

upload_button.on_click(start_func)

## have fun here

In [None]:
# disable upload button
# upload_button.disabled=False
lst=[uploader_label,widgets.HBox([uploader,upload_button, resize_input]),uploader_output]
for item in lst:
    display(item)
display_lst=base.display()
display_lst.insert(-1,widgets.HBox([control_label,commit_button,retry_button,undo_button]))
display_lst.insert(-1,widgets.HBox([resize_check,step_input,guidance_input,strength_slider]))
for item in display_lst:
    display(item)