In [1]:
#install
#@markdown If the code to the left is too large or annoying, double click on this text to hide it<br/>See above cell for model descriptions and settings explanation <br />
try:
  import torch
  from src import SimpleStable
except ImportError as e:
  print("Installing required libraries...")
  !pip3 install torch torchvision torchaudio diffusers transformers accelerate scipy pillow tqdm requests huggingface_hub ipywidgets lark --extra-index-url https://download.pytorch.org/whl/cu116 > /dev/null
  %cd /content/
  !git clone https://github.com/cadaeix/simplest-stable.git > /dev/null
  outputs_path = "/content/images/"
  !mkdir -p $outputs_path
  #print(f"Outputs will be saved to {outputs_path}.")

import gradio as gr
from typing import Dict
from src import SimpleStable
import gc
import torch

print("Loading models and files...")

css = ""

with open("src/gradio.css") as file:
    css += file.read() + "\n"

pipe = SimpleStable.setup_pipe("Stable Diffusion 1.5")

def load_model(loaded_model_name: str, chosen_model_name: str):
    global pipe

    if pipe is None or loaded_model_name != chosen_model_name:
        pipe = None
        gc.collect()
        torch.cuda.empty_cache()
        pipe = SimpleStable.setup_pipe(chosen_model_name)

        return f"{chosen_model_name} loaded", chosen_model_name

    return "Model already loaded", loaded_model_name

def is_custom_resolution(resolution: str):
    return resolution == "Custom (Select this and put width and height below)"

def generate(mode, prompt, negative, number_of_images, resolution, custom_width, custom_height, steps, sampler, seed, scale, additional_options, upscale_strength, input_image, img2img_strength, inpaint_image, inpaint_strength, model_name):
    global pipe 

    if mode == "txt2img":
        init_img = None
        mask_image = None
        strength = None
    elif mode == "img2img":
        init_img = input_image
        mask_image = None
        strength = img2img_strength
    elif mode == "inpainting":
        init_img = inpaint_image["image"]
        mask_image = inpaint_image["mask"]
        strength = inpaint_strength

    width, height = [custom_width, custom_height] if is_custom_resolution(resolution) else SimpleStable.res_dict[resolution]
    images = SimpleStable.gradio_main({
        "model_name": model_name,
        "prompt": prompt,
        "negative": negative if negative != None else "",
        "init_img": init_img,
        "mask_image": mask_image,
        "strength": strength,
        "number_of_images": number_of_images,
        "H" : height - height % 64,
        "W" : width - width % 64,
        "steps": steps,
        "sampler": sampler,
        "scale": scale,
        "eta" : 0.0,
        "tiling" : "Tiling" in additional_options,
        "upscale": "SD Upscale" in additional_options,
        "upscale_strength": upscale_strength if "SD Upscale" in additional_options else None,
        "detail_scale" : 10,
        "seed": seed,
        "add_keyword": "Don't insert model keyword" not in additional_options
    }, pipe)

    return images


def generate_options():
    with gr.Row():
        with gr.Column(scale=1):
            number_of_images = gr.Number(value=1, precision=0, label="Number of Images")
        with gr.Column(scale=2):
            resolution = gr.Dropdown(choices = list(SimpleStable.res_dict.keys()), label="Image Resolution", value="Square 512x512 (default, good for most models)")

    with gr.Accordion("Advanced Settings"):
        with gr.Row():
            custom_width = gr.Slider(minimum = 512, maximum = 1152, value= 512, step = 64, label="Width (if Custom is selected)", interactive = True)
            custom_height = gr.Slider(minimum = 512, maximum = 1152, value= 512, step = 64, label="Height (if Custom is selected)", interactive = True)
        with gr.Row():
            steps = gr.Slider(minimum = 1, maximum = 100, value= 20, step = 1, label="Step Count", interactive = True)
            sampler = gr.Dropdown(choices = SimpleStable.sampler_list, label="Sampler", value="Euler a")
        with gr.Row():
            seed = gr.Number(value=-1, precision=0, label="Seed")
            scale = gr.Slider(minimum = 1, maximum = 20, value= 7, step = 0.5, label="Guidance Scale", interactive = True)
        with gr.Row():
            additional_options = gr.CheckboxGroup(["Tiling", "SD Upscale", "Don't insert model keyword"], interactive=True)
            upscale_strength = gr.Slider(minimum = 0.1, maximum = 1, value=0.2, step = 0.05, label="Guidance Scale", interactive = True)
   
    
    return number_of_images, resolution, custom_width, custom_height, steps, sampler, seed, scale, additional_options, upscale_strength

def show_state(input: str):
    states = {
        "txt2img": [True, False, False],
        "img2img": [False, True, False],
        "inpainting": [False, False, True]
    }

    txt2img_button = "primary" if states[input][0] else "secondary"
    img2img_button = "primary" if states[input][1] else "secondary" 
    inpaint_button = "primary" if states[input][2] else "secondary"

    return {
        current_mode: input,
        input_image: gr.update(visible = states[input][1]), 
        img2img_strength: gr.update(visible = states[input][1]),
        inpaint_image: gr.update(visible = states[input][2]), 
        inpaint_strength: gr.update(visible = states[input][2]),
        txt2img_show: gr.update(variant = txt2img_button), 
        img2img_show: gr.update(variant = img2img_button),  
        inpaint_show: gr.update(variant = inpaint_button)
        }


with gr.Blocks(css=css) as main:
    current_loaded_model_name = gr.State("Stable Diffusion 1.5")
    current_mode = gr.State("txt2img")

    with gr.Row():
        model_name = gr.Dropdown(choices = list(SimpleStable.model_dict.keys()), value = "Stable Diffusion 1.5", show_label=False)
        model_submit = gr.Button(value="Load Model", interactive=True)
    loading_status = gr.Markdown("")

    with gr.Row():
        with gr.Column(scale=3):
            with gr.Row():
                txt2img_show = gr.Button(value="txt2img", variant="primary")
                img2img_show = gr.Button(value="img2img")
                inpaint_show = gr.Button(value="inpainting")

            prompt = gr.Textbox(placeholder = "Describe a prompt here", label = "Prompt")
            negative = gr.Textbox(placeholder = "Negative prompt", label = "Negative")

            with gr.Blocks():
                input_image = gr.Image(value=None, source="upload", interactive=True, type="pil", visible=False, elem_id="img2img_input")
                img2img_strength = gr.Slider(minimum = 0.1, maximum = 1, value=0.75, step = 0.05, label="img2img strength", interactive = True, visible=False, elem_id="img2img_strength")

                inpaint_image = gr.Image(value=None, source="upload", interactive=True, type="pil", visible=False, tool="sketch", elem_id="inpaint_input")
                inpaint_strength = gr.Slider(minimum = 0.1, maximum = 1, value=0.75, step = 0.05, label="inpaint strength", interactive = True, visible=False, elem_id="inpaint_strength")

            number_of_images, resolution, custom_width, custom_height, steps, sampler, seed, scale, additional_options, upscale_strength = generate_options()
        with gr.Column(scale=2):
            button = gr.Button(value="Generate", variant="primary")
            image_output = gr.Gallery(interactive = False)

    model_submit.click(load_model, inputs=[current_loaded_model_name, model_name], outputs=[loading_status, current_loaded_model_name])

    txt2img_show.click(show_state, inputs=[txt2img_show], outputs = [current_mode, input_image, img2img_strength, inpaint_image, inpaint_strength, txt2img_show, img2img_show, inpaint_show])
    img2img_show.click(show_state, inputs=[img2img_show], outputs = [current_mode, input_image, img2img_strength, inpaint_image, inpaint_strength, txt2img_show, img2img_show, inpaint_show])
    inpaint_show.click(show_state, inputs=[inpaint_show], outputs = [current_mode, input_image, img2img_strength, inpaint_image, inpaint_strength, txt2img_show, img2img_show, inpaint_show])

    button.click(generate, inputs=[current_mode, prompt, negative, number_of_images, resolution, custom_width, custom_height, steps, sampler, seed, scale, additional_options, upscale_strength, input_image, img2img_strength, inpaint_image, inpaint_strength, current_loaded_model_name], outputs=[image_output])


main.queue()
main.launch(debug=True)