<a href="https://colab.research.google.com/github/karaage0703/stable-diffusion-colab-tools/blob/main/001_stable_diffusion_gui_basic.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Stable Diffusion GUI Basic
Stable Diffusion easy and useful GUI basic tool

GitHub repository is below:
[stable-diffusion-colab-tools](https://github.com/karaage0703/stable-diffusion-colab-tools)

In [None]:
#@title **Hugging Face Login**
#@markdown　You need access token of Hugging Face.
!pip -qq install diffusers==0.3.0

!pip -qq install transformers scipy ftfy gradio
!pip -qq install "ipywidgets>=7,<8"

from google.colab import output
output.enable_custom_widget_manager()

from huggingface_hub import notebook_login
notebook_login()

In [None]:
#@title **Launch App**
#@markdown　Execute and click URL ex: `Running on public URL: https://xxxx.gradio.app` 

import gradio as gr
import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler


#@markdown　Select model 

device = "cuda"
model_id = "CompVis/stable-diffusion-v1-4" #@param ["CompVis/stable-diffusion-v1-4", "hakurei/waifu-diffusion", "naclbit/trinart_stable_diffusion_v2"] {allow-input: true}

if model_id == "CompVis/stable-diffusion-v1-4":
    pipe = StableDiffusionPipeline.from_pretrained(
        model_id,
        revision="fp16",
        torch_dtype=torch.float16,
        use_auth_token=True,
    ).to(device)

if model_id == "hakurei/waifu-diffusion":
    pipe = StableDiffusionPipeline.from_pretrained(
        model_id,
        torch_dtype=torch.float16,
        revision="fp16",
        scheduler=DDIMScheduler(
            beta_start=0.00085,
            beta_end=0.012,
            beta_schedule="scaled_linear",
            clip_sample=False,
            set_alpha_to_one=False,
        ),
        use_auth_token=True,
    ).to(device)

if model_id == "naclbit/trinart_stable_diffusion_v2":
    pipe = StableDiffusionPipeline.from_pretrained(
        model_id,
        torch_dtype=torch.float16,
        revision="diffusers-60k",
        use_auth_token=True,
    ).to(device)



# Full model
#pipe = StableDiffusionPipeline.from_pretrained(
#    model_id,
#    use_auth_token=True,
#).to(device)


def infer(prompt, num_images, num_inference_steps, guidance_scale_value, width_images, height_images, seed_number):
    generator = torch.Generator(device=device)
    latents = None
    seeds = []

    width = int(width_images)
    height = int(height_images)
    num_images = int(num_images)
    num_inference_steps = int(num_inference_steps)
    seed_number = int(seed_number)

    images = []

    for _ in range(num_images):
        # Get a new random seed, store it and use it as the generator state
        if seed_number < 0:
            seed = generator.seed()
        else:
            seed = seed_number

        print('seed=' + str(seed))
        seeds.append(seed)
        generator = generator.manual_seed(seed)
    
        image_latents = torch.randn(
            (1, pipe.unet.in_channels, height // 8, width // 8),
            generator = generator,
            device = device
        )
        latents = image_latents if latents is None else torch.cat((latents, image_latents))

    for latent in latents:
        with torch.autocast('cuda'):
            image = pipe(
                [prompt],
                width=width,
                height=height,
                guidance_scale=guidance_scale_value,
                num_inference_steps=num_inference_steps,
                latents = latent.unsqueeze(dim=0)
        )['sample']
        images.append(image[0])

    return images


from IPython.display import clear_output

block = gr.Blocks(css=".container { max-width: 800px; margin: auto; }")

with block as demo:
    gr.Markdown("<h1><center>Stable Diffusion Tool</center></h1>")
    gr.Markdown(
        'Stable Diffusion useful web tool'
    )
    with gr.Group():
        with gr.Box():
            gr.Markdown(
                'Enter prompt and Run!!'
            )
            with gr.Row().style(mobile_collapse=False, equal_height=True):

                text = gr.Textbox(
                    label='Enter prompt', show_label=False, max_lines=1
                ).style(
                    border=(True, False, True, True),
                    rounded=(True, False, False, True),
                    container=False,
                )
                btn = gr.Button("Run").style(
                    margin=False,
                    rounded=(False, True, True, False),
                )

        num_images = gr.Number(
                    label='Number of images', value=1
                )

        seed_number = gr.Number(
                    label='Seed(-1 is random)', value=-1
                )

        num_inference_steps = gr.Slider(
                    label='Number of inference steps', minimum=1, maximum=200, value=50
                )

        guidance_scale_value = gr.Slider(
                    label='Guidance scale', minimum=1, maximum=20, value=7.5, step=0.1
                )

        width_images = gr.Slider(
                    label='Width of images', minimum=64, maximum=640, value=512, step=64
                )

        height_images = gr.Slider(
                    label='Height of images', minimum=64, maximum=640, value=512, step=64
                )

        gallery = gr.Gallery(label="Generated images", show_label=False).style(
            grid=[2], height="auto"
        )

        btn.click(infer,
                 inputs=[text, num_images, num_inference_steps, guidance_scale_value, width_images, height_images, seed_number], outputs=gallery)

    gr.Markdown(
        """___
   <p style='text-align: center'>
   Created by CompVis and Stability AI
   <br/>
   </p>"""
    )

clear_output()
demo.launch(debug=True)

## Reference
Special Thanks
- https://note.com/npaka/n/ndd549d2ce556
- http://cedro3.com/ai/image2image/
- https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb
- https://huggingface.co/hakurei/waifu-diffusion
- https://huggingface.co/naclbit/trinart_stable_diffusion_v2