# Interactive version of `scripts/text2img.py`

This notebook was tested using Jupyter Lab.

Uses Jupyter widgets (not Google Colab form fields) for compatibility on both Jupyter Lab and Colab, see: https://colab.research.google.com/notebooks/forms.ipynb.

To install widgets for Jupyter Lab, follow the instructions here: https://ipywidgets.readthedocs.io/en/latest/user_install.html#installing-in-jupyterlab-3-0.

## Setup code (run once)

In [1]:
# FIXME hack to get code in txt2img (run from base project directory) to work in a notebook located inside /scripts
import os

if os.path.basename(os.getcwd()) == 'scripts':
    # doesn't check that we're actually in the root directory, hence the hack
    os.chdir('..')

In [2]:
# Slightly modified version of: https://github.com/CompVis/stable-diffusion/blob/main/scripts/txt2img.py

import argparse, os, sys, glob    
import torch    
import numpy as np    
from omegaconf import OmegaConf    
from PIL import Image    
from tqdm.notebook import tqdm, trange  # NOTE: updated for notebook
from itertools import islice    
from einops import rearrange    
from torchvision.utils import make_grid    
import time    
from pytorch_lightning import seed_everything    
from torch import autocast    
from contextlib import contextmanager, nullcontext    
    
from ldm.util import instantiate_from_config    
from ldm.models.diffusion.ddim import DDIMSampler    
from ldm.models.diffusion.plms import PLMSSampler
from scripts.txt2img import chunk, load_model_from_config

from IPython.display import clear_output


def load_model(opt):
    """Seperates the loading of the model from the inference"""
    
    if opt.laion400m:
        print("Falling back to LAION 400M model...")
        opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
        opt.ckpt = "models/ldm/text2img-large/model.ckpt"
        opt.outdir = "outputs/txt2img-samples-laion400m"

    config = OmegaConf.load(f"{opt.config}")
    model = load_model_from_config(config, f"{opt.ckpt}")

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model = model.to(device)
    
    return model


def run_inference(opt, model):
    """Seperates the loading of the model from the inference
    
    Additionally, slightly modified to display generated images inline
    """
    seed_everything(opt.seed)

    if opt.plms:
        sampler = PLMSSampler(model)
    else:
        sampler = DDIMSampler(model)

    os.makedirs(opt.outdir, exist_ok=True)
    outpath = opt.outdir

    batch_size = opt.n_samples
    n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
    if not opt.from_file:
        prompt = opt.prompt
        assert prompt is not None
        data = [batch_size * [prompt]]

    else:
        print(f"reading prompts from {opt.from_file}")
        with open(opt.from_file, "r") as f:
            data = f.read().splitlines()
            data = list(chunk(data, batch_size))

    sample_path = os.path.join(outpath, "samples")
    os.makedirs(sample_path, exist_ok=True)
    base_count = len(os.listdir(sample_path))
    grid_count = len(os.listdir(outpath)) - 1

    start_code = None
    if opt.fixed_code:
        start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)

    precision_scope = autocast if opt.precision=="autocast" else nullcontext
    with torch.no_grad():
        with precision_scope("cuda"):
            with model.ema_scope():
                tic = time.time()
                all_samples = list()
                for n in trange(opt.n_iter, desc="Sampling"):
                    for prompts in tqdm(data, desc="data"):
                        uc = None
                        if opt.scale != 1.0:
                            uc = model.get_learned_conditioning(batch_size * [""])
                        if isinstance(prompts, tuple):
                            prompts = list(prompts)
                        c = model.get_learned_conditioning(prompts)
                        shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
                        samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
                                                         conditioning=c,
                                                         batch_size=opt.n_samples,
                                                         shape=shape,
                                                         verbose=False,
                                                         unconditional_guidance_scale=opt.scale,
                                                         unconditional_conditioning=uc,
                                                         eta=opt.ddim_eta,
                                                         x_T=start_code)

                        x_samples_ddim = model.decode_first_stage(samples_ddim)
                        x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

                        if not opt.skip_save:
                            for x_sample in x_samples_ddim:
                                x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                                Image.fromarray(x_sample.astype(np.uint8)).save(
                                    os.path.join(sample_path, f"{base_count:05}.png"))
                                base_count += 1

                        if not opt.skip_grid:
                            all_samples.append(x_samples_ddim)

                if not opt.skip_grid:
                    # additionally, save as grid
                    grid = torch.stack(all_samples, 0)
                    grid = rearrange(grid, 'n b c h w -> (n b) c h w')
                    grid = make_grid(grid, nrow=n_rows)

                    # to image
                    grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
                    Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
                    grid_count += 1
                    
                    # display
                    if opt.display_inline:
                        clear_output()
                        display(Image.fromarray(grid.astype(np.uint8)))

                toc = time.time()

    print(f"Your samples have been saved to: \n{outpath} \n"
          f" \nEnjoy.")


def run(opt):
    """If the model parameters changed, reload the model, otherwise, just do inference"""

    print(f"Creating image ({opt.H},{opt.W}) from prompt:\n\"{opt.prompt}\"\n")
    
    # FIXME global hack
    global last_config
    global last_ckpt
    global model

    if (opt.config != last_config) or (opt.ckpt != last_ckpt):
        model = load_model(opt)
        # FIXME global hack
        last_config = opt.config
        last_ckpt   = opt.ckpt

    run_inference(opt, model)


# FIXME global hack
last_config = ""
last_ckpt   = ""

In [3]:
# Code to turn kwargs into Jupyter widgets
import ipywidgets as widgets
from collections import OrderedDict


def get_widget_extractor(widget_dict):
    # allows accessing after setting, this is to reduce the diff against the argparse code
    class WidgetDict(OrderedDict):
        def __getattr__(self,val):
            return self[val].value
    return WidgetDict(widget_dict)


# Allows long widget descriptions
style  = {'description_width': 'initial'}
# Force widget width to max
layout = widgets.Layout(width='100%')

# args from argparse converted to widgets:
# https://github.com/CompVis/stable-diffusion/blob/main/scripts/txt2img.py#L48-L177
widget_opt = OrderedDict()
widget_opt['prompt'] = widgets.Text(
    layout=layout, style=style,
    description='the prompt to render',
    #value="a painting of a virus monster playing guitar",  # script default
    value="a photograph of an astronaut riding a horse",  # README default
    disabled=False
)
widget_opt['outdir'] = widgets.Text(
    layout=layout, style=style,
    description='dir to write results to',
    value="outputs/txt2img-samples",
    disabled=False
)
widget_opt['skip_grid'] = widgets.Checkbox(
    layout=layout, style=style,
    value=False,
    description='do not save a grid, only individual samples. Helpful when evaluating lots of samples',
    indent=False,
    disabled=False
)
widget_opt['skip_save'] = widgets.Checkbox(
    layout=layout, style=style,
    value=False,
    description='do not save individual samples. For speed measurements.',
    indent=False,
    disabled=False
)
widget_opt['ddim_steps'] = widgets.IntText(
    layout=layout, style=style,
    description='number of ddim sampling steps',
    value=50,
    disabled=False
)
widget_opt['plms'] = widgets.Checkbox(
    layout=layout, style=style,
    value=True,
    description='use plms sampling',
    indent=False,
    disabled=False
)
widget_opt['laion400m'] = widgets.Checkbox(
    layout=layout, style=style,
    value=False,
    description='uses the LAION400M model',
    indent=False,
    disabled=False
)
widget_opt['fixed_code'] = widgets.Checkbox(
    layout=layout, style=style,
    value=False,
    description='if enabled, uses the same starting code across samples',
    indent=False,
    disabled=False
)
widget_opt['ddim_eta'] = widgets.FloatText(
    layout=layout, style=style,
    description='ddim eta (eta=0.0 corresponds to deterministic sampling',
    value=0.0,
    disabled=False
)
widget_opt['n_iter'] = widgets.IntText(
    layout=layout, style=style,
    description='sample this often',
    value=2,
    disabled=False
)
widget_opt['H'] = widgets.IntText(
    layout=layout, style=style,
    description='image height, in pixel space',
    value=512,
    disabled=False
)
widget_opt['W'] = widgets.IntText(
    layout=layout, style=style,
    description='image width, in pixel space',
    value=512,
    disabled=False
)
widget_opt['C'] = widgets.IntText(
    layout=layout, style=style,
    description='latent channels',
    value=4,
    disabled=False
)
widget_opt['f'] = widgets.IntText(
    layout=layout, style=style,
    description='downsampling factor',
    value=8,
    disabled=False
)
widget_opt['n_samples'] = widgets.IntText(
    layout=layout, style=style,
    description='how many samples to produce for each given prompt. A.k.a. batch size',
    value=3,
    disabled=False
)
widget_opt['n_rows'] = widgets.IntText(
    layout=layout, style=style,
    description='rows in the grid (default: n_samples)',
    value=0,
    disabled=False
)
widget_opt['scale'] = widgets.FloatText(
    layout=layout, style=style,
    description='unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))',
    value=7.5,
    disabled=False
)
widget_opt['from_file'] = widgets.Text(
    layout=layout, style=style,
    description='if specified, load prompts from this file',
    value=None,
    disabled=False
)
widget_opt['config'] = widgets.Text(
    layout=layout, style=style,
    description='path to config which constructs model',
    value="configs/stable-diffusion/v1-inference.yaml",
    disabled=False
)
widget_opt['ckpt'] = widgets.Text(
    layout=layout, style=style,
    description='path to checkpoint of model',
    value="models/ldm/stable-diffusion-v1/model.ckpt",
    disabled=False
)
widget_opt['seed'] = widgets.IntText(
    layout=layout, style=style,
    description='the seed (for reproducible sampling)',
    value=42,
    disabled=False
)
widget_opt['precision'] = widgets.Combobox(
    layout=layout, style=style,
    description='evaluate at this precision',
    value="autocast",
    options=["full", "autocast"],
    disabled=False
)
# Extra option for the notebook
widget_opt['display_inline'] = widgets.Checkbox(
    layout=layout, style=style,
    value=True,
    description='display output images inline (in addition to saving them)',
    indent=False,
    disabled=False
)

# Button that runs the 
# Alternatively, you can just run the following in a new cell:
# run(get_widget_extractor(widget_opt))
run_button = widgets.Button(
    description='CLICK TO DREAM',
    disabled=False,
    button_style='', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Click to run (settings will update automatically)',
    icon='check'
)
run_button_out = widgets.Output()
def on_run_button_click(b):
    with run_button_out:
        clear_output()
        run(get_widget_extractor(widget_opt))
run_button.on_click(on_run_button_click)

# Package into box and render
#primary_options = ['prompt', 'outdir']  # options to put up top
#secondary_options = [k for k in widget_opt.keys() if k not in primary_options]  # rest, ordered by insertion

load_options = ['config', 'ckpt']
inference_options = [k for k in widget_opt.keys() if k not in load_options]  # rest, ordered by insertion
assert all([k in inference_options + load_options for k in widget_opt.keys()])  # make sure we didn't miss any options

# Package into box for rendering
gui = widgets.VBox(
    [widget_opt[k] for k in inference_options] + [widget_opt[k] for k in load_options] + [run_button, run_button_out]
)

## Interactive loop

Change options using the GUI, then run the next cell - no need to re-run/display the GUI cell (the GUI will automatically update the variables)

You may get a warning (eg "Some weights of the model ...") the first time you run the cell when the model is first loaded

In [4]:
display(gui)

VBox(children=(Text(value='a photograph of an astronaut riding a horse', description='the prompt to render', l…