# Stable Diffusion Demo
## _with preview images_

Stable Diffusion is a state of the art text-to-image model that generates images from text.

For faster generation and forthcoming API access you can try [DreamStudio Beta](http://beta.dreamstudio.ai/).

In [1]:
import torch
from datasets import load_dataset
from PIL import Image  
import re

from preview_decoder import ApproximateDecoder, jpeg_bytes
from generator_pipeline import StableDiffusionGeneratorPipeline, PipelineIntermediateState


model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda"

#If you are running this code locally, you need to either do a 'huggingface-cli login` or paste your User Access Token from here https://huggingface.co/settings/tokens into the use_auth_token field below. 
pipe = StableDiffusionGeneratorPipeline.from_pretrained(
    model_id, use_auth_token=True, 
    revision="fp16", torch_dtype=torch.float16,
)
pipe = pipe.to(device)
pipe.enable_attention_slicing()
torch.backends.cudnn.benchmark = True

In [2]:
#When running locally, you won`t have access to this, so you can remove this part
#word_list_dataset = load_dataset("stabilityai/word-list", data_files="list.txt", use_auth_token=True)
#word_list = word_list_dataset["train"]['text']
WORD_LIST = 'https://raw.githubusercontent.com/coffee-and-fun/google-profanity-words/main/data/list.txt'
import requests
word_list = [word for word in requests.get(WORD_LIST).text.split('\n') if word and not word.isspace()]

def infer(prompt, samples, steps, scale, seed):
    for filter in word_list:
        if re.search(rf"\b{filter}\b", prompt):
            raise gr.Error(f"""Unsafe content found. Please try again with different prompts. 
            filter: {filter}
            prompt: {prompt}
""")

    generator = torch.Generator(device=device).manual_seed(seed)
    
    with torch.autocast(pipe.device.type):
        yield from pipe.generate(
            [prompt] * samples,
            num_inference_steps=steps,
            guidance_scale=scale,
            generator=generator,
        )


def replace_unsafe_images(output):
    images = []
    safe_image = Image.open(r"unsafe.png")
    for image, is_unsafe in zip(output.images, output.nsfw_content_detected):
        if is_unsafe:
            images.append(safe_image)
        else:
            images.append(pipe.numpy_to_pil(image)[0])
    return images


approximate_decoder = ApproximateDecoder.for_pipeline(pipe)

In [3]:
examples = [
    [
        'A high tech solarpunk utopia in the Amazon rainforest',
        2,
        20,
        7.5,
        1024,
    ],
    [
        'A pikachu fine dining with a view to the Eiffel Tower',
        2,
        20,
        7,
        1024,
    ],
    [
        'A mecha robot in a favela in expressionist style',
        2,
        20,
        7,
        1024,
    ],
    [
        'an insect robot preparing a delicious meal',
        2,
        20,
        7,
        1024,
    ],
    [
        "A small cabin on top of a snowy mountain in the style of Disney, artstation",
        2,
        20,
        7,
        1024,
    ],
]

In [4]:
import urllib.parse
icon_src = ('<svg viewBox="0 0 115 115" fill="none" xmlns="http://www.w3.org/2000/svg">'
    '<path fill="#fff" d="M0 0h23v23H0zm0 69h23v23H0z"/><path fill="#AEAEAE" d="M23 0h23v23H23zm0 69h23v23H23z"/>'
    '<path fill="#fff" d="M46 0h23v23H46zm0 69h23v23H46z"/><path fill="#000" d="M69 0h23v23H69zm0 69h23v23H69z"/>'
    '<path fill="#D9D9D9" d="M92 0h23v23H92z"/><path fill="#AEAEAE" d="M92 69h23v23H92z"/><path fill="#fff" d="M115 46h23v23h-23zm0 69h23v23h-23z"/>'
    '<path fill="#D9D9D9" d="M115 69h23v23h-23z"/><path fill="#AEAEAE" d="M92 46h23v23H92zm0 69h23v23H92z"/>'
    '<path fill="#fff" d="M92 69h23v23H92zM69 46h23v23H69zm0 69h23v23H69z"/><path fill="#D9D9D9" d="M69 69h23v23H69z"/>'
    '<path fill="#000" d="M46 46h23v23H46zm0 69h23v23H46zm0-46h23v23H46z"/><path fill="#D9D9D9" d="M23 46h23v23H23z"/>'
    '<path fill="#AEAEAE" d="M23 115h23v23H23z"/><path fill="#000" d="M23 69h23v23H23z"/></svg>')
icon = "data:image+svg/xml," + urllib.parse.quote(icon_src, safe=" \"=")

In [6]:
import secrets
import ipywidgets

PREVIEW_ZOOM = 3
DEFAULT_SIZE = 512

_latent_size = DEFAULT_SIZE >> 3

def _preview_widget():
    widget = ipywidgets.Image(
        format='jpeg',
        width=_latent_size,
        height=_latent_size,
    )
    preview_size = f"{PREVIEW_ZOOM * _latent_size}px"
    widget.layout.width = preview_size
    widget.layout.height = preview_size
    widget.layout.object_position = 'center center'
    return widget
    
def on_submit(event=None):
    gallery_container.children = [_preview_widget() for i in range(samples.value)]
    progress.value = progress.max
    for result in infer(text.value, samples.value, steps.value, scale.value, seed.value):
        if isinstance(result, PipelineIntermediateState):
            progress.value = result.timestep
            for widget, latents in zip(gallery_container.children, result.latents):
                widget.value = jpeg_bytes(approximate_decoder(latents))
        else:
            progress.value = progress.min
            widgets = []
            for image in replace_unsafe_images(result):
                widgets.append(ipywidgets.Image(
                    value=jpeg_bytes(image), 
                    format='jpeg', 
                    width=512, height=512
                ))
            gallery_container.children = widgets
            

text = ipywidgets.Text(
    # description="Enter your prompt",
    placeholder="Enter your prompt"
)
btn = ipywidgets.Button(
    description="Generate image"
)
btn.on_click(on_submit)
text.on_submit(on_submit)
gallery_container = ipywidgets.Box(layout=ipywidgets.Layout(
    flex_flow="row wrap",
    justify_content="space-around",
    align_items="center",
    align_content="space-around",    
))

MAX_SEED = 1 << 31 - 1
samples = ipywidgets.IntSlider(2, 1, 4, description="Images")
steps = ipywidgets.IntSlider(16, 1, 50, description="Steps")
scale = ipywidgets.FloatSlider(7.5, min=0, max=50, step=0.1, description="Guidance Scale")
seed = ipywidgets.BoundedIntText(secrets.randbelow(MAX_SEED), 0, MAX_SEED, description="Seed")
progress = ipywidgets.IntProgress(value=0, min=0, max=pipe.scheduler.num_train_timesteps)
progress.layout.width = "100%"

advanced_options = ipywidgets.Accordion(children=[ipywidgets.VBox([
    samples,
    steps,
    scale,
    seed
])])
advanced_options.set_title(0, "Advanced Options")

form = ipywidgets.VBox([
    ipywidgets.HBox([text, btn]),
    gallery_container,
    progress,
    advanced_options
])
display(form)

VBox(children=(HBox(children=(Text(value='', placeholder='Enter your prompt'), Button(description='Generate im…

Model by [CompVis][CompVis] and [Stability AI][Stability AI].

### LICENSE

The model is licensed with a [CreativeML Open RAIL-M][license] license.
The authors claim no rights on the outputs you generate,
you are free to use them and are accountable for their use which must not go against the provisions set in this license.
The license forbids you from sharing any content that violates any laws, produce any harm to a person,
disseminate any personal information that would be meant for harm,
spread misinformation and target vulnerable groups.
For the full list of restrictions please [read the license][license]

### Biases and content acknowledgment

Despite how impressive being able to turn text into image is,
beware to the fact that this model may output content that reinforces or exacerbates societal biases,
as well as realistic faces, pornography and violence.
The model was trained on the [LAION-5B dataset][laion-5b],
which scraped non-curated image-text-pairs from the internet (the exception being the removal of illegal content)
and is meant for research purposes.
You can read more in the [model card][card].

[CompVis]: https://huggingface.co/CompVis
[Stability AI]: https://huggingface.co/stabilityai
[license]: https://huggingface.co/spaces/CompVis/stable-diffusion-license
[laion-5b]: https://laion.ai/blog/laion-5b/
[card]: https://huggingface.co/CompVis/stable-diffusion-v1-4