In [1]:
import gradio as gr
from typing import Dict
from src import SimpleStable
import gc
import torch

pipe = SimpleStable.setup_pipe("Stable Diffusion 1.5")

def its_loading(loaded_model_name: str, chosen_model_name: str):
    global pipe

    if pipe is None or loaded_model_name != chosen_model_name:
        pipe = None
        gc.collect()
        torch.cuda.empty_cache()
        pipe = SimpleStable.setup_pipe(chosen_model_name)

        return f"{chosen_model_name} loaded", chosen_model_name

    return "Model already loaded", loaded_model_name


def is_custom_resolution(resolution: str):
    return resolution == "Custom (Select this and put width and height below)"

def generate(prompt, negative, number_of_images, resolution, custom_width, custom_height, steps, sampler, seed, scale, additional_options, upscale_strength, model_name):
    global pipe 
    width, height = [custom_width, custom_height] if is_custom_resolution(resolution) else SimpleStable.res_dict[resolution]
    images = SimpleStable.gradio_main({
        "model_name": model_name,
        "prompt": prompt,
        "negative": negative if negative != None else "",
        "init_img": None,
        "number_of_images": number_of_images,
        "H" : height - height % 64,
        "W" : width - width % 64,
        "steps": steps,
        "sampler": sampler,
        "scale": scale,
        "eta" : 0.0,
        "tiling" : False,
        "upscale": False,
        "seed": seed,
        "add_keyword": False
    }, pipe)

    return images


def generate_options():
    prompt = gr.Textbox(placeholder = "Describe a prompt here", label = "Prompt")
    negative = gr.Textbox(placeholder = "Negative prompt", label = "Negative")
    with gr.Row():
        with gr.Column(scale=1):
            number_of_images = gr.Number(value=1, precision=0, label="Number of Images")
        with gr.Column(scale=2):
            resolution = gr.Dropdown(choices = list(SimpleStable.res_dict.keys()), label="Image Resolution", value="Square 512x512 (default, good for most models)")

    with gr.Accordion("Advanced Settings"):
        with gr.Row():
            custom_width = gr.Slider(minimum = 512, maximum = 1152, value= 512, step = 64, label="Width (if Custom is selected)", interactive = True)
            custom_height = gr.Slider(minimum = 512, maximum = 1152, value= 512, step = 64, label="Height (if Custom is selected)", interactive = True)
        with gr.Row():
            steps = gr.Slider(minimum = 1, maximum = 100, value= 20, step = 1, label="Step Count", interactive = True)
            sampler = gr.Dropdown(choices = SimpleStable.sampler_list, label="Sampler", value="Euler a")
        with gr.Row():
            seed = gr.Number(value=-1, precision=0, label="Seed")
            scale = gr.Slider(minimum = 1, maximum = 20, value= 7, step = 0.5, label="Guidance Scale", interactive = True)
        with gr.Row():
            additional_options = gr.CheckboxGroup(["Tiling", "SD Upscale"], interactive=True)
            upscale_strength = gr.Slider(minimum = 0.1, maximum = 1, value=0.2, step = 0.05, label="Guidance Scale", interactive = True)
   
    
    return prompt, negative, number_of_images, resolution, custom_width, custom_height, steps, sampler, seed, scale, additional_options, upscale_strength

with gr.Blocks(css="/src/gradio.css") as main:
    current_loaded_model_name = gr.State("Stable Diffusion 1.5")
    with gr.Row():
        model_name = gr.Dropdown(choices = list(SimpleStable.model_dict.keys()), value = "Stable Diffusion 1.5", show_label=False)
        model_submit = gr.Button(value="Load Model", interactive=True)
    loading_status = gr.Markdown("")

    prompt, negative, number_of_images, resolution, custom_width, custom_height, steps, sampler, seed, scale, additional_options, upscale_strength = generate_options()

    button = gr.Button(label="Generate")
    image_output = gr.Gallery(interactive = False)

    model_submit.click(its_loading, inputs=[current_loaded_model_name, model_name], outputs=[loading_status, current_loaded_model_name])
    button.click(generate, inputs=[prompt, negative, number_of_images, resolution, custom_width, custom_height, steps, sampler, seed, scale, additional_options, upscale_strength, current_loaded_model_name], outputs=[image_output])

main.queue()
main.launch(debug=True)

Fetching 15 files:   0%|          | 0/15 [00:00<?, ?it/s]

Running on local URL:  http://127.0.0.1:7861

To create a public link, set `share=True` in `launch()`.


Using the seed 4026901179
a cat


  0%|          | 0/20 [00:00<?, ?it/s]

Fetching 15 files:   0%|          | 0/15 [00:00<?, ?it/s]

Using the seed 2596814745
nvinkpunk, a cat


  0%|          | 0/20 [00:00<?, ?it/s]