In [None]:
import gradio as gr
from typing import Dict
from src import SimpleStable

def error_str(error, title="Error"):
    return f"""#### {title}
            {error}"""  if error else ""

res_dict = {"Custom (Select this and put width and height below)": "",
            "Square 512x512 (default, good for most models)": [512,512],
            "Landscape 768x512": [768,512],
            "Portrait 512x768": [512,768],
            "Square 768x768 (good for 768 models)": [768,768],
            "Landscape 1152x768 (does not work on free colab)": [1152,768],
            "Portrait 768x1152 (does not work on free colab)":[768,1152]}

model_opt = "Stable Diffusion 1.5"

def change_model_choice(model_name: str):
    global model_opt
    model_opt = model_name

def load_another_model():
    global pipe
    pipe = SimpleStable.setup_pipe(model_opt)
    return pipe

pipe = load_another_model()

def gradio_main(prompt, negative, number_of_images, resolution, custom_width, custom_height, steps, sampler, seed, scale, additional_options, upscale_strength, image=None, init_strength=None):
    global pipe 
    width, height = [custom_width, custom_height] if is_custom_resolution(resolution) else res_dict[resolution]
    opt = {
        "model_name": model_opt,
        "prompt": prompt,
        "negative": negative if negative != None else "",
        "init_img": image,
        "number_of_images": number_of_images,
        "H" : height - height % 64,
        "W" : width - width % 64,
        "steps": steps,
        "sampler": sampler,
        "scale": scale,
        "eta" : 0.0,
        "tiling" : False,
        "upscale": False,
        "seed": 42,
        "add_keyword": False
    }

    image = SimpleStable.gradio_main(opt, pipe)

    return image, None

    # try:
    #     image = SimpleStable.gradio_main(opt, pipe)

    #     return image, None
    # except Exception as e:
    #     return None, gr.update(visible=True, value=error_str(e))



def is_custom_resolution(resolution: str):
    return resolution == "Custom (Select this and put width and height below)"

is_custom = False

with gr.Blocks(css="/src/gradio.css") as main:
    with gr.Row():
        model_name = gr.Dropdown(choices = list(SimpleStable.model_dict.keys()), label="Model", value = "Stable Diffusion 1.5")
        model_submit = gr.Button(value="Load Model")
        
    gr.Markdown("Simple Stable")
    with gr.Tab("txt2img"):
        with gr.Row():
            with gr.Column(scale=3):
                prompt = gr.Textbox(placeholder = "Describe a prompt here", label = "Prompt")
                negative = gr.Textbox(placeholder = "Negative prompt", label = "Negative")
                with gr.Row():
                    with gr.Column(scale=1):
                        number_of_images = gr.Number(value=1, precision=0, label="Number of Images")
                    with gr.Column(scale=2):
                        resolution = gr.Dropdown(choices = list(res_dict.keys()), label="Image Resolution", value="Square 512x512 (default, good for most models)")


                with gr.Accordion("Advanced Settings"):
                    with gr.Row():
                        custom_width = gr.Slider(minimum = 512, maximum = 1152, value= 512, step = 64, label="Width (if Custom is selected)", interactive = True)
                        custom_height = gr.Slider(minimum = 512, maximum = 1152, value= 512, step = 64, label="Height (if Custom is selected)", interactive = True)
                    with gr.Row():
                        steps = gr.Slider(minimum = 1, maximum = 100, value= 20, step = 1, label="Step Count", interactive = True)
                        sampler = gr.Dropdown(choices = SimpleStable.sampler_list, label="Sampler", value="Euler a")
                    with gr.Row():
                        seed = gr.Number(value=1, precision=0, label="Seed")
                        scale = gr.Slider(minimum = 1, maximum = 20, value= 7, step = 0.5, label="Guidance Scale", interactive = True)
                    with gr.Row():
                        additional_options = gr.CheckboxGroup(["Tiling", "SD Upscale"], interactive=True)
                        upscale_strength = gr.Slider(minimum = 0.1, maximum = 1, value=0.2, step = 0.05, label="Guidance Scale", interactive = True)


                submit_button = gr.Button(value="Generate!", variant="primary")
            with gr.Column(scale=2):
                logger = gr.Textbox(label = "Log", interactive=False)
                image_output = gr.Image(interactive = False)
    with gr.Tab("img2img"):
        image_input = gr.Image()
    
    model_name.change(change_model_choice, model_name, [], queue=False)
    model_submit.click(load_another_model, inputs=[], outputs=[], queue=True)
    submit_button.click(gradio_main, inputs=[prompt, negative, number_of_images, resolution, custom_width, custom_height, steps, sampler, seed, scale, additional_options, upscale_strength], outputs=[image_output, logger])

main.queue()
main.launch(debug=True)