# Stable diffusion using gradio

## Overview

Stable diffusionのtext2imgを利用した画像生成のサンプルコードです。実行環境としては、Google ColabのGPUを想定しています。
また、gradioを利用してUIを提供しています。
実行にはhugging faceのアカウントを取得し、access tokenを取得する必要があります。取得したaccess tokenは、 `notebook_login` で入力してください。入力のためのUIが出現します。

## Install packages

In [None]:
def _install_packages() -> None:
    # Install stable diffusion
    !pip install --quiet --no-cache \
        diffusers==0.3.0 \
        ftfy \
        scipy \
        transformers
    # Install fugginface hub
    !pip install --quiet --no-cache huggingface_hub
    # Install ui
    !pip install --quiet --no-cache gradio


_install_packages()

## Import packages

In [None]:
from __future__ import annotations

from getpass import getpass
from pathlib import Path
import functools

from diffusers import StableDiffusionPipeline
import torch
import gradio as gr
from huggingface_hub import notebook_login
from IPython.display import display
from IPython.display import Image as displayImage
from PIL.Image import Image

## Setup

### Device

In [None]:
# Select cuda(use gpu) or cpu
DEVICE = "cuda"

### login hugging face

In [None]:
notebook_login()

## prepare pipe

In [None]:
def _create_pipeline(device: str) -> StableDiffusionPipeline:
    # Create pipeline using stable diffusion v1.4
    pipe = StableDiffusionPipeline.from_pretrained(
        "CompVis/stable-diffusion-v1-4",
        revision="fp16",
        torch_device=torch.float16,
        use_auth_token=True,
    )
    pipe.to(device)

    return pipe


PIPE = _create_pipeline(device=DEVICE)

## UI

### functions

In [None]:
def _infer(
    pipe: StableDiffusionPipeline,
    device: str,  # device info: cuda or cpu.
    prompt: str,
    num_images: float,  # number of generated images.
    height: float,  # image height.
    width: float,  # image width.
    guidance_scale: float,
    num_inference_steps: float,
    seed: float,  # random seed for latent space.
) -> list[Image]:
    # Generate image from prompt string.
    num_images = int(num_images)
    height = int(height)
    width = int(width)
    num_inference_steps = int(num_inference_steps)
    seed = int(seed)

    # initialize latent space.
    generator = torch.Generator(device=device).manual_seed(seed)
    latents = torch.randn(
        (num_images, pipe.unet.in_channels, height // 8, width // 8),
        generator=generator,
        device=device,
    )

    # generate images.
    images = list()
    for latent in latents:
        with torch.autocast(device):
            image = pipe(
                [prompt],
                width=width,
                height=height,
                guidance_scale=guidance_scale,
                num_inference_steps=num_inference_steps,
                latents=latent.unsqueeze(0),
            )["sample"]
        images.append(image[0])

    return images

In [None]:
def gradio_block(pipe: StableDiffusionPipeline, device: str) -> None:
    # UI using gradio.
    with gr.Blocks() as demo:
        prompt = gr.Textbox(label="Enter prompt", lines=7)
        run_button = gr.Button("Run")
        with gr.Accordion("See details", open=False):
            number_of_images = gr.Number(label="Enter number of images", value=1)
            image_height = gr.Slider(
                label="Enter height of images",
                minimum=64,
                maximum=1024,
                value=512,
                step=8,
            )
            image_width = gr.Slider(
                label="Enter width of images",
                minimum=64,
                maximum=1024,
                value=512,
                step=8,
            )
            number_of_inference_step = gr.Slider(
                label="Number of inference step",
                minimum=1,
                maximum=200,
                value=50,
                step=1,
            )
            guidance_scale = gr.Slider(
                label="Guidance scale",
                minimum=1,
                maximum=20,
                value=7.5,
                step=0.1,
            )
            seed = gr.Number(label="random seed", value=42)
        gallery = gr.Gallery(label="Generated images", show_label=False).style(grid=2)
        run_button.click(
            functools.partial(_infer, pipe, device),
            inputs=[
                prompt,
                number_of_images,
                image_height,
                image_width,
                guidance_scale,
                number_of_inference_step,
                seed
            ],
            outputs=gallery,
        )

        demo.launch(debug=False, share=True)  # colab利用の場合は、share=True以外は許容されない。



### View

In [None]:
gradio_block(pipe=PIPE, device=DEVICE)