# GUI for Stable Diffusion

# Stable Diffusion
Written by Jasmine Sandhu (Acknowledgements: Jim Bednar, Maxime Liquet, Philipp Rudiger)<br>
Created: Jan, 2023<br>
Last updated: Jan, 2023

## Stable Diffusion, Diffusers library

[Stable Diffusion](https://en.wikipedia.org/wiki/Stable_Diffusion#:~:text=Stable%20Diffusion%20is%20a%20deep,guided%20by%20a%20text%20prompt) is a deep learning, text-to-image model released in 2022. It is primarily used to generate detailed images conditioned on text descriptions. 

This example uses the [Diffusers library](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb) with checkpoints from the runwayml and CompVis repositories. [Diffusers on github](https://github.com/huggingface/diffusers#stable-diffusion-is-fully-compatible-with-diffusers). Blogpost on [Stable Diffusion with Diffusers](https://huggingface.co/blog/stable_diffusion)

### Performance: GPU

The example assumes it will run on a GPU. It can be modified to run on a CPU but image generation will take on the order of minutes as opposed to seconds.


### Limitations

The models were trained on images with resolution of 512x512. The diffusers pipeline and subsequently the UI allows creation of images with different resolutions; however, the image quality degrades if deviating from the resolution used to train the model. 


### Seed

The idea behind stable diffusion is to start with a noisy image, with the goal of removing gaussian noise in each inference step. The seed value determines the randomness and the output generated. By default the seed is randomized in this application with the opportunity to explore generated images for the same prompt. Fixing the seed will recreate the same image for a given resolution. As noted above, changing the resolution will also change the image output.

In [None]:
import time
from contextlib import contextmanager

import torch
import random
from diffusers import StableDiffusionPipeline

from bokeh.models.formatters import PrintfTickFormatter
import panel as pn
import param
from panel.layout.base import ListLike
from panel.reactive import ReactiveHTML
from panel.viewable import Viewer, Viewable

pn.extension()

In [None]:
# create a context manager to measure execution time and print it to the console
@contextmanager
def exec_time(description="Task"):
    st = time.perf_counter()
    yield 
    print(f"{description}: {time.perf_counter() - st:.2f} sec")


The `init_model` function will first look in the default cache location used by huggingface to find downloaded pretrained model. If these haven't been downloaded yet, it will first download the models. On subsequent restarts of the app, it'll load the models from the local cache. These can also be downloaded separately as follows:
  
  ```
  pipe, cache = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", return_cached_folder=True, local_files_only=False)
  pipe, cache = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", return_cached_folder=True, local_files_only=False)
  print(cache) # to see the default cache location
  ```

In addition to caching the pretrained model, we also initialize and cache the diffusers pipeline inside `panel.state.cache`. This ensures that each new visitor to the page does not require creating and destroying a new diffusers pipeline.
The initial page load takes an extra ~10 sec or so and allocates the GPU memory required to load the pipeline in memory but subsequent visitors get this pipeline from panel's cache. The memory overhead from here is the amount needed to generate the image  text prompt.
Below is an example output of the `nvidia-smi` running on a machine with 2 Quadro RTX 8000 GPUs, after both models load.

```
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.65.01    Driver Version: 515.65.01    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Quadro RTX 8000     Off  | 00000000:15:00.0 Off |                  Off |
| 33%   33C    P8    24W / 260W |     48MiB / 49152MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Quadro RTX 8000     Off  | 00000000:2D:00.0 Off |                  Off |
| 33%   40C    P8    29W / 260W |   5933MiB / 49152MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A      2024      G   /usr/lib/xorg/Xorg                 23MiB |
|    0   N/A  N/A      2545      G   /usr/bin/gnome-shell               20MiB |
|    1   N/A  N/A      2024      G   /usr/lib/xorg/Xorg                  4MiB |
|    1   N/A  N/A   2263594      C   .../diffusers/bin/python3.11     5925MiB |
+-----------------------------------------------------------------------------+
```

In [None]:
# initialize models and define function for image generation
# use only downloaded models
random_int_range = 1, int(1e6)
def init_model(model, gpu_id=1, torch_dtype=None, local_files_only=True):
    print(f"Init model: {model}")
    if torch_dtype:
        pipe = StableDiffusionPipeline.from_pretrained(model, torch_dtype=torch_dtype,
                                                       local_files_only=local_files_only)
    else:
        pipe = StableDiffusionPipeline.from_pretrained(model, local_files_only=local_files_only)
        

    # can use nvidia-smi to check and set this so you're not running on the same one as panel serve
    # it just makes it a little more responsive
    if torch.cuda.is_available():
        pipe.to(f"cuda:{gpu_id}")
    return pipe     


if 'pipelines' in pn.state.cache:
    print(f"load from cache")
    pipelines = pn.state.cache['pipelines']
    pseudo_rand_gen = pn.state.cache['pseudo_rand_gen']
else:
    models = ['runwayml/stable-diffusion-v1-5', 
              'CompVis/stable-diffusion-v1-4'
             ]
    
    with exec_time("Load models"):
        pipelines = dict()
        for m in models:
            try: 
                pipelines[m] = init_model(m, torch_dtype=torch.float16)
            except OSError:
                pipelines[m] = init_model(m, torch_dtype=torch.float16, local_files_only=False)
            
    if torch.cuda.is_available():
        pseudo_rand_gen = torch.Generator(device='cuda')
    else:
        pseudo_rand_gen = torch.Generator()

    pn.state.cache['pipelines'] = pipelines
    pn.state.cache['pseudo_rand_gen'] = pseudo_rand_gen
    print(f"Save to cache")

default_model = next(iter(pipelines))
    
def generate_image(
    prompt,
    negative_prompt=None,
    model=default_model,
    height=512,
    width=512,
    guidance_scale=7.5,
    num_steps=30,
    seed=None,
):
    pipe = pipelines[model]
    
    if not seed or seed < random_int_range[0]:
        seed = random.randint(*random_int_range)
    
    generator = pseudo_rand_gen.manual_seed(seed)
    res = pipe(prompt=prompt,
               negative_prompt=negative_prompt,
               guidance_scale=guidance_scale,
               height=height,
               width=width,
               num_inference_steps=num_steps,
               generator=generator,
              )
    return res.images[0], seed

The various panel widgets in this code block affect the image generation. When rendered with a template, the sidebar should ideally start out collapsed with only the `Prompt` text box visible. A user writes a prompt, hits enter which triggers the callback to invoke the image generation function. Opening the sidebar provides more options. A user can set various options, then click `Generate` to create image with those options or hit enter on the prompt. If the prompt does not change, hitting enter will not generate a new image - use the `Generate` button to create new images with the same prompt. Below is a description of each option.


__Prompt__: Enter a text you wish to use for image generation. Some examples below:

  1. Wildflowers on a mountain side 
  1. A dream of a distant planet, with multiple moons
  1. valley of flowers in the Himalayas
  
__Negative Prompt__: Negative prompt is what the model will try to remove from the image. For instance, in example (1) above, you can add `yellow` to negative prompt to remove yellow flowers

__Pretrained Model__: These are the models, download from hugging face, used for inference.

__Height, Width__: Height and width in pixels of the images.

__Guidance Scale__: Also known as CFG (Classifier-free guidance scale). Typically use a value between 7 to 8.5. As you increase this value, the model will try to match the prompt at the expense of image quality or diversity of the image.

__# of steps__: The number of denoising steps taken by the model. As you increase the number of steps the image gets more refined; however, it takes longer to generate.

The random seed used when create the noise for the image is randomly set for each image.

In [None]:
class Gallery(ListLike, ReactiveHTML):
    
    objects = param.List(item_type=Viewable)
   
    current = param.Integer(default=None)
    
    margin = param.Integer(0)

    _template = """
    <div id="gallery" style="display: flex; flex-direction: row;">
    {% for img in objects %}
      <div id="img" name="{{ img.name }}" onclick=${script('click')}>${img}</div>
    {% endfor %}
    </div>
    """
    
    _scripts = {
        'click': """
          const id = event.target.parentNode.parentNode.parentNode.id;
          data.current = Number(id.split('-')[1]);
          """
    }

class StableDiffusionUI(Viewer):
    
    prompt = param.String(label='Prompt')
    
    neg_prompt = param.String(label='Negative Prompt')
    
    model = param.Selector(objects=list(pipelines), default=default_model)
    
    _size_range = tuple(448 + i*2**6 for i in range(10))
    width = param.Selector(_size_range, default=_size_range[1])
    
    height = param.Selector(_size_range, default=_size_range[1])
    
    guidance_scale = param.Number(bounds=(5, 10), step=0.1, default=7.5)
    
    num_steps = param.Integer(label='# of steps', bounds=(10, 75), default=30)
    
    gallery = param.ClassSelector(class_=Gallery, default=Gallery(min_height=100), precedence=-1)
    
    seed = param.Integer(
        default=random.randint(*random_int_range), bounds=random_int_range, step=10,
        precedence=-1)
    
    generate = param.Event(precedence=1)
    
    def __init__(self, **params):
        self.history = []
        super().__init__(name=params.pop('name', 'Stable Diffusion with Panel UI'), **params)
        self.gallery.param.watch(self._restore_history, 'current')
        self._restore = False
        self._on_load()

    @contextmanager
    def _toggle(self, attr, value):
        # toggle state of bool attribute inside context
        # if exception raised by code inside the contextmanager, set the state back to original and rethrow exception
        init_state = getattr(self, attr)
        try:
            setattr(self, attr, not(init_state))
            yield
            setattr(self, attr, not(getattr(self, attr)))
        except Exception as ex:
            setattr(self, attr, init_state)
            raise ex
        

    def _restore_history(self, event):
        if event.new is None:
            return
        self.gallery.current = None
        self.param.update(self.history[event.new])
        print(f"before with - self._restore: {self._restore}")
        with self._toggle('_restore', value=True):
            print(f"before trigger - self._restore: {self._restore}")
            self.param.trigger('generate')
        print(f"after with - self._restore: {self._restore}")
 
    @property
    def _current_state(self):
        return {'prompt': self.prompt,
                'negative_prompt': self.neg_prompt,
                'model': self.model,
                'height': self.height,
                'width': self.width,
                'guidance_scale': self.guidance_scale,
                'num_steps': self.num_steps,
                'seed': self.seed}
    
#     def _on_load(self):
#         self._generate(restore=pn.state.location.query_params)

    
#     @param.depends('generate', watch=True)
#     def _generate(self, restore=None):
#         if restore:
#             self.param.update(restore)
#         else:
#             self.seed = random.randint(*self.param.seed.bounds)

#         state = self._current_state
#         image, _ = generate_image(**state)
#         pn.state.location.update_query(**state)
    
    def _on_load(self):
        if pn.state.location.query_params:
            # reuse _restore flag for now to generate new seed or not
            self.param.update(pn.state.location.query_params)
            print(f"pn.state.location.query_params: {pn.state.location.query_params}")
            print(f"before with - self._restore: {self._restore}")
            
            with self._toggle('_restore', value=True):
                print(f"before trigger - self._restore: {self._restore}")
                self.param.trigger('generate')
            print(f"after with - self._restore: {self._restore}")

    @param.depends('generate')
    def image(self):
        if not self.prompt:
            return pn.pane.PNG(style={'border': '1px solid black'}, height=self.height, width=self.width)
        
        if not self._restore:
            print(f"In image - self.seed {self.seed}")
            self.seed = random.randint(*self.param.seed.bounds)

        current_state = self._current_state
        print(f"{current_state}")
        image, image_seed = generate_image(**current_state)

        if not self._restore:
            self.gallery.append(pn.pane.PNG(image.resize((100, 100))))
            self.history.append({
                p: v for p, v in self.param.values().items() if p not in ('name', 'gallery', 'generate')
            })

        pn.state.location.update_query(**self._current_state)
        return pn.pane.PNG(image, style={'border': '1px solid black'})

    def _sidebar_widgets(self):
        return [

            self.param.model,
            pn.Param(self.param.height, widgets={'height': pn.widgets.DiscreteSlider}),
            pn.Param(self.param.width, widgets={'width': pn.widgets.DiscreteSlider}),
            pn.Param(self.param.guidance_scale, 
                     widgets={'guidance_scale': 
                              {'formatter': PrintfTickFormatter(format='%.1f')}}),
            self.param.num_steps,
        ]

    def _main_widgets(self):
        return [
            pn.Row(
                pn.Column(self.param.prompt, self.param.neg_prompt, sizing_mode='stretch_width'),
                pn.Param(self.param.generate, 
                         widgets={'generate': {'button_type': 'success', 'height': 110, 'width': 30, 'name': 'Generate Image'}}),
            ),
            pn.Row(pn.param.ParamMethod(self.image, loading_indicator=True),
                   pn.Column(self.gallery))
        ]
    
    def __panel__(self):
        # Discrete slider for width, height: https://huggingface.co/blog/stable_diffusion
        return pn.Row(
            pn.Column(*self._sidebar_widgets()),
            pn.Column(*self._main_widgets(), sizing_mode='stretch_width')
        )

    
sdui = StableDiffusionUI()

sdui

### Use a template

Use a template to get a clean look and feel.

TODO: Start out with the sidebar collapsed.

In [None]:
## logo / headers / 
logo  = """<a href="http://panel.pyviz.org">
           <img src="https://panel.pyviz.org/_static/logo_stacked.png" 
            width=150 height=127 align="left" margin=20px>"""

desc = pn.pane.HTML("""
    The <a href="http://panel.pyviz.org">Panel</a> library from <a href="https://holoviz.org/">HoloViz</a> 
    lets you make widget-controlled apps. Here you can use the
    <a href="https://huggingface.co/docs/diffusers/index">diffusers</a> library to
    generate images from pretrained diffusion models. Panel is used to create the UI for the pipeline.""", width=250)

template = pn.template.MaterialTemplate(
    title=sdui.name,
)

template.sidebar.append(logo)
template.sidebar.append(desc.clone(width=300, margin=(20, 5)))
template.sidebar.append(pn.Column(*sdui._sidebar_widgets()))

template.main.append(pn.Column(*sdui._main_widgets(), sizing_mode='stretch_width'))

template.servable();