diff --git a/.gitignore b/.gitignore index 09734267ff5..b7460b9a634 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,8 @@ __pycache__ /repositories /venv /tmp +/cache +/footprints /model.ckpt /models/**/* /GFPGANv1.3.pth diff --git a/README.md b/README.md index dc1aeae97d7..73088e0afb9 100644 --- a/README.md +++ b/README.md @@ -185,4 +185,5 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al - UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC - TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. +- Olive - https://github.com/microsoft/Olive - (You) diff --git a/configs/olive_optimize/config_safety_checker.json b/configs/olive_optimize/config_safety_checker.json new file mode 100644 index 00000000000..b92fedab626 --- /dev/null +++ b/configs/olive_optimize/config_safety_checker.json @@ -0,0 +1,94 @@ +{ + "input_model": { + "type": "PyTorchModel", + "config": { + "model_path": "runwayml/stable-diffusion-v1-5", + "model_loader": "safety_checker_load", + "model_script": "modules/sd_olive_scripts.py", + "io_config": { + "input_names": [ "clip_input", "images" ], + "output_names": [ "out_images", "has_nsfw_concepts" ], + "dynamic_axes": { + "clip_input": { "0": "batch", "1": "channels", "2": "height", "3": "width" }, + "images": { "0": "batch", "1": "height", "2": "width", "3": "channels" } + } + }, + "dummy_inputs_func": "safety_checker_conversion_inputs" + } + }, + "systems": { + "local_system": { + "type": "LocalSystem", + "config": { + "accelerators": ["gpu"] + } + } + }, + "evaluators": { + "common_evaluator": { + "metrics": [ + { + "name": "latency", + "type": "latency", + "sub_types": [{"name": "avg"}], + "user_config": { + "user_script": "modules/sd_olive_scripts.py", + "dataloader_func": "safety_checker_data_loader", + "batch_size": 1 + } + } + ] + } + }, + "passes": { + "convert": { + "type": "OnnxConversion", + "config": { + "target_opset": 14 + } + }, + "optimize": { + "type": "OrtTransformersOptimization", + "config": { + "model_type": "unet", + "float16": true, + "use_gpu": true, + "keep_io_types": false, + "optimization_options": { + "enable_gelu": true, + "enable_layer_norm": true, + "enable_attention": true, + "use_multi_head_attention": true, + "enable_skip_layer_norm": false, + "enable_embed_layer_norm": true, + "enable_bias_skip_layer_norm": false, + "enable_bias_gelu": true, + "enable_gelu_approximation": false, + "enable_qordered_matmul": false, + "enable_shape_inference": true, + "enable_gemm_fast_gelu": false, + "enable_nhwc_conv": false, + "enable_group_norm": true, + "enable_bias_splitgelu": false, + "enable_packed_qkv": true, + "enable_packed_kv": true, + "enable_bias_add": false + }, + "force_fp32_ops": ["RandomNormalLike"] + } + } + }, + "engine": { + "search_strategy": { + "execution_order": "joint", + "search_algorithm": "exhaustive" + }, + "evaluator": "common_evaluator", + "host": "local_system", + "target": "local_system", + "cache_dir": "cache", + "output_name": "safety_checker", + "output_dir": "footprints", + "execution_providers": ["DmlExecutionProvider"] + } +} diff --git a/configs/olive_optimize/config_text_encoder.json b/configs/olive_optimize/config_text_encoder.json new file mode 100644 index 00000000000..a836f5d58ed --- /dev/null +++ b/configs/olive_optimize/config_text_encoder.json @@ -0,0 +1,91 @@ +{ + "input_model": { + "type": "PyTorchModel", + "config": { + "model_path": "runwayml/stable-diffusion-v1-5", + "model_loader": "text_encoder_load", + "model_script": "modules/sd_olive_scripts.py", + "io_config": { + "input_names": [ "input_ids" ], + "output_names": [ "last_hidden_state", "pooler_output" ], + "dynamic_axes": { "input_ids": { "0": "batch", "1": "sequence" } } + }, + "dummy_inputs_func": "text_encoder_conversion_inputs" + } + }, + "systems": { + "local_system": { + "type": "LocalSystem", + "config": { + "accelerators": ["gpu"] + } + } + }, + "evaluators": { + "common_evaluator": { + "metrics": [ + { + "name": "latency", + "type": "latency", + "sub_types": [{"name": "avg"}], + "user_config": { + "user_script": "modules/sd_olive_scripts.py", + "dataloader_func": "text_encoder_data_loader", + "batch_size": 1 + } + } + ] + } + }, + "passes": { + "convert": { + "type": "OnnxConversion", + "config": { + "target_opset": 14 + } + }, + "optimize": { + "type": "OrtTransformersOptimization", + "config": { + "model_type": "clip", + "float16": true, + "use_gpu": true, + "keep_io_types": false, + "optimization_options": { + "enable_gelu": true, + "enable_layer_norm": true, + "enable_attention": true, + "use_multi_head_attention": true, + "enable_skip_layer_norm": false, + "enable_embed_layer_norm": true, + "enable_bias_skip_layer_norm": false, + "enable_bias_gelu": true, + "enable_gelu_approximation": false, + "enable_qordered_matmul": false, + "enable_shape_inference": true, + "enable_gemm_fast_gelu": false, + "enable_nhwc_conv": false, + "enable_group_norm": true, + "enable_bias_splitgelu": false, + "enable_packed_qkv": true, + "enable_packed_kv": true, + "enable_bias_add": false + }, + "force_fp32_ops": ["RandomNormalLike"] + } + } + }, + "engine": { + "search_strategy": { + "execution_order": "joint", + "search_algorithm": "exhaustive" + }, + "evaluator": "common_evaluator", + "host": "local_system", + "target": "local_system", + "cache_dir": "cache", + "output_name": "text_encoder", + "output_dir": "footprints", + "execution_providers": ["DmlExecutionProvider"] + } +} diff --git a/configs/olive_optimize/config_unet.json b/configs/olive_optimize/config_unet.json new file mode 100644 index 00000000000..28e9359787f --- /dev/null +++ b/configs/olive_optimize/config_unet.json @@ -0,0 +1,98 @@ +{ + "input_model": { + "type": "PyTorchModel", + "config": { + "model_path": "runwayml/stable-diffusion-v1-5", + "model_loader": "unet_load", + "model_script": "modules/sd_olive_scripts.py", + "io_config": { + "input_names": [ "sample", "timestep", "encoder_hidden_states", "return_dict" ], + "output_names": [ "out_sample" ], + "dynamic_axes": { + "sample": {"0": "unet_sample_batch", "1": "unet_sample_channels", "2": "unet_sample_height", "3": "unet_sample_width"}, + "timestep": {"0": "unet_time_batch"}, + "encoder_hidden_states": {"0": "unet_hidden_batch", "1": "unet_hidden_sequence"} + } + }, + "dummy_inputs_func": "unet_conversion_inputs" + } + }, + "systems": { + "local_system": { + "type": "LocalSystem", + "config": { + "accelerators": ["gpu"] + } + } + }, + "evaluators": { + "common_evaluator": { + "metrics": [ + { + "name": "latency", + "type": "latency", + "sub_types": [{"name": "avg"}], + "user_config": { + "user_script": "modules/sd_olive_scripts.py", + "dataloader_func": "unet_data_loader", + "batch_size": 2 + } + } + ] + } + }, + "passes": { + "convert": { + "type": "OnnxConversion", + "config": { + "target_opset": 14, + "save_as_external_data": true, + "all_tensors_to_one_file": true, + "external_data_name": "weights.pb" + } + }, + "optimize": { + "type": "OrtTransformersOptimization", + "config": { + "model_type": "unet", + "float16": true, + "use_gpu": true, + "keep_io_types": false, + "optimization_options": { + "enable_gelu": true, + "enable_layer_norm": true, + "enable_attention": true, + "use_multi_head_attention": true, + "enable_skip_layer_norm": false, + "enable_embed_layer_norm": true, + "enable_bias_skip_layer_norm": false, + "enable_bias_gelu": true, + "enable_gelu_approximation": false, + "enable_qordered_matmul": false, + "enable_shape_inference": true, + "enable_gemm_fast_gelu": false, + "enable_nhwc_conv": false, + "enable_group_norm": true, + "enable_bias_splitgelu": false, + "enable_packed_qkv": true, + "enable_packed_kv": true, + "enable_bias_add": false + }, + "force_fp32_ops": ["RandomNormalLike"] + } + } + }, + "engine": { + "search_strategy": { + "execution_order": "joint", + "search_algorithm": "exhaustive" + }, + "evaluator": "common_evaluator", + "host": "local_system", + "target": "local_system", + "cache_dir": "cache", + "output_name": "unet", + "output_dir": "footprints", + "execution_providers": ["DmlExecutionProvider"] + } +} diff --git a/configs/olive_optimize/config_vae_decoder.json b/configs/olive_optimize/config_vae_decoder.json new file mode 100644 index 00000000000..cac8daee244 --- /dev/null +++ b/configs/olive_optimize/config_vae_decoder.json @@ -0,0 +1,91 @@ +{ + "input_model": { + "type": "PyTorchModel", + "config": { + "model_path": "runwayml/stable-diffusion-v1-5", + "model_loader": "vae_decoder_load", + "model_script": "modules/sd_olive_scripts.py", + "io_config": { + "input_names": [ "latent_sample", "return_dict" ], + "output_names": [ "sample" ], + "dynamic_axes": { "latent_sample": { "0": "batch", "1": "channels", "2": "height", "3": "width" } } + }, + "dummy_inputs_func": "vae_decoder_conversion_inputs" + } + }, + "systems": { + "local_system": { + "type": "LocalSystem", + "config": { + "accelerators": ["gpu"] + } + } + }, + "evaluators": { + "common_evaluator": { + "metrics": [ + { + "name": "latency", + "type": "latency", + "sub_types": [{"name": "avg"}], + "user_config": { + "user_script": "modules/sd_olive_scripts.py", + "dataloader_func": "vae_decoder_data_loader", + "batch_size": 1 + } + } + ] + } + }, + "passes": { + "convert": { + "type": "OnnxConversion", + "config": { + "target_opset": 14 + } + }, + "optimize": { + "type": "OrtTransformersOptimization", + "config": { + "model_type": "vae", + "float16": true, + "use_gpu": true, + "keep_io_types": false, + "optimization_options": { + "enable_gelu": true, + "enable_layer_norm": true, + "enable_attention": true, + "use_multi_head_attention": true, + "enable_skip_layer_norm": false, + "enable_embed_layer_norm": true, + "enable_bias_skip_layer_norm": false, + "enable_bias_gelu": true, + "enable_gelu_approximation": false, + "enable_qordered_matmul": false, + "enable_shape_inference": true, + "enable_gemm_fast_gelu": false, + "enable_nhwc_conv": false, + "enable_group_norm": true, + "enable_bias_splitgelu": false, + "enable_packed_qkv": true, + "enable_packed_kv": true, + "enable_bias_add": false + }, + "force_fp32_ops": ["RandomNormalLike"] + } + } + }, + "engine": { + "search_strategy": { + "execution_order": "joint", + "search_algorithm": "exhaustive" + }, + "evaluator": "common_evaluator", + "host": "local_system", + "target": "local_system", + "cache_dir": "cache", + "output_name": "vae_decoder", + "output_dir": "footprints", + "execution_providers": ["DmlExecutionProvider"] + } +} diff --git a/configs/olive_optimize/config_vae_encoder.json b/configs/olive_optimize/config_vae_encoder.json new file mode 100644 index 00000000000..cec7b0fb5c1 --- /dev/null +++ b/configs/olive_optimize/config_vae_encoder.json @@ -0,0 +1,91 @@ +{ + "input_model": { + "type": "PyTorchModel", + "config": { + "model_path": "runwayml/stable-diffusion-v1-5", + "model_loader": "vae_encoder_load", + "model_script": "modules/sd_olive_scripts.py", + "io_config": { + "input_names": [ "sample", "return_dict" ], + "output_names": [ "latent_sample" ], + "dynamic_axes": { "sample": { "0": "batch", "1": "channels", "2": "height", "3": "width" } } + }, + "dummy_inputs_func": "vae_encoder_conversion_inputs" + } + }, + "systems": { + "local_system": { + "type": "LocalSystem", + "config": { + "accelerators": ["gpu"] + } + } + }, + "evaluators": { + "common_evaluator": { + "metrics": [ + { + "name": "latency", + "type": "latency", + "sub_types": [{"name": "avg"}], + "user_config": { + "user_script": "modules/sd_olive_scripts.py", + "dataloader_func": "vae_encoder_data_loader", + "batch_size": 1 + } + } + ] + } + }, + "passes": { + "convert": { + "type": "OnnxConversion", + "config": { + "target_opset": 14 + } + }, + "optimize": { + "type": "OrtTransformersOptimization", + "config": { + "model_type": "vae", + "float16": true, + "use_gpu": true, + "keep_io_types": false, + "optimization_options": { + "enable_gelu": true, + "enable_layer_norm": true, + "enable_attention": true, + "use_multi_head_attention": true, + "enable_skip_layer_norm": false, + "enable_embed_layer_norm": true, + "enable_bias_skip_layer_norm": false, + "enable_bias_gelu": true, + "enable_gelu_approximation": false, + "enable_qordered_matmul": false, + "enable_shape_inference": true, + "enable_gemm_fast_gelu": false, + "enable_nhwc_conv": false, + "enable_group_norm": true, + "enable_bias_splitgelu": false, + "enable_packed_qkv": true, + "enable_packed_kv": true, + "enable_bias_add": false + }, + "force_fp32_ops": ["RandomNormalLike"] + } + } + }, + "engine": { + "search_strategy": { + "execution_order": "joint", + "search_algorithm": "exhaustive" + }, + "evaluator": "common_evaluator", + "host": "local_system", + "target": "local_system", + "cache_dir": "cache", + "output_name": "vae_encoder", + "output_dir": "footprints", + "execution_providers": ["DmlExecutionProvider"] + } +} diff --git a/html/licenses.html b/html/licenses.html index ef6f2c0a42b..424426146f9 100644 --- a/html/licenses.html +++ b/html/licenses.html @@ -687,4 +687,30 @@
+MIT License + +Copyright (c) Microsoft Corporation. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE\ No newline at end of file diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 641b3101b0d..68255e78030 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -91,6 +91,7 @@ parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) parser.add_argument("--api-log", action='store_true', help="use api-log=True to enable logging of all API requests") parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui") +parser.add_argument("--olive", action='store_true', help="[Beta] Enable Olive optimization support.", default=False) parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) parser.add_argument("--backend", type=str, help="Select the backend to be used. Default: 'auto'", choices=["cuda", "rocm", "directml", "auto"], default="auto") diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 2616a410ef3..31244ab707f 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -232,6 +232,7 @@ def prepare_environment(): print("ROCm was found. Automatically changed backend to 'rocm'. You can manually select which backend will be used through '--backend' argument.") else: args.backend = 'directml' + cmd_args.parser.set_defaults(backend=args.backend) if args.backend == 'cuda': torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}") if args.backend == 'rocm': @@ -240,6 +241,7 @@ def prepare_environment(): torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==2.0.0 torchvision==0.15.1 torch-directml") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") + requirements_file_olive = os.environ.get('REQS_FILE', "requirements_olive.txt") xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17') gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "https://github.com/TencentARC/GFPGAN/archive/8d2447a2d918f8eba5a4a01463fd48e45126a379.zip") @@ -310,6 +312,15 @@ def prepare_environment(): requirements_file = os.path.join(script_path, requirements_file) run_pip(f"install -r \"{requirements_file}\"", "requirements") + if args.olive: + if args.backend != 'directml': + print(f"Olive optimization requires DirectML as a backend, but you have: '{args.backend}'. Try again with '--backend directml'.") + exit(0) + print("WARNING! Because Olive optimization does not support torch 2.0, some packages will be downgraded and it can occur version mismatches between packages. (Strongly recommend to create another virtual environment to run Olive)") + if not is_installed("olive-ai"): + run_pip("install olive-ai[directml]", "Olive") + run_pip(f"install -r \"{requirements_file_olive}\"", "requirements for Olive") + run_extensions_installers(settings_file=args.ui_settings_file) if args.update_check: diff --git a/modules/paths.py b/modules/paths.py index 5f6474c032a..d2c4aab4f76 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -1,5 +1,9 @@ import os import sys +try: + import olive.workflows +except: + pass from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir # noqa: F401 import modules.safe # noqa: F401 diff --git a/modules/processing.py b/modules/processing.py index 29a3743f578..ccf9e7f335a 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -597,6 +597,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed: p.override_settings.pop('sd_model_checkpoint', None) sd_models.reload_model_weights() + if cmd_opts.olive: + return p.process() + for k, v in p.override_settings.items(): setattr(opts, k, v) diff --git a/modules/sd_models.py b/modules/sd_models.py index b1afbaa7ffb..573881423e6 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -8,7 +8,7 @@ import re import safetensors.torch from omegaconf import OmegaConf -from os import mkdir +from os import mkdir, listdir from urllib import request import ldm.modules.midas as midas @@ -19,7 +19,7 @@ from modules.timer import Timer import tomesd -model_dir = "Stable-diffusion" +model_dir = "ONNX-Olive" if shared.cmd_opts.olive else "Stable-diffusion" model_path = os.path.abspath(os.path.join(paths.models_path, model_dir)) checkpoints_list = {} @@ -45,9 +45,9 @@ def __init__(self, filename): self.name = name self.name_for_extra = os.path.splitext(os.path.basename(filename))[0] self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] - self.hash = model_hash(filename) + self.hash = model_hash(filename) if not shared.cmd_opts.olive else None - self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{name}") + self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{name}") if not shared.cmd_opts.olive else None self.shorthash = self.sha256[0:10] if self.sha256 else None self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]' @@ -69,6 +69,9 @@ def register(self): checkpoint_alisases[id] = self def calculate_shorthash(self): + if shared.cmd_opts.olive: + return + self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}") if self.sha256 is None: return @@ -121,7 +124,10 @@ def list_models(): else: model_url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors" - model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"]) + if shared.cmd_opts.olive: + model_list = [f for f in listdir(model_path) if os.path.isdir(os.path.join(model_path, f))] + else: + model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"]) if os.path.exists(cmd_ckpt): checkpoint_info = CheckpointInfo(cmd_ckpt) @@ -447,6 +453,12 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): gc.collect() devices.torch_gc() + if shared.cmd_opts.olive: + from modules.sd_olive import OliveOptimizedModel + model_data.set_sd_model(OliveOptimizedModel(checkpoint_info.name)) + print(f"Model {model_data.sd_model.path} loaded.") + return model_data.sd_model + do_inpainting_hijack() timer = Timer() @@ -520,6 +532,10 @@ def reload_model_weights(sd_model=None, info=None): if not sd_model: sd_model = model_data.sd_model + if shared.cmd_opts.olive: + model_data.set_sd_model(sd_model) + return sd_model + if sd_model is None: # previous model load failed current_checkpoint_info = None else: @@ -573,6 +589,9 @@ def unload_model_weights(sd_model=None, info=None): from modules import devices, sd_hijack timer = Timer() + if shared.cmd_opts.olive: + return sd_model + if model_data.sd_model: model_data.sd_model.to(devices.cpu) sd_hijack.model_hijack.undo_hijack(model_data.sd_model) diff --git a/modules/sd_olive.py b/modules/sd_olive.py new file mode 100644 index 00000000000..9ebba1d8320 --- /dev/null +++ b/modules/sd_olive.py @@ -0,0 +1,469 @@ +import json +import torch +import numpy as np +import shutil +import inspect +import onnxruntime as ort +from pathlib import Path +from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, StableDiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE + +from olive.model import ONNXModel +from olive.workflows import run as olive_run + +from modules import shared, images, devices +from modules.paths_internal import sd_configs_path, models_path +from modules.sd_olive_scripts import get_base_model_name +from modules.processing import Processed, get_fixed_seed + +def __call__( + self, + p, + prompt = None, + height = 512, + width = 512, + num_inference_steps = 50, + guidance_scale = 7.5, + negative_prompt = None, + num_images_per_prompt = 1, + eta = 0.0, + generator = None, + latents = None, + prompt_embeds = None, + negative_prompt_embeds = None, + output_type = "pil", + return_dict: bool = True, + callback = None, + callback_steps: int = 1, +): + # check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if generator is None: + generator = np.random + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + prompt_embeds = self._encode_prompt( + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # get the initial random noise unless the user supplied it + latents_dtype = prompt_embeds.dtype + latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8) + if latents is None: + latents = generator.randn(*latents_shape).astype(latents_dtype) + elif latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + latents = latents * np.float64(self.scheduler.init_noise_sigma) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + timestep_dtype = next( + (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" + ) + timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + if shared.state.job_count == -1: + shared.state.job_count = p.n_iter + + if shared.state.skipped: + shared.state.skipped = False + + p.prompts = p.all_prompts[i * p.batch_size:(i + 1) * p.batch_size] + + if p.n_iter > 1: + shared.state.job = f"Batch {i+1} out of {p.n_iter}" + + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) + latent_model_input = latent_model_input.cpu().numpy() + + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds) + noise_pred = noise_pred[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs + ) + latents = scheduler_output.prev_sample.numpy() + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + latents = 1 / 0.18215 * latents + # image = self.vae_decoder(latent_sample=latents)[0] + # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 + image = np.concatenate( + [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] + ) + + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor( + self.numpy_to_pil(image), return_tensors="np" + ).pixel_values.astype(image.dtype) + + images, has_nsfw_concept = [], [] + for i in range(image.shape[0]): + image_i, has_nsfw_concept_i = self.safety_checker( + clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] + ) + images.append(image_i) + has_nsfw_concept.append(has_nsfw_concept_i[0]) + image = np.concatenate(images) + else: + has_nsfw_concept = None + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) +OnnxStableDiffusionPipeline.__call__ = __call__ + +def optimize(model_id: str, unoptimized_dir: str, optimized_dir: str, olive_safety_checker, olive_text_encoder, olive_unet, olive_vae_decoder, olive_vae_encoder, use_fp16): + unoptimized_dir = Path(models_path) / "ONNX" / unoptimized_dir + optimized_dir = Path(models_path) / "ONNX-Olive" / optimized_dir + + shutil.rmtree("footprints", ignore_errors=True) + shutil.rmtree(unoptimized_dir, ignore_errors=True) + shutil.rmtree(optimized_dir, ignore_errors=True) + + base_model_id = get_base_model_name(model_id) + + pipeline = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float32) + + submodels = [] + model_info = {} + + if olive_safety_checker: + submodels += ["safety_checker"] + if olive_text_encoder: + submodels += ["text_encoder"] + if olive_unet: + submodels += ["unet"] + if olive_vae_decoder: + submodels += ["vae_decoder"] + if olive_vae_encoder: + submodels += ["vae_encoder"] + + for submodel_name in submodels: + print(f"\nOptimizing {submodel_name}") + + olive_config = None + with open(Path(sd_configs_path) / "olive_optimize" / f"config_{submodel_name}.json", "r") as fin: + olive_config = json.load(fin) + + olive_config["passes"]["optimize"]["config"]["float16"] = use_fp16 + if submodel_name in ("unet", "text_encoder"): + olive_config["input_model"]["config"]["model_path"] = model_id + else: + olive_config["input_model"]["config"]["model_path"] = base_model_id + + olive_run(olive_config) + + footprints_file_path = ( + Path("footprints") / f"{submodel_name}_gpu-dml_footprints.json" + ) + with footprints_file_path.open("r") as footprint_file: + footprints = json.load(footprint_file) + + conversion_footprint = None + optimizer_footprint = None + for _, footprint in footprints.items(): + if footprint["from_pass"] == "OnnxConversion": + conversion_footprint = footprint + elif footprint["from_pass"] == "OrtTransformersOptimization": + optimizer_footprint = footprint + + assert conversion_footprint and optimizer_footprint + + unoptimized_olive_model = ONNXModel(**conversion_footprint["model_config"]["config"]) + optimized_olive_model = ONNXModel(**optimizer_footprint["model_config"]["config"]) + + model_info[submodel_name] = { + "unoptimized": { + "path": Path(unoptimized_olive_model.model_path), + }, + "optimized": { + "path": Path(optimized_olive_model.model_path), + }, + } + + print(f"Optimized {submodel_name}") + + print("\nCreating ONNX pipeline...") + onnx_pipeline = OnnxStableDiffusionPipeline( + vae_encoder=OnnxRuntimeModel.from_pretrained(model_info["vae_encoder"]["unoptimized"]["path"].parent), + vae_decoder=OnnxRuntimeModel.from_pretrained(model_info["vae_decoder"]["unoptimized"]["path"].parent), + text_encoder=OnnxRuntimeModel.from_pretrained(model_info["text_encoder"]["unoptimized"]["path"].parent), + tokenizer=pipeline.tokenizer, + unet=OnnxRuntimeModel.from_pretrained(model_info["unet"]["unoptimized"]["path"].parent), + scheduler=pipeline.scheduler, + safety_checker=None if pipeline.feature_extractor is None else OnnxRuntimeModel.from_pretrained(model_info["safety_checker"]["unoptimized"]["path"].parent), + feature_extractor=pipeline.feature_extractor, + requires_safety_checker=True, + ) + + print("Saving unoptimized models...") + onnx_pipeline.save_pretrained(unoptimized_dir) + + print("Copying optimized models...") + shutil.copytree(unoptimized_dir, optimized_dir, ignore=shutil.ignore_patterns("weights.pb")) + for submodel_name in submodels: + try: + src_path = model_info[submodel_name]["optimized"]["path"] + dst_path = optimized_dir / submodel_name / "model.onnx" + shutil.copyfile(src_path, dst_path) + except: + pass + + shared.refresh_checkpoints() + print(f"Optimization complete.") + + return f'Saved as {unoptimized_dir}', '' + +class OliveOptimizedModel: + def __init__(self, path: str): + self.path = path + self.sd_model_hash = None + +class OliveOptimizedProcessingTxt2Img: + def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params = None, overlay_images = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = ''): + self.sd_model: Path = Path(models_path) / "ONNX-Olive" / sd_model.path + self.outpath_samples: str = outpath_samples + self.outpath_grids: str = outpath_grids + self.prompt: str = prompt + self.prompt_for_display: str = None + self.negative_prompt: str = (negative_prompt or "") + self.styles: list = styles or [] + self.seed: int = seed + self.subseed: int = subseed + self.subseed_strength: float = subseed_strength + self.seed_resize_from_h: int = seed_resize_from_h + self.seed_resize_from_w: int = seed_resize_from_w + self.sampler_name: str = sampler_name + self.batch_size: int = batch_size + self.n_iter: int = n_iter + self.steps: int = steps + self.cfg_scale: float = cfg_scale + self.width: int = width + self.height: int = height + self.restore_faces: bool = restore_faces + self.tiling: bool = tiling + self.do_not_save_samples: bool = do_not_save_samples + self.do_not_save_grid: bool = do_not_save_grid + self.extra_generation_params: dict = extra_generation_params or {} + self.overlay_images = overlay_images + self.eta = eta + self.do_not_reload_embeddings = do_not_reload_embeddings + self.paste_to = None + self.color_corrections = None + self.denoising_strength: float = denoising_strength + self.sampler_noise_scheduler_override = None + self.ddim_discretize = ddim_discretize or shared.opts.ddim_discretize + self.s_min_uncond = s_min_uncond or shared.opts.s_min_uncond + self.s_churn = s_churn or shared.opts.s_churn + self.s_tmin = s_tmin or shared.opts.s_tmin + self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option + self.s_noise = s_noise or shared.opts.s_noise + self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts} + self.override_settings_restore_afterwards = override_settings_restore_afterwards + self.is_using_inpainting_conditioning = False + self.disable_extra_networks = False + self.token_merging_ratio = 0 + self.token_merging_ratio_hr = 0 + + if not seed_enable_extras: + self.subseed = -1 + self.subseed_strength = 0 + self.seed_resize_from_h = 0 + self.seed_resize_from_w = 0 + + self.scripts = None + self.script_args = script_args + self.all_prompts = None + self.all_negative_prompts = None + self.all_seeds = None + self.all_subseeds = None + self.iteration = 0 + self.is_hr_pass = False + self.sampler = None + + self.prompts = None + self.negative_prompts = None + self.seeds = None + self.subseeds = None + + self.step_multiplier = 1 + self.cached_uc = [None, None] + self.cached_c = [None, None] + self.uc = None + self.c = None + + if type(prompt) == list: + self.all_prompts = self.prompt + else: + self.all_prompts = self.batch_size * self.n_iter * [self.prompt] + + if type(self.negative_prompt) == list: + self.all_negative_prompts = self.negative_prompt + else: + self.all_negative_prompts = self.batch_size * self.n_iter * [self.negative_prompt] + + self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts] + self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts] + + self.extra_generation_params: dict = {} + self.override_settings = {k: v for k, v in ({}).items() if k not in shared.restricted_opts} + self.override_settings_restore_afterwards = False + + self.enable_hr = enable_hr + self.denoising_strength = denoising_strength + self.hr_scale = hr_scale + self.hr_upscaler = hr_upscaler + self.hr_second_pass_steps = hr_second_pass_steps + self.hr_resize_x = hr_resize_x + self.hr_resize_y = hr_resize_y + self.hr_upscale_to_x = hr_resize_x + self.hr_upscale_to_y = hr_resize_y + self.hr_sampler_name = hr_sampler_name + self.hr_prompt = hr_prompt + self.hr_negative_prompt = hr_negative_prompt + self.all_hr_prompts = None + self.all_hr_negative_prompts = None + if firstphase_width != 0 or firstphase_height != 0: + self.hr_upscale_to_x = self.width + self.hr_upscale_to_y = self.height + self.width = firstphase_width + self.height = firstphase_height + + self.truncate_x = 0 + self.truncate_y = 0 + self.applied_old_hires_behavior_to = None + + self.hr_prompts = None + self.hr_negative_prompts = None + self.hr_extra_network_data = None + + self.hr_c = None + self.hr_uc = None + + self.sess_options = ort.SessionOptions() + self.sess_options.enable_mem_pattern = False + self.sess_options.add_free_dimension_override_by_name("unet_sample_batch", self.batch_size * 2) + self.sess_options.add_free_dimension_override_by_name("unet_sample_channels", 4) + self.sess_options.add_free_dimension_override_by_name("unet_sample_height", 64) + self.sess_options.add_free_dimension_override_by_name("unet_sample_width", 64) + self.sess_options.add_free_dimension_override_by_name("unet_time_batch", 1) + self.sess_options.add_free_dimension_override_by_name("unet_hidden_batch", self.batch_size * 2) + self.sess_options.add_free_dimension_override_by_name("unet_hidden_sequence", 77) + self.pipeline = OnnxStableDiffusionPipeline.from_pretrained( + self.sd_model, provider="DmlExecutionProvider", sess_options=self.sess_options + ) + + def process(self) -> Processed: + if type(self.prompt) == list: + assert(len(self.prompt) > 0) + else: + assert self.prompt is not None + + devices.torch_gc() + + seed = get_fixed_seed(self.seed) + subseed = get_fixed_seed(self.subseed) + + if type(seed) == list: + self.all_seeds = seed + else: + self.all_seeds = [int(seed) + (x if self.subseed_strength == 0 else 0) for x in range(len(self.all_prompts))] + + if type(subseed) == list: + self.all_subseeds = subseed + else: + self.all_subseeds = [int(subseed) + x for x in range(len(self.all_prompts))] + + output_images = [] + + result = self.pipeline(self, [self.prompt] * self.batch_size, num_inference_steps=self.steps) + + for n in range(self.n_iter): + image = result.images[n] + images.save_image(image, self.outpath_samples, "") + output_images.append(image) + + devices.torch_gc() + shared.state.nextjob() + + index_of_first_image = 0 + unwanted_grid_because_of_img_count = len(output_images) < 2 and shared.opts.grid_only_if_multiple + if (shared.opts.return_grid or shared.opts.grid_save) and not unwanted_grid_because_of_img_count: + grid = images.image_grid(output_images, self.batch_size) + + if shared.opts.return_grid: + output_images.insert(0, grid) + index_of_first_image = 1 + + if shared.opts.grid_save: + images.save_image(grid, self.outpath_grids, "grid", self.all_seeds[0], self.all_prompts[0], shared.opts.grid_format, short_filename=not shared.opts.grid_extended_filename, grid=True) + + devices.torch_gc() + + return Processed( + self, + images_list=output_images, + seed=self.all_seeds[0], + info="", + comments="", + subseed=self.all_subseeds[0], + index_of_first_image=index_of_first_image, + infotexts=[], + ) + + def close(self): + return diff --git a/modules/sd_olive_scripts.py b/modules/sd_olive_scripts.py new file mode 100644 index 00000000000..fae36d9b265 --- /dev/null +++ b/modules/sd_olive_scripts.py @@ -0,0 +1,237 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import torch +from diffusers import AutoencoderKL, UNet2DConditionModel +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from huggingface_hub import model_info +from transformers.models.clip.modeling_clip import CLIPTextModel + + +# Helper latency-only dataloader that creates random tensors with no label +class RandomDataLoader: + def __init__(self, create_inputs_func, batchsize, torch_dtype): + self.create_input_func = create_inputs_func + self.batchsize = batchsize + self.torch_dtype = torch_dtype + + def __getitem__(self, idx): + label = None + return self.create_input_func(self.batchsize, self.torch_dtype), label + + +def get_base_model_name(model_name): + return model_info(model_name).cardData.get("base_model", model_name) + + +def is_lora_model(model_name): + # TODO: might be a better way to detect (e.g. presence of LORA weights file) + return model_name != get_base_model_name(model_name) + + +# Merges LoRA weights into the layers of a base model +def merge_lora_weights(base_model, lora_model_id, submodel_name="unet", scale=1.0): + from collections import defaultdict + from functools import reduce + + from diffusers.loaders import LORA_WEIGHT_NAME + from diffusers.models.attention_processor import LoRAAttnProcessor + from diffusers.utils import DIFFUSERS_CACHE + from diffusers.utils.hub_utils import _get_model_file + + # Load LoRA weights + model_file = _get_model_file( + lora_model_id, + weights_name=LORA_WEIGHT_NAME, + cache_dir=DIFFUSERS_CACHE, + force_download=False, + resume_download=False, + proxies=None, + local_files_only=False, + use_auth_token=None, + revision=None, + subfolder=None, + user_agent={ + "file_type": "attn_procs_weights", + "framework": "pytorch", + }, + ) + lora_state_dict = torch.load(model_file, map_location="cpu") + + # All keys in the LoRA state dictionary should have 'lora' somewhere in the string. + keys = list(lora_state_dict.keys()) + assert all("lora" in k for k in keys) + + if all(key.startswith(submodel_name) for key in keys): + # New format (https://github.com/huggingface/diffusers/pull/2918) supports LoRA weights in both the + # unet and text encoder where keys are prefixed with 'unet' or 'text_encoder', respectively. + submodel_state_dict = {k: v for k, v in lora_state_dict.items() if k.startswith(submodel_name)} + else: + # Old format. Keys will not have any prefix. This only applies to unet, so exit early if this is + # optimizing the text encoder. + if submodel_name != "unet": + return + submodel_state_dict = lora_state_dict + + # Group LoRA weights into attention processors + attn_processors = {} + lora_grouped_dict = defaultdict(dict) + for key, value in submodel_state_dict.items(): + attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) + lora_grouped_dict[attn_processor_key][sub_key] = value + + for key, value_dict in lora_grouped_dict.items(): + rank = value_dict["to_k_lora.down.weight"].shape[0] + cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] + hidden_size = value_dict["to_k_lora.up.weight"].shape[0] + + attn_processors[key] = LoRAAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank + ) + attn_processors[key].load_state_dict(value_dict) + + # Merge LoRA attention processor weights into existing Q/K/V/Out weights + for name, proc in attn_processors.items(): + attention_name = name[: -len(".processor")] + attention = reduce(getattr, attention_name.split(sep="."), base_model) + attention.to_q.weight.data += scale * torch.mm(proc.to_q_lora.up.weight, proc.to_q_lora.down.weight) + attention.to_k.weight.data += scale * torch.mm(proc.to_k_lora.up.weight, proc.to_k_lora.down.weight) + attention.to_v.weight.data += scale * torch.mm(proc.to_v_lora.up.weight, proc.to_v_lora.down.weight) + attention.to_out[0].weight.data += scale * torch.mm(proc.to_out_lora.up.weight, proc.to_out_lora.down.weight) + + +# ----------------------------------------------------------------------------- +# TEXT ENCODER +# ----------------------------------------------------------------------------- + + +def text_encoder_inputs(batchsize, torch_dtype): + return torch.zeros((batchsize, 77), dtype=torch_dtype) + + +def text_encoder_load(model_name): + base_model_id = get_base_model_name(model_name) + model = CLIPTextModel.from_pretrained(base_model_id, subfolder="text_encoder") + if is_lora_model(model_name): + merge_lora_weights(model, model_name, "text_encoder") + return model + + +def text_encoder_conversion_inputs(model): + return text_encoder_inputs(1, torch.int32) + + +def text_encoder_data_loader(data_dir, batchsize): + return RandomDataLoader(text_encoder_inputs, batchsize, torch.int32) + + +# ----------------------------------------------------------------------------- +# UNET +# ----------------------------------------------------------------------------- + + +def unet_inputs(batchsize, torch_dtype): + return { + "sample": torch.rand((batchsize, 4, 64, 64), dtype=torch_dtype), + "timestep": torch.rand((batchsize,), dtype=torch_dtype), + "encoder_hidden_states": torch.rand((batchsize, 77, 768), dtype=torch_dtype), + "return_dict": False, + } + + +def unet_load(model_name): + base_model_id = get_base_model_name(model_name) + model = UNet2DConditionModel.from_pretrained(base_model_id, subfolder="unet") + if is_lora_model(model_name): + merge_lora_weights(model, model_name, "unet") + return model + + +def unet_conversion_inputs(model): + return tuple(unet_inputs(1, torch.float32).values()) + + +def unet_data_loader(data_dir, batchsize): + return RandomDataLoader(unet_inputs, batchsize, torch.float16) + + +# ----------------------------------------------------------------------------- +# VAE ENCODER +# ----------------------------------------------------------------------------- + + +def vae_encoder_inputs(batchsize, torch_dtype): + return { + "sample": torch.rand((batchsize, 3, 512, 512), dtype=torch_dtype), + "return_dict": False, + } + + +def vae_encoder_load(model_name): + base_model_id = get_base_model_name(model_name) + model = AutoencoderKL.from_pretrained(base_model_id, subfolder="vae") + model.forward = lambda sample, return_dict: model.encode(sample, return_dict)[0].sample() + return model + + +def vae_encoder_conversion_inputs(model): + return tuple(vae_encoder_inputs(1, torch.float32).values()) + + +def vae_encoder_data_loader(data_dir, batchsize): + return RandomDataLoader(vae_encoder_inputs, batchsize, torch.float16) + + +# ----------------------------------------------------------------------------- +# VAE DECODER +# ----------------------------------------------------------------------------- + + +def vae_decoder_inputs(batchsize, torch_dtype): + return { + "latent_sample": torch.rand((batchsize, 4, 64, 64), dtype=torch_dtype), + "return_dict": False, + } + + +def vae_decoder_load(model_name): + base_model_id = get_base_model_name(model_name) + model = AutoencoderKL.from_pretrained(base_model_id, subfolder="vae") + model.forward = model.decode + return model + + +def vae_decoder_conversion_inputs(model): + return tuple(vae_decoder_inputs(1, torch.float32).values()) + + +def vae_decoder_data_loader(data_dir, batchsize): + return RandomDataLoader(vae_decoder_inputs, batchsize, torch.float16) + + +# ----------------------------------------------------------------------------- +# SAFETY CHECKER +# ----------------------------------------------------------------------------- + + +def safety_checker_inputs(batchsize, torch_dtype): + return { + "clip_input": torch.rand((batchsize, 3, 224, 224), dtype=torch_dtype), + "images": torch.rand((batchsize, 512, 512, 3), dtype=torch_dtype), + } + + +def safety_checker_load(model_name): + base_model_id = get_base_model_name(model_name) + model = StableDiffusionSafetyChecker.from_pretrained(base_model_id, subfolder="safety_checker") + model.forward = model.forward_onnx + return model + + +def safety_checker_conversion_inputs(model): + return tuple(safety_checker_inputs(1, torch.float32).values()) + + +def safety_checker_data_loader(data_dir, batchsize): + return RandomDataLoader(safety_checker_inputs, batchsize, torch.float16) diff --git a/modules/txt2img.py b/modules/txt2img.py index 2e7d202d7b0..d934a4ab6a5 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -10,40 +10,54 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args): override_settings = create_override_settings_dict(override_settings_texts) - p = processing.StableDiffusionProcessingTxt2Img( - sd_model=shared.sd_model, - outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, - outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids, - prompt=prompt, - styles=prompt_styles, - negative_prompt=negative_prompt, - seed=seed, - subseed=subseed, - subseed_strength=subseed_strength, - seed_resize_from_h=seed_resize_from_h, - seed_resize_from_w=seed_resize_from_w, - seed_enable_extras=seed_enable_extras, - sampler_name=sd_samplers.samplers[sampler_index].name, - batch_size=batch_size, - n_iter=n_iter, - steps=steps, - cfg_scale=cfg_scale, - width=width, - height=height, - restore_faces=restore_faces, - tiling=tiling, - enable_hr=enable_hr, - denoising_strength=denoising_strength if enable_hr else None, - hr_scale=hr_scale, - hr_upscaler=hr_upscaler, - hr_second_pass_steps=hr_second_pass_steps, - hr_resize_x=hr_resize_x, - hr_resize_y=hr_resize_y, - hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None, - hr_prompt=hr_prompt, - hr_negative_prompt=hr_negative_prompt, - override_settings=override_settings, - ) + if cmd_opts.olive: + from modules.sd_olive import OliveOptimizedProcessingTxt2Img + p = OliveOptimizedProcessingTxt2Img( + sd_model=shared.sd_model, + outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, + outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids, + prompt=prompt, + seed=seed, + subseed=subseed, + batch_size=batch_size, + n_iter=n_iter, + steps=steps, + ) + else: + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, + outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids, + prompt=prompt, + styles=prompt_styles, + negative_prompt=negative_prompt, + seed=seed, + subseed=subseed, + subseed_strength=subseed_strength, + seed_resize_from_h=seed_resize_from_h, + seed_resize_from_w=seed_resize_from_w, + seed_enable_extras=seed_enable_extras, + sampler_name=sd_samplers.samplers[sampler_index].name, + batch_size=batch_size, + n_iter=n_iter, + steps=steps, + cfg_scale=cfg_scale, + width=width, + height=height, + restore_faces=restore_faces, + tiling=tiling, + enable_hr=enable_hr, + denoising_strength=denoising_strength if enable_hr else None, + hr_scale=hr_scale, + hr_upscaler=hr_upscaler, + hr_second_pass_steps=hr_second_pass_steps, + hr_resize_x=hr_resize_x, + hr_resize_y=hr_resize_y, + hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None, + hr_prompt=hr_prompt, + hr_negative_prompt=hr_negative_prompt, + override_settings=override_settings, + ) p.scripts = modules.scripts.scripts_txt2img p.script_args = args diff --git a/modules/ui.py b/modules/ui.py index e62182daa8a..50f079699ae 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -15,7 +15,8 @@ from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML -from modules.paths import script_path, data_path +from modules.paths import script_path, data_path, models_path +from modules.sd_olive import optimize from modules.shared import opts, cmd_opts @@ -1142,6 +1143,53 @@ def update_interp_description(value): with gr.Group(elem_id="modelmerger_results_panel"): modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False) + with gr.Blocks(analytics_enabled=False) as olive_interface: + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel'): + with gr.Tabs(elem_id="olive_tabs"): + with gr.Tab(label="Optimize ONNX using Olive"): + olive_model_id = gr.Textbox(label='Model ID', value="runwayml/stable-diffusion-v1-5", elem_id="olive_model_id", info="The huggingface identifier of the model to download and optimize.") + olive_source_dir = gr.Textbox(label='Onnx model folder', value="models/ONNX/runwayml/stable-diffusion-v1-5", elem_id="olive_source_dir") + olive_dir = gr.Textbox(label='Output folder', value="models/ONNX-Olive/runwayml/stable-diffusion-v1-5", elem_id="olive_dir") + + with gr.Column(elem_id="olive_width"): + min_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Minimum width", value=512, elem_id="olive_min_width") + max_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Maximum width", value=512, elem_id="olive_max_width") + + with gr.Column(elem_id="olive_height"): + min_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Minimum height", value=512, elem_id="olive_min_height") + max_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Maximum height", value=512, elem_id="olive_max_height") + + with gr.Column(elem_id="olive_batch_size"): + min_bs = gr.Slider(minimum=1, maximum=16, step=1, label="Minimum batch size", value=1, elem_id="olive_min_bs") + max_bs = gr.Slider(minimum=1, maximum=16, step=1, label="Maximum batch size", value=1, elem_id="olive_max_bs") + + with gr.Column(elem_id="olive_token_count"): + min_token_count = gr.Slider(minimum=75, maximum=750, step=75, label="Minimum prompt token count", value=75, elem_id="olive_min_token_count") + max_token_count = gr.Slider(minimum=75, maximum=750, step=75, label="Maximum prompt token count", value=75, elem_id="olive_max_token_count") + + with FormGroup(elem_id="olive_submodels", elem_classes="checkboxes-row", variant="compact"): + olive_safety_checker = gr.Checkbox(label='Safety Checker', value=True, elem_id="olive_safety_checker") + olive_text_encoder = gr.Checkbox(label='Text Encoder', value=True, elem_id="olive_text_encoder") + olive_unet = gr.Checkbox(label='UNet', value=True, elem_id="olive_unet") + olive_vae_decoder = gr.Checkbox(label='VAE Decoder', value=True, elem_id="olive_vae_decoder") + olive_vae_encoder = gr.Checkbox(label='VAE Encoder', value=True, elem_id="olive_vae_encoder") + + with FormRow(elem_classes="checkboxes-row", variant="compact"): + use_fp16 = gr.Checkbox(label='Use half floats', value=True, elem_id="olive_fp16") + + button_export_olive = gr.Button(value="Optimize ONNX model using Olive", variant='primary', elem_id="olive_optimize_from_onnx") + + with gr.Column(variant='panel'): + olive_result = gr.Label(elem_id="olive_result", value="", show_label=False) + olive_info = gr.HTML(elem_id="olive_info", value="") + + button_export_olive.click( + wrap_gradio_gpu_call(optimize, extra_outputs=[""]), + inputs=[olive_model_id, olive_source_dir, olive_dir, olive_safety_checker, olive_text_encoder, olive_unet, olive_vae_decoder, olive_vae_encoder, use_fp16], + outputs=[olive_result, olive_info], + ) + with gr.Blocks(analytics_enabled=False) as train_interface: with gr.Row().style(equal_height=False): gr.HTML(value="
See wiki for detailed explanation.
") @@ -1659,6 +1707,9 @@ def reload_scripts(): (train_interface, "Train", "train"), ] + if cmd_opts.olive: + interfaces += [(olive_interface, "Olive", "olive")] + interfaces += script_callbacks.ui_tabs_callback() interfaces += [(settings_interface, "Settings", "settings")] diff --git a/requirements_olive.txt b/requirements_olive.txt new file mode 100644 index 00000000000..6205f1f97a8 --- /dev/null +++ b/requirements_olive.txt @@ -0,0 +1,8 @@ +diffusers +transformers +onnx +accelerate +torch==1.13.1 +torchvision==0.14.1 +onnxruntime-directml>=1.15.0 +protobuf==3.20.3 \ No newline at end of file diff --git a/user_script.py b/user_script.py new file mode 100644 index 00000000000..22462fbae98 --- /dev/null +++ b/user_script.py @@ -0,0 +1,132 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import torch +from diffusers import AutoencoderKL, UNet2DConditionModel +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from huggingface_hub import model_info +from transformers.models.clip.modeling_clip import CLIPTextModel + + +# Helper latency-only dataloader that creates random tensors with no label +class RandomDataLoader: + def __init__(self, create_inputs_func, batchsize, torch_dtype): + self.create_input_func = create_inputs_func + self.batchsize = batchsize + self.torch_dtype = torch_dtype + + def __getitem__(self, idx): + label = None + return self.create_input_func(self.batchsize, self.torch_dtype), label + + +def get_base_model_name(model_name): + return model_info(model_name).cardData.get("base_model", model_name) + + +def is_lora_model(model_name): + # TODO: might be a better way to detect (e.g. presence of LORA weights file) + return model_name != get_base_model_name(model_name) + + +# Merges LoRA weights into the layers of a base model +def merge_lora_weights(base_model, lora_model_id, submodel_name="unet", scale=1.0): + from collections import defaultdict + from functools import reduce + + from diffusers.loaders import LORA_WEIGHT_NAME + from diffusers.models.attention_processor import LoRAAttnProcessor + from diffusers.utils import DIFFUSERS_CACHE + from diffusers.utils.hub_utils import _get_model_file + + # Load LoRA weights + model_file = _get_model_file( + lora_model_id, + weights_name=LORA_WEIGHT_NAME, + cache_dir=DIFFUSERS_CACHE, + force_download=False, + resume_download=False, + proxies=None, + local_files_only=False, + use_auth_token=None, + revision=None, + subfolder=None, + user_agent={ + "file_type": "attn_procs_weights", + "framework": "pytorch", + }, + ) + lora_state_dict = torch.load(model_file, map_location="cpu") + + # All keys in the LoRA state dictionary should have 'lora' somewhere in the string. + keys = list(lora_state_dict.keys()) + assert all("lora" in k for k in keys) + + if all(key.startswith(submodel_name) for key in keys): + # New format (https://github.com/huggingface/diffusers/pull/2918) supports LoRA weights in both the + # unet and text encoder where keys are prefixed with 'unet' or 'text_encoder', respectively. + submodel_state_dict = {k: v for k, v in lora_state_dict.items() if k.startswith(submodel_name)} + else: + # Old format. Keys will not have any prefix. This only applies to unet, so exit early if this is + # optimizing the text encoder. + if submodel_name != "unet": + return + submodel_state_dict = lora_state_dict + + # Group LoRA weights into attention processors + attn_processors = {} + lora_grouped_dict = defaultdict(dict) + for key, value in submodel_state_dict.items(): + attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) + lora_grouped_dict[attn_processor_key][sub_key] = value + + for key, value_dict in lora_grouped_dict.items(): + rank = value_dict["to_k_lora.down.weight"].shape[0] + cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] + hidden_size = value_dict["to_k_lora.up.weight"].shape[0] + + attn_processors[key] = LoRAAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank + ) + attn_processors[key].load_state_dict(value_dict) + + # Merge LoRA attention processor weights into existing Q/K/V/Out weights + for name, proc in attn_processors.items(): + attention_name = name[: -len(".processor")] + attention = reduce(getattr, attention_name.split(sep="."), base_model) + attention.to_q.weight.data += scale * torch.mm(proc.to_q_lora.up.weight, proc.to_q_lora.down.weight) + attention.to_k.weight.data += scale * torch.mm(proc.to_k_lora.up.weight, proc.to_k_lora.down.weight) + attention.to_v.weight.data += scale * torch.mm(proc.to_v_lora.up.weight, proc.to_v_lora.down.weight) + attention.to_out[0].weight.data += scale * torch.mm(proc.to_out_lora.up.weight, proc.to_out_lora.down.weight) + + +# ----------------------------------------------------------------------------- +# UNET +# ----------------------------------------------------------------------------- + + +def unet_inputs(batchsize, torch_dtype): + return { + "sample": torch.rand((batchsize, 4, 64, 64), dtype=torch_dtype), + "timestep": torch.rand((batchsize,), dtype=torch_dtype), + "encoder_hidden_states": torch.rand((batchsize, 77, 768), dtype=torch_dtype), + "return_dict": False, + } + + +def unet_load(model_name): + base_model_id = get_base_model_name(model_name) + model = UNet2DConditionModel.from_pretrained(base_model_id, subfolder="unet") + if is_lora_model(model_name): + merge_lora_weights(model, model_name, "unet") + return model + + +def unet_conversion_inputs(model): + return tuple(unet_inputs(1, torch.float32).values()) + + +def unet_data_loader(data_dir, batchsize): + return RandomDataLoader(unet_inputs, batchsize, torch.float16) +