# Stable Diffusion

## Enviroments

In [None]:
import os
import math
from datetime import datetime
from PIL import Image

USE_DRIVE = False #@param {type:'boolean'}
DRIVE_DIR = '/content/drive' #@param {type:'string'}
if USE_DRIVE:
    from google.colab import drive
    drive.mount(DRIVE_DIR)

%pip install -q torch diffusers accelerate transformers compel
%pip install -q omegaconf ipywidgets controlnet-aux mediapipe

def save_image(img, dir: str, tmp: str = '') -> str:
  if not os.path.exists(dir):
    os.makedirs(dir)
  img_time = datetime.now().strftime('%Y%m%d%H%M%S')
  img_name = f'{img_time}-{tmp}' if tmp != '' else img_time
  img_path = os.path.join(dir, f'{img_name}.png')
  img.save(img_path)
  return img_path


def get_model_list(dir: str) -> dict:
  items = {}
  allowed_extensions = {'.ckpt' , '.pt', '.pth', '.safetensors', '.bin'}
  for file in os.listdir(dir):
    if os.path.splitext(file)[1] in allowed_extensions:
      file_name = os.path.splitext(file)[0]
      file_path = os.path.join(dir, file)
      items[file_name] = file_path
  return items


def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols
    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid


def resize_and_crop_image(image, target_width=576, target_height=1024):
    original_width, original_height = image.size
    aspect_ratio = original_width / original_height

    target_aspect_ratio = target_width / target_height
    if aspect_ratio > target_aspect_ratio:
        # Image is wider than target, resize based on width
        new_width = target_width
        new_height = int(new_width / aspect_ratio)
    else:
        # Image is taller than target, resize based on height
        new_height = target_height
        new_width = int(new_height * aspect_ratio)

    resized_image = image.resize((new_width, new_height))
    # Calculate cropping offset to center the image within the target dimensions
    offset_x = math.floor((new_width - target_width) / 2)
    offset_y = math.floor((new_height - target_height) / 2)

    cropped_image = resized_image.crop((offset_x, offset_y, offset_x + target_width, offset_y + target_height))
    return cropped_image


def batch_resize_and_crop(images, target_width = 576, target_height = 1024):
    return [
        resize_and_crop_image(image, target_width, target_height)
        for image in images
    ]

In [5]:
#@title Define classes, methods
import os
import torch
from compel import Compel
from diffusers import (
    AutoencoderKL,
    ControlNetModel,
    StableDiffusionPipeline,
    StableDiffusionImg2ImgPipeline,
    StableDiffusionInpaintPipeline,
    StableDiffusionControlNetPipeline,
    StableDiffusionControlNetImg2ImgPipeline,
    StableDiffusionControlNetInpaintPipeline,
    DPMSolverMultistepScheduler,
    DPMSolverSinglestepScheduler,
    KDPM2DiscreteScheduler,
    KDPM2AncestralDiscreteScheduler,
    EulerDiscreteScheduler,
    EulerAncestralDiscreteScheduler,
    HeunDiscreteScheduler,
)
from diffusers.pipelines.controlnet import MultiControlNetModel
from controlnet_aux.processor import Processor

CONTROLNET_REPOS = {
  'canny': 'lllyasviel/control_v11p_sd15_canny',
  'depth': 'lllyasviel/control_v11f1p_sd15_depth',
  'inpaint': 'lllyasviel/control_v11p_sd15_inpaint',
  'ip2p': 'lllyasviel/control_v11e_sd15_ip2p',
  'lineart': 'lllyasviel/control_v11p_sd15_lineart',
  'lineart_anime': 'lllyasviel/control_v11p_sd15s2_lineart_anime',
  'mlsd': 'lllyasviel/control_v11p_sd15_mlsd',
  'normalbae': 'lllyasviel/control_v11p_sd15_normalbae',
  'openpose': 'lllyasviel/control_v11p_sd15_openpose',
  'scribble': 'lllyasviel/control_v11p_sd15_scribble',
  'segmentation': 'lllyasviel/control_v11p_sd15_seg',
  'shuffle': 'lllyasviel/control_v11e_sd15_shuffle',
  'softedge': 'lllyasviel/control_v11p_sd15_softedge',
}

AUX_PROCESSOR = [
    "canny",
    "depth_leres",
    "depth_leres++",
    "depth_midas",
    "depth_zoe",
    "lineart_anime",
    "lineart_coarse",
    "lineart_realistic",
    "mediapipe_face",
    "mlsd",
    "normal_bae",
    "normal_midas",
    "openpose",
    "openpose_face",
    "openpose_faceonly",
    "openpose_full",
    "openpose_hand",
    "scribble_hed",
    "scribble_pidinet",
    "shuffle",
    "softedge_hed",
    "softedge_hedsafe",
    "softedge_pidinet",
    "softedge_pidsafe",
    "dwpose",
]


class AnnotatorService:
    def __init__(self):
        self.models = {}

    def process(self, images: dict):
        result = []

        if list(self.models.keys()) != list(images.keys()):
            self.models = {}
            for name in images.keys():
                self.models[name] = None if name == 'none' else Processor(name)

        for name, image in images.items():
            new_image = Image.open(image)
            if name == 'none':
                result.append(new_image)
                continue
            new_image = self.models[name](new_image, to_pil=True)
            result.append(new_image)
        return result


class StableDiffusionService:
    def __init__(self):
        self.task = ''
        self.pipe = None
        self.compel = None
        self.ldm = ''
        self.vae = ''
        self.controlnets = []
        self.cfg = {'torch_dtype': torch.float16}
    

    def load_ldm(self, ldm: str):
        if self.pipe and self.ldm == ldm:
            print('LDM: from cache')
            return

        cfg = {'torch_dtype': torch.float16, 'token': False}

        if not os.path.isfile(ldm):
            self.pipe = StableDiffusionPipeline.from_pretrained(ldm, **cfg)
        else:
            self.pipe = StableDiffusionPipeline.from_single_file(ldm, **cfg)

        self.compel = Compel(self.pipe.tokenizer, self.pipe.text_encoder)
        self.pipe.safety_checker = None
        self.pipe.to('cuda')
        self.ldm = ldm
        self.task = 'txt2img'


    def load_vae(self, vae: str):
        if not self.vae and vae == 'none':
            return

        if self.vae == vae:
            print('VAE: from cache')
            return

        if self.pipe.vae and vae == 'none':
            del self.pipe.vae
            print('VAE: set None')
            return
        elif os.path.isfile(vae):
            self.pipe.vae = AutoencoderKL.from_single_file(vae, **self.cfg)
        else:
            self.pipe.vae = AutoencoderKL.from_pretrained(vae, **self.cfg)

        self.vae = vae
        self.pipe.vae.to('cuda')


    def load_controlnet(self, task: str, controlnets: list):
        if self.task == task and self.controlnets == controlnets:
            return

        components = self.pipe.components

        if self.controlnets != controlnets:
            models = [
                ControlNetModel.from_pretrained(name, torch_dtype=torch.float16)
                for name in controlnets
            ]
            components['controlnet'] = MultiControlNetModel(models)
            print(f'Controlnets: {",".join(controlnets)} have been loaded')

            if not controlnets:
                components.pop('controlnet')

        if 'controlnet' in components:
            if task == 'txt2img':
                self.pipe = StableDiffusionControlNetPipeline(**components)
            if task == 'img2img':
                self.pipe = StableDiffusionControlNetImg2ImgPipeline(**components)
            if task == 'inpaint':
                self.pipe = StableDiffusionControlNetInpaintPipeline(**components)
        else:
            if task == 'txt2img':
                self.pipe = StableDiffusionPipeline(**components)
            if task == 'img2img':
                self.pipe = StableDiffusionImg2ImgPipeline(**components)
            if task == 'inpaint':
                self.pipe = StableDiffusionInpaintPipeline(**components)

        self.task = task
        self.controlnets = controlnets


    def set_scheduler(self, name: str):
        match name:
            case "DPM++ 2M":
                sampler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config)
            case "DPM++ 2M Karras":
                sampler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config, use_karras_sigmas=True)
            case "DPM++ 2M SDE":
                sampler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config, algorithm_type="sde-dpmsolver++")
            case "DPM++ 2M SDE Karras":
                sampler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++")
            case "DPM++ SDE":
                sampler = DPMSolverSinglestepScheduler.from_config(self.pipe.scheduler.config)
            case "DPM++ SDE Karras":
                sampler = DPMSolverSinglestepScheduler.from_config(self.pipe.scheduler.config, use_karras_sigmas=True)
            case "DPM2":
                sampler = KDPM2DiscreteScheduler.from_config(self.pipe.scheduler.config)
            case "DPM2 Karras":
                sampler = KDPM2DiscreteScheduler.from_config(self.pipe.scheduler.config, use_karras_sigmas=True)
            case "Euler":
                sampler = EulerDiscreteScheduler.from_config(self.pipe.scheduler.config)
            case "Euler a":
                sampler = EulerAncestralDiscreteScheduler.from_config(self.pipe.scheduler.config)
            case "Heun":
                sampler = HeunDiscreteScheduler.from_config(self.pipe.scheduler.config)
            case _:
                sampler = EulerAncestralDiscreteScheduler.from_config(self.pipe.scheduler.config)
        self.pipe.scheduler = sampler

    def __call__(self, **kwargs):
        return self.pipe(**kwargs).images


app = StableDiffusionService()
annotator = AnnotatorService()

In [3]:
#@title Define user interface
from ipywidgets import widgets
from IPython.display import display

class ControlnetUserInterface:
    def __init__(
        self,
        title: str,
        preprocessor: list = AUX_PROCESSOR,
        processor: list = list(CONTROLNET_REPOS.keys())
    ):
        self.title = title
        self.element_layout = widgets.Layout(width='initial')
        self.enabled = widgets.Checkbox(
            description='Enabled',
            layout=self.element_layout,
        )
        self.preprocessor = widgets.Dropdown(
            description='Preprocessor',
            options=['none'] + preprocessor,
            layout=self.element_layout,
        )
        self.processor = widgets.Dropdown(
            description='Model',
            options=['none'] + processor,
            layout=self.element_layout,
        )
        self.image = widgets.Text(
            description='Image',
            placeholder='/content/image.jpg',
            layout=self.element_layout,
        )
        self.weight = widgets.FloatSlider(
            description='Weight',
            value=1.0,
            min=0.0,
            max=1.0,
            step=0.1,
            layout=self.element_layout,
        )
        self.start = widgets.FloatSlider(
            description='Start Threshold',
            value=0.0,
            min=0.0,
            max=1.0,
            step=0.1,
            layout=self.element_layout,
        )
        self.end = widgets.FloatSlider(
            description='Stop Threshold',
            value=1.0,
            min=0.0,
            max=1.0,
            step=0.1,
            layout=self.element_layout,
        )

    def widget(self):
        return widgets.VBox([
            widgets.Label(self.title),
            self.enabled,
            self.image,
            self.preprocessor,
            self.processor,
            self.weight,
            self.start,
            self.end,
        ])


class StableDiffusionUserInterface:
    def __init__(self, ldm_items: list, vae_items: list):
        self.element_layout = widgets.Layout(width='initial')
        self.ldm = widgets.Dropdown(
            options=ldm_items,
            layout=self.element_layout,
        )
        self.ldm_path = widgets.Text(
            placeholder='runwayml/stable-diffusion-v1-5',
            layout=self.element_layout,
        )
        self.vae = widgets.Dropdown(
            options=vae_items,
            layout=self.element_layout,
        )
        self.vae_path = widgets.Text(
            placeholder='stabilityai/sd-vae-ft-ema',
            layout=self.element_layout,
        )
        self.scheduler = widgets.Dropdown(
            description='Scheduler',
            options=[
                'DPM++ 2M',
                'DPM++ 2M Karras',
                'DPM++ 2M SDE',
                'DPM++ 2M SDE Karras',
                'DPM++ SDE',
                'DPM++ SDE Karras',
                'DPM2',
                'DPM2 Karras',
                'Euler',
                'Euler a',
                'Heun',
            ],
            layout=self.element_layout,
        )
        self.steps = widgets.IntText(
            description='Steps',
            value=30,
            layout=self.element_layout,
        )
        self.cfg = widgets.FloatText(
            description='CFG',
            value=7.5,
            step=0.1,
            layout=self.element_layout,
        )
        self.width = widgets.IntText(
            description='Width',
            value=576,
            layout=self.element_layout,
        )
        self.height = widgets.IntText(
            description='Height',
            value=1024,
            layout=self.element_layout,
        )
        self.seed = widgets.IntText(
            description='Seed',
            value=-1,
            layout=self.element_layout,
        )
        self.batch_size = widgets.IntText(
            description='Batch Size',
            value=1,
            layout=self.element_layout,
        )
        self.task = widgets.Dropdown(
            description='Task',
            options=['txt2img', 'img2img', 'inpaint'],
            layout=self.element_layout,
        )
        self.prompt = widgets.Textarea(
            description='Prompt',
            rows=5,
            value='best quality, highly detailed',
            layout=self.element_layout,
        )
        self.negative_prompt = widgets.Textarea(
            description='Negative Prompt',
            rows=5,
            value='low quality, watermark, logo, blurry, monochrome',
            layout=self.element_layout,
        )
        self.input_image = widgets.Text(
            description='Input Image',
            placeholder='/content/example.jpg',
            layout=widgets.Layout(width='initial', display='none'),
        )
        self.input_mask = widgets.Text(
            description='Input Mask',
            placeholder='/content/example.jpg',
            layout=widgets.Layout(width='initial', display='none'),
        )
        self.denoising_strength = widgets.FloatSlider(
            description='Denoising Strength',
            value=0.75,
            min=0.0,
            max=1.0,
            step=0.05,
            layout=widgets.Layout(width='initial', display='none'),
        )
        self.generate_button = widgets.Button(
            description='Generate',
            button_style='primary',
        )
        self.clear_button = widgets.Button(
            description='Clear Output',
        )

        self.controlnets = [
            ControlnetUserInterface('Controlnet 1'),
            ControlnetUserInterface('Controlnet 2'),
            ControlnetUserInterface('Controlnet 3'),
        ]
        self.task.observe(self.on_task_change, names='value')

    def on_task_change(self, change):
        self.input_image.layout.display = 'none' if change['new'] == 'txt2img' else ''
        self.denoising_strength.layout.display = 'none' if change['new'] == 'txt2img' else ''
        self.input_mask.layout.display = 'none' if change['new'] != 'inpaint' else ''

    def get_inputs(self) -> dict:
        control_image = {}
        control_processor = []
        control_weight = []
        control_start = []
        control_end = []

        for control in self.controlnets:
            if control.enabled.value:
                control_image[control.preprocessor.value] = control.image.value or self.input_image.value
                control_processor.append(control.processor.value)
                control_weight.append(control.weight.value)
                control_start.append(control.start.value)
                control_end.append(control.end.value)

        return {
            'task': self.task.value,
            'ldm': self.ldm.value,
            'ldm_path': self.ldm_path.value,
            'vae': self.vae.value,
            'vae_path': self.vae_path.value,
            'prompt': self.prompt.value,
            'negative_prompt': self.negative_prompt.value,
            'guidance_scale': self.cfg.value,
            'num_inference_steps': self.steps.value,
            'num_images_per_prompt': self.batch_size.value,
            'width': self.width.value,
            'height': self.height.value,
            'seed': self.seed.value,
            'image': self.input_image.value,
            'mask_image': self.input_mask.value,
            'control_image': control_image,
            'control_processor': control_processor,
            'controlnet_conditioning_scale': control_weight,
            'control_guidance_start': control_start,
            'control_guidance_end': control_end,
            'scheduler': self.scheduler.value,
            'strength': self.denoising_strength.value,
        }

    def widget(self):
        buttons = widgets.HBox(
            [self.generate_button, self.clear_button],
            layout=widgets.Layout(
                width='100%',
                display='flex',
                justify_content='center',
            )
        )

        header = widgets.HBox([
            widgets.VBox([
                widgets.Label('Checkpoint'),
                self.ldm,
                self.ldm_path,
            ], layout=widgets.Layout(width='inherit')),
            widgets.VBox([
                widgets.Label('VAE'),
                self.vae,
                self.vae_path,
            ], layout=widgets.Layout(width='inherit')),
        ], layout=widgets.Layout(width='100%', margin='0px 0px 16px 0px')
        )

        body = widgets.GridspecLayout(2, 3, grid_gap='8px')
        body[0, 0] = widgets.VBox([
            self.scheduler,
            self.steps,
            self.cfg,
            self.width,
            self.height,
            self.seed,
            self.batch_size,
        ], layout=widgets.Layout(width='100%'))

        body[0, 1:] = widgets.VBox([
            self.task,
            self.input_image,
            self.input_mask,
            self.prompt,
            self.negative_prompt,
            self.denoising_strength,
            buttons,
        ])

        body[1, 0] = self.controlnets[0].widget()
        body[1, 1] = self.controlnets[1].widget()
        body[1, 2] = self.controlnets[2].widget()

        layout = widgets.VBox([
            header,
            body,
        ])

        return layout


In [None]:
#@title Download (Optional)
from torch.hub import download_url_to_file

DOWNLOAD_DIR = '/content/drive/MyDrive/sd/vae' #@param {type:'string'}
DOWNLOAD_URL = 'https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.safetensors' #@param {type:'string'}
DOWNLOAD_OUT = 'vae-ft-mse-840000-ema-pruned.safetensors' #@param {type:'string'}

if not os.path.exists(DOWNLOAD_DIR):
    os.makedirs(DOWNLOAD_DIR, exist_ok=True)

if DOWNLOAD_DIR and DOWNLOAD_URL and DOWNLOAD_OUT:
    download_url_to_file(DOWNLOAD_URL, f'{DOWNLOAD_DIR}/{DOWNLOAD_OUT}')

## Generate

In [None]:
# @title Generate { display-mode: "form" }
import random

OUT_DIR = '/content/out' #@param {type:'string'}
LDM_DIR = '/content/drive/MyDrive/sd/ldm' #@param {type:'string'}
VAE_DIR = '/content/drive/MyDrive/sd/vae' #@param {type:'string'}

ldm_checkpoints = {}
vae_checkpoints = {}

if os.path.exists(LDM_DIR):
    ldm_checkpoints = get_model_list(LDM_DIR)

if os.path.exists(VAE_DIR):
    vae_checkpoints = get_model_list(VAE_DIR)

ui = StableDiffusionUserInterface(
    list(ldm_checkpoints.keys()) or ['none'],
    ['none'] + list(vae_checkpoints.keys()),
)
applog = widgets.Output()

def on_click_generate_button(b):
    applog.clear_output()
    with applog:
        inputs: dict = ui.get_inputs()
        if not inputs['ldm'] and not inputs['ldm_path']:
            print('require LDM checkpoint')
            return

        task = inputs.pop('task', 'txt2img')
        seed = inputs.pop('seed', -1)
        ldm = inputs.pop('ldm')
        ldm_path = inputs.pop('ldm_path')
        vae = inputs.pop('vae')
        vae_path = inputs.pop('vae_path')
        prompt = inputs.pop('prompt')
        negative_prompt = inputs.pop('negative_prompt')
        scheduler = inputs.pop('scheduler')
        control_processor = inputs.pop('control_processor', [])
        control_repos = [CONTROLNET_REPOS[name] for name in control_processor]

        app.load_ldm(ldm_path or ldm_checkpoints[ldm])
        app.load_vae(vae_path or vae_checkpoints.get(vae, 'none'))
        app.load_controlnet(task, control_repos)
        app.set_scheduler(scheduler)
        app.pipe.to('cuda')

        inputs['prompt_embeds'] = app.compel(prompt)
        inputs['negative_prompt_embeds'] = app.compel(negative_prompt)

        calculate_seed = random.randint(0, 999999999) if seed == -1 else seed
        generator = torch.Generator(device="cuda").manual_seed(calculate_seed)
        inputs['generator'] = generator

        if inputs['control_image']:
            inputs['control_image'] = annotator.process(inputs['control_image'])
            inputs['control_image'] = batch_resize_and_crop(
                inputs['control_image'],
                inputs['width'],
                inputs['height'],
            )

        if task == 'txt2img' and control_processor:
            inputs['image'] = inputs['control_image']

        if task in ['img2img', 'inpaint']:
            inputs['image'] = [Image.open(inputs['image']).convert('RGB')]

        if task == 'inpaint':
            mask_path = inputs['mask_image']
            inputs['mask_image'] = [Image.open(mask_path).convert('RGB')]

        images = []
        images = app(**inputs)
        if len(images) > 0:
            print(f"Prompt: {prompt}")
            print(f"Negative Prompt: {negative_prompt}")
            print(f"Seed: {calculate_seed}")
            for index, img in enumerate(images):
                dir = OUT_DIR
                img_path = save_image(img, dir, f'{calculate_seed}-{index}')
                print(f'Saved: {img_path}')

            if inputs['control_image']:
                images = images + inputs['control_image']

            total = len(images)
            row = 1 if total % 4 != 0 else total // 4
            col = total if row == 1 else 4
            grid = image_grid(images, row, col)
            display(grid)
        else:
            print('Empty!')

def on_click_clear_button(b):
    applog.clear_output()

ui.generate_button._click_handlers.callbacks = []
ui.generate_button.on_click(on_click_generate_button)
ui.clear_button._click_handlers.callbacks = []
ui.clear_button.on_click(on_click_clear_button)

display(ui.widget())
display(applog)

In [None]:
#@title Controlnet Annotators {display-mode:"form"}

from PIL import Image
from datetime import datetime
from ipywidgets import widgets
from IPython.display import display

AUX_OUT_DIR = '/content/aux' #@param {type:'string'}
os.makedirs(AUX_OUT_DIR, exist_ok=True)

def aux_process(processor, input):
    model = Processor(processor)
    return model(input, to_pil=True)

aux_log = widgets.Output()

aux_input = widgets.Text(
  description='Input',
  placeholder='/content/input.jpg'
)

aux_processor = widgets.Dropdown(
  description='Processor',
  options=AUX_PROCESSOR,
)

aux_button = widgets.Button(
  description='Process',
  button_style='primary',
)

aux_layout = widgets.VBox([
  aux_processor,
  aux_input,
  aux_button,
])

def aux_on_click(b):
  aux_log.clear_output()
  with aux_log:
    if not os.path.exists(aux_input.value):
      print(f'Input: {aux_input.value} not found')
      return

    input = Image.open(aux_input.value).convert('RGB')
    result = aux_process(aux_processor.value, input)
    output = f'{AUX_OUT_DIR}/{datetime.now().strftime("%Y%m%d-%H%M%S")}.png'
    if result.save(output):
      print(f'Saved: {output}')

aux_button._click_handlers.callbacks = []
aux_button.on_click(aux_on_click)
display(aux_layout)
display(aux_log)