In [None]:
#install
#@markdown If the code to the left is too large or annoying, double click on this text to hide it\
#@markdown Click on the triangle play button to the left to start setting up Simple Stable\
#@markdown The interface will pop up below\
#@markdown You can click on the link after "Running on public URL" to open up the interface in a new tab or window
#@markdown 
#@markdown The time estimates are all inaccurate\
#@markdown Loading a new model for the first time in a session will take a few minutes to download, but then switching between downloaded models will be faster
#@markdown

import os

try:
  import torch
  from src import SimpleStable
except ImportError as e:
  print("Installing required libraries...")
  !pip3 install torch torchvision torchaudio diffusers transformers accelerate scipy pillow tqdm requests huggingface_hub ipywidgets lark gradio omegaconf --extra-index-url https://download.pytorch.org/whl/cu116 > /dev/null
  if os.name == "nt": #if windows
    !pip install https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
  else:
    #!pip install triton==2.0.0.dev20221202 xformers==0.0.16rc424
    !pip install https://github.com/brian6091/xformers-wheels/releases/download/0.0.15.dev0%2B4c06c79/xformers-0.0.15.dev0+4c06c79.d20221205-cp38-cp38-linux_x86_64.whl
  %cd /content/
  !git clone https://github.com/cadaeix/simplest-stable.git > /dev/null
  %cd /content/simplest-stable/
  
  
%cd /content/simplest-stable/
from datetime import datetime
from IPython.display import display, clear_output
import random
import gradio as gr
import gradio.routes
from src import SimpleStable
import gc
import torch

class LoadJavaScript():
    def __init__(self):
        self.original_template = gradio.routes.templates.TemplateResponse

        with open("src/gradio.js", "r", encoding="utf8") as jsfile:
            self.javascript = f'<script>{jsfile.read()}</script>'
        
        gradio.routes.templates.TemplateResponse = self.template_response

    def template_response(self, *args, **kwargs):
        response = self.original_template(*args, **kwargs)
        response.body = response.body.replace(
            '</head>'.encode('utf-8'), f"{''.join(self.javascript)}\n</head>".encode("utf-8")
        )
        response.init_headers()
        return response

def return_selected_image_from_gallery(i):
    return i["name"] if i else None

outputs_path = "images/"
if not os.path.exists(outputs_path):
    os.mkdir(outputs_path)
    
session_folder = os.path.join(outputs_path, datetime.now().strftime("%Y_%m_%d"))
if not os.path.exists(session_folder):
    os.mkdir(session_folder)

print("Loading models and files...")

css = ""

with open("src/gradio.css") as file:
    css += file.read() + "\n"

pipe = SimpleStable.setup_pipe("Stable Diffusion 1.5")

clear_output(wait=False)

def load_model(loaded_model_name: str, chosen_model_name: str, progress=gr.Progress(track_tqdm=True)):
    global pipe

    if pipe is None or loaded_model_name != chosen_model_name:
        pipe = None
        gc.collect()
        torch.cuda.empty_cache()
        pipe = SimpleStable.setup_pipe(chosen_model_name)

        return f"{chosen_model_name} loaded", chosen_model_name

    return "Model already loaded", loaded_model_name

def is_custom_resolution(resolution: str):
    return resolution == "Custom (Select this and put width and height below)"

def generate(mode, prompt, negative, number_of_images, resolution, custom_width, custom_height, steps, sampler, seed, scale, additional_options, upscale_strength, input_image, img2img_strength, inpaint_image, inpaint_strength, model_name, progress=gr.Progress(track_tqdm=True)):
    global pipe, session_folder

    if mode == "txt2img":
        init_img = None
        mask_image = None
        strength = None
    elif mode == "img2img":
        init_img = input_image
        mask_image = None
        strength = img2img_strength
    elif mode == "inpainting":
        init_img = inpaint_image["image"]
        mask_image = inpaint_image["mask"]
        strength = inpaint_strength

    negative = negative if negative != None else ""
    used_seed = random.randint(0, 2**32) if seed < 0 else seed
    width, height = [custom_width, custom_height] if is_custom_resolution(resolution) else SimpleStable.res_dict[resolution]

    if "Insert standard Danbooru model quality prompt" in additional_options:
        prompt = "masterpiece, best quality, " + prompt
        standard_negative = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username"
        neagtive = standard_negative if negative is None else standard_negative + ", " + negative

    images = SimpleStable.gradio_main({
        "model_name": model_name,
        "prompt": prompt,
        "negative": negative if negative != None else "",
        "init_img": init_img,
        "mask_image": mask_image,
        "strength": strength,
        "number_of_images": number_of_images,
        "H" : height - height % 64,
        "W" : width - width % 64,
        "steps": steps,
        "sampler": sampler,
        "scale": scale,
        "eta" : 0.0,
        "tiling" : "Tiling" in additional_options,
        "upscale": "SD Upscale" in additional_options,
        "upscale_strength": upscale_strength if "SD Upscale" in additional_options else None,
        "detail_scale" : 10,
        "seed": used_seed,
        "add_keyword": "Don't insert model keyword" not in additional_options,
        "outputs_folder": session_folder
    }, pipe)

    return images, used_seed, f"{prompt} + {seed}"

def generate_options():
    with gr.Row():
        with gr.Column(scale=1):
            number_of_images = gr.Number(value=1, precision=0, label="Number of Images")
        with gr.Column(scale=2):
            resolution = gr.Dropdown(choices = list(SimpleStable.res_dict.keys()), label="Image Resolution", value="Square 512x512 (default, good for most models)")

    with gr.Accordion(label="Advanced Settings", elem_id="adv_settings"):
        with gr.Row():
            with gr.Column():
                custom_width = gr.Slider(minimum = 512, maximum = 1152, value= 512, step = 64, label="Width (if Custom is selected)", interactive = True, elem_id="custom_width")
                sampler = gr.Dropdown(choices = SimpleStable.sampler_list, label="Sampler", value="Euler a", elem_id="sampler_choice")
                with gr.Row():
                    with gr.Column(elem_id="seed_col"):
                        seed = gr.Number(value=-1, precision=0, label="Seed", interactive=True)
                    with gr.Column(elem_id="seed_button_col"):
                        reuse_seed_button = gr.Button(value="Last Seed", elem_id="reuse_seed")
                        random_seed_button = gr.Button(value="Random Seed", elem_id="random_seed")
                additional_options = gr.CheckboxGroup(["Tiling", "SD Upscale", "Don't insert model keyword", "Insert standard Danbooru model quality prompt"], interactive=True, label="Additional Settings")
            with gr.Column():
                custom_height = gr.Slider(minimum = 512, maximum = 1152, value= 512, step = 64, label="Height (if Custom is selected)", interactive = True, elem_id="custom_height")
                steps = gr.Slider(minimum = 1, maximum = 100, value= 20, step = 1, label="Step Count", interactive = True, elem_id="step_count")
                scale = gr.Slider(minimum = 1, maximum = 20, value= 7, step = 0.5, label="Guidance Scale", interactive = True, elem_id="guidance_scale")
                upscale_strength = gr.Slider(minimum = 0.1, maximum = 1, value=0.2, step = 0.05, label="Upscale Strength", interactive = True, elem_id="upscale_strength")
    
    return number_of_images, resolution, custom_width, custom_height, steps, sampler, seed, scale, additional_options, upscale_strength, reuse_seed_button, random_seed_button

def has_image(input):
    return {
        to_img2img_button: gr.update(visible = (input is not None)),
        to_inpaint_button: gr.update(visible = (input is not None)),
    }

def show_state(input: str):
    states = {
        "txt2img": [True, False, False],
        "img2img": [False, True, False],
        "inpainting": [False, False, True]
    }

    txt2img_button = "primary" if states[input][0] else "secondary"
    img2img_button = "primary" if states[input][1] else "secondary" 
    inpaint_button = "primary" if states[input][2] else "secondary"

    return {
        current_mode: input,
        input_image: gr.update(visible = states[input][1]), 
        img2img_strength: gr.update(visible = states[input][1]),
        inpaint_image: gr.update(visible = states[input][2]), 
        inpaint_strength: gr.update(visible = states[input][2]),
        txt2img_show: gr.update(variant = txt2img_button), 
        img2img_show: gr.update(variant = img2img_button),  
        inpaint_show: gr.update(variant = inpaint_button)
        }

def show_state_and_clear_inpaint(input: str):
    states = {
        "txt2img": [True, False, False],
        "img2img": [False, True, False],
        "inpainting": [False, False, True]
    }

    txt2img_button = "primary" if states[input][0] else "secondary"
    img2img_button = "primary" if states[input][1] else "secondary" 
    inpaint_button = "primary" if states[input][2] else "secondary"

    return {
        current_mode: input,
        input_image: gr.update(visible = states[input][1]), 
        img2img_strength: gr.update(visible = states[input][1]),
        inpaint_image: gr.update(value = None, visible = states[input][2]), 
        inpaint_strength: gr.update(visible = states[input][2]),
        txt2img_show: gr.update(variant = txt2img_button), 
        img2img_show: gr.update(variant = img2img_button),  
        inpaint_show: gr.update(variant = inpaint_button)
        }

#run
load_javascript = LoadJavaScript()
with gr.Blocks(css=css, title="Simple Stable") as main:
    current_loaded_model_name = gr.State("Stable Diffusion 1.5")
    current_mode = gr.State("txt2img")
    last_used_seed = gr.State(-1)

    with gr.Row(elem_id="model_row"):
        model_name = gr.Dropdown(choices = list(SimpleStable.model_dict.keys()), value = "Stable Diffusion 1.5", show_label=False, elem_id="model_choice")
        model_submit = gr.Button(value="Load Model", interactive=True, elem_id="model_submit")
    loading_status = gr.Markdown("", elem_id="model_status")

    with gr.Row():
        with gr.Column(scale=3):
            with gr.Row():
                txt2img_show = gr.Button(value="txt2img", variant="primary")
                img2img_show = gr.Button(value="img2img")
                inpaint_show = gr.Button(value="inpainting")

            prompt = gr.Textbox(placeholder = "Describe a prompt here", label = "Prompt")
            negative = gr.Textbox(placeholder = "Negative prompt", label = "Negative")

            input_image = gr.Image(value=None, source="upload", interactive=True, type="pil", visible=False, elem_id="img2img_input")
            inpaint_image = gr.Image(value=None, source="upload", interactive=True, type="pil", visible=False, tool="sketch", elem_id="inpaint_input")
            img2img_strength = gr.Slider(minimum = 0.1, maximum = 1, value=0.75, step = 0.05, label="img2img strength", interactive = True, visible=False, elem_id="img2img_strength")
            inpaint_strength = gr.Slider(minimum = 0.1, maximum = 1, value=0.75, step = 0.05, label="inpaint strength", interactive = True, visible=False, elem_id="inpaint_strength")

            number_of_images, resolution, custom_width, custom_height, steps, sampler, seed, scale, additional_options, upscale_strength, reuse_seed_button, random_seed_button = generate_options()

        with gr.Column(scale=2):
            with gr.Box(elem_id="output_box") as output_box:
                with gr.Row(elem_id="generate_row"):
                    button = gr.Button(value="Generate", variant="primary", elem_id="generate_button")
                    #interrupt_button = gr.Button(value="Interrupt", variant="secondary", elem_id="generate_button")
                image_output = gr.Gallery(interactive = False, elem_id="output_gallery")
                with gr.Row(elem_id="edit_row"):
                    to_img2img_button = gr.Button(value="img2img Selected Image", variant="secondary", elem_id="to_img2img_button", visible=False)
                    to_inpaint_button = gr.Button(value="inpaint Selected Image", variant="secondary", elem_id="to_inpaint_button", visible=False)

    hidden_state = gr.Markdown("", visible=False)

    model_submit.click(load_model, inputs=[current_loaded_model_name, model_name], outputs=[loading_status, current_loaded_model_name])

    reuse_seed_button.click(lambda x: x, inputs=[last_used_seed], outputs=[seed])
    random_seed_button.click(lambda : -1, inputs=[], outputs=[seed])

    txt2img_show.click(show_state, inputs=[txt2img_show], outputs = [current_mode, input_image, img2img_strength, inpaint_image, inpaint_strength, txt2img_show, img2img_show, inpaint_show])
    img2img_show.click(show_state, inputs=[img2img_show], outputs = [current_mode, input_image, img2img_strength, inpaint_image, inpaint_strength, txt2img_show, img2img_show, inpaint_show])
    inpaint_show.click(show_state, inputs=[inpaint_show], outputs = [current_mode, input_image, img2img_strength, inpaint_image, inpaint_strength, txt2img_show, img2img_show, inpaint_show])

    #interrupt_button.click(None, cancels=[image_output])
    to_img2img_button.click(show_state, inputs=[img2img_show], outputs = [current_mode, input_image, img2img_strength, inpaint_image, inpaint_strength, txt2img_show, img2img_show, inpaint_show])
    to_inpaint_button.click(show_state_and_clear_inpaint, inputs=[inpaint_show], outputs = [current_mode, input_image, img2img_strength, inpaint_image, inpaint_strength, txt2img_show, img2img_show, inpaint_show])
    to_img2img_button.click(return_selected_image_from_gallery, inputs=[image_output], outputs =[input_image], _js="findSelectedImageFromGallery")
    to_inpaint_button.click(return_selected_image_from_gallery, inputs=[image_output], outputs =[inpaint_image], _js="findSelectedImageFromGallery")
    button.click(generate, inputs=[current_mode, prompt, negative, number_of_images, resolution, custom_width, custom_height, steps, sampler, seed, scale, additional_options, upscale_strength, input_image, img2img_strength, inpaint_image, inpaint_strength, current_loaded_model_name], outputs=[image_output, last_used_seed, hidden_state])
    button.click(has_image, inputs=[image_output], outputs=[to_img2img_button, to_inpaint_button])
    #hidden_state.change(lambda: "", inputs=[], outputs=[hidden_state], _js="setupOutputGallery")

main.queue()
main.launch(debug=True)