Skip to content

Commit

Permalink
enable loading of INC quantized stable diffusion model (#305)
Browse files Browse the repository at this point in the history
* add inc stable diffusion pipeline

* enable loading of quantized stable diffusion model

* add inc loading stable diffusion

* add model

* update model

* test

* fix inc config loading
  • Loading branch information
echarlaix committed May 2, 2023
1 parent caae4be commit c9512a3
Show file tree
Hide file tree
Showing 13 changed files with 198 additions and 67 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_inc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[neural-compressor,ipex,tests]
pip install .[neural-compressor,ipex,diffusers,tests]
- name: Test with Pytest
run: |
pytest tests/neural_compressor/
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
from pytorch_fid import fid_score
from torch.utils.data import Dataset

from optimum.intel.neural_compressor import INCQuantizer
from optimum.intel.neural_compressor.utils import load_quantized_model
from optimum.intel import INCQuantizer, INCStableDiffusionPipeline
from optimum.intel.utils.constant import DIFFUSION_WEIGHTS_NAME


os.environ["CUDA_VISIBLE_DEVICES"] = ""
Expand Down Expand Up @@ -254,31 +254,30 @@ def eval_func(model):
)

quantization_config = PostTrainingQuantConfig(approach=args.quantization_approach)
pipeline.save_pretrained(args.output_dir)
quantizer = INCQuantizer.from_pretrained(pipeline.unet, calibration_fn=calibration_func)

quantizer.quantize(
quantization_config=quantization_config,
save_directory=args.output_dir,
save_directory=os.path.join(args.output_dir, "unet"),
calibration_dataset=CalibDataset() if args.quantization_approach == "static" else None,
remove_unused_columns=False,
file_name=DIFFUSION_WEIGHTS_NAME,
)

if args.apply_quantization and args.verify_loading:
loaded_model = load_quantized_model(args.output_dir, model=getattr(pipeline, "unet"))
int8_pipeline = INCStableDiffusionPipeline.from_pretrained(args.output_dir)
result_optimized_model = eval_func(quantizer._quantized_model)
result_loaded_model = eval_func(loaded_model)
result_loaded_model = eval_func(int8_pipeline.unet)
if result_loaded_model != result_optimized_model:
logger.error("The quantized model was not successfully loaded.")
else:
logger.info("The quantized model was successfully loaded.")

if args.benchmark and args.int8:
print("====int8 inference====")
loaded_model = load_quantized_model(args.output_dir, model=getattr(pipeline, "unet"))
loaded_model.eval()
setattr(pipeline, "unet", loaded_model)
int8_pipeline = INCStableDiffusionPipeline.from_pretrained(args.output_dir)
generator = torch.Generator("cpu").manual_seed(args.seed)
benchmark(pipeline, generator)
benchmark(int8_pipeline, generator)


def _mp_fn(index):
Expand Down
21 changes: 3 additions & 18 deletions examples/neural_compressor/textual-inversion/text2images.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
import os

import torch
from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel
from neural_compressor.utils.pytorch import load
from PIL import Image
from transformers import CLIPTextModel, CLIPTokenizer

from optimum.intel import INCStableDiffusionPipeline


def parse_args():
Expand Down Expand Up @@ -86,23 +85,9 @@ def generate_images(


args = parse_args()
# Load models and create wrapper for stable diffusion
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")

pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer
)
pipeline = INCStableDiffusionPipeline.from_pretrained(args.output_dir).to(torch.device("cpu"))
pipeline.safety_checker = lambda images, clip_input: (images, False)
if os.path.exists(os.path.join(args.pretrained_model_name_or_path, "best_model.pt")):
unet = load(args.pretrained_model_name_or_path, model=unet)
unet.eval()
setattr(pipeline, "unet", unet)
else:
unet = unet.to(torch.device("cuda", args.cuda_id))
pipeline = pipeline.to(unet.device)
grid, images = generate_images(pipeline, prompt=args.caption, num_images_per_prompt=args.images_num, seed=args.seed)
grid.save(os.path.join(args.pretrained_model_name_or_path, "{}.png".format("_".join(args.caption.split()))))
dirname = os.path.join(args.pretrained_model_name_or_path, "_".join(args.caption.split()))
Expand Down
20 changes: 9 additions & 11 deletions examples/neural_compressor/textual-inversion/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@
from neural_compressor.config import DistillationConfig, IntermediateLayersKnowledgeDistillationLossConfig
from neural_compressor.training import prepare_compression
from neural_compressor.utils import logger
from neural_compressor.utils.pytorch import load
from packaging import version
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer

from optimum.intel import INCStableDiffusionPipeline
from optimum.intel.utils.constant import DIFFUSION_WEIGHTS_NAME


if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
PIL_INTERPOLATION = {
Expand Down Expand Up @@ -949,7 +951,10 @@ def attention_fetcher(x):
compression_manager.callbacks.on_train_end()

# Save the resulting model and its corresponding configuration in the given directory
model.save(args.output_dir)
state_dict = model.state_dict()
if hasattr(model, "q_config"):
state_dict["best_configure"] = model.q_config
torch.save(state_dict, os.path.join(args.output_dir, "unet", DIFFUSION_WEIGHTS_NAME))

logger.info(f"Optimized model saved to: {args.output_dir}.")

Expand Down Expand Up @@ -994,15 +999,8 @@ def attention_fetcher(x):
accelerator.end_training()

if args.do_quantization and args.verify_loading:
# Load the model obtained after Intel Neural Compressor quantization
loaded_model = load(args.output_dir, model=unet)
loaded_model.eval()

setattr(pipeline, "unet", loaded_model)
if args.do_quantization:
pipeline = pipeline.to(torch.device("cpu"))

loaded_model_images = generate_images(pipeline, prompt=prompt, seed=args.seed)
int8_pipeline = INCStableDiffusionPipeline.from_pretrained(args.output_dir).to(torch.device("cpu"))
loaded_model_images = generate_images(int8_pipeline, prompt=prompt, seed=args.seed)
if loaded_model_images != optimized_model_images:
logger.info("The quantized model was not successfully loaded.")
else:
Expand Down
17 changes: 16 additions & 1 deletion optimum/intel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
else:
_import_structure["openvino"].extend(["OVConfig", "OVQuantizer", "OVTrainer", "OVTrainingArguments"])


try:
if not (is_openvino_available() and is_diffusers_available()):
raise OptionalDependencyNotAvailable()
Expand Down Expand Up @@ -109,6 +108,13 @@
"INCSeq2SeqTrainer",
"INCTrainer",
]
try:
if not (is_neural_compressor_available() and is_diffusers_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
_import_structure["utils.dummy_neural_compressor_and_diffusers_objects"] = ["INCStableDiffusionPipeline"]
else:
_import_structure["neural_compressor"].append("INCStableDiffusionPipeline")


if TYPE_CHECKING:
Expand Down Expand Up @@ -175,6 +181,15 @@
INCSeq2SeqTrainer,
INCTrainer,
)

try:
if not (is_neural_compressor_available() and is_diffusers_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_neural_compressor_and_diffusers_objects import INCStableDiffusionPipeline
else:
from .neural_compressor import INCStableDiffusionPipeline

else:
import sys

Expand Down
5 changes: 5 additions & 0 deletions optimum/intel/neural_compressor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ..utils.import_utils import is_diffusers_available
from .configuration import INCConfig
from .quantization import (
INCModel,
Expand All @@ -28,3 +29,7 @@
)
from .trainer import INCTrainer
from .trainer_seq2seq import INCSeq2SeqTrainer


if is_diffusers_available():
from .modeling_diffusion import INCStableDiffusionPipeline
55 changes: 55 additions & 0 deletions optimum/intel/neural_compressor/modeling_diffusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import torch
from diffusers import StableDiffusionPipeline
from neural_compressor.utils.pytorch import load

from ..utils.constant import DIFFUSION_WEIGHTS_NAME, WEIGHTS_NAME
from ..utils.import_utils import _torch_version, is_torch_version
from .configuration import INCConfig


class INCStableDiffusionPipeline(StableDiffusionPipeline):
@classmethod
def from_pretrained(cls, *args, **kwargs):
model = super(INCStableDiffusionPipeline, cls).from_pretrained(*args, low_cpu_mem_usage=False, **kwargs)
components = set(model.config.keys()).intersection({"vae", "text_encoder", "unet"})
for name in components:
component = getattr(model, name, None)
name_or_path = ""
if hasattr(component, "_internal_dict"):
name_or_path = component._internal_dict["_name_or_path"]
elif hasattr(component, "name_or_path"):
name_or_path = component.name_or_path
if os.path.isdir(name_or_path):
folder_contents = os.listdir(name_or_path)
file_name = DIFFUSION_WEIGHTS_NAME if DIFFUSION_WEIGHTS_NAME in folder_contents else WEIGHTS_NAME
state_dict_path = os.path.join(name_or_path, file_name)
if os.path.exists(state_dict_path) and INCConfig.CONFIG_NAME in folder_contents:
msg = None
inc_config = INCConfig.from_pretrained(name_or_path)
if not is_torch_version("==", inc_config.torch_version):
msg = f"Quantized model was obtained with torch version {inc_config.torch_version} but {_torch_version} was found."
state_dict = torch.load(state_dict_path, map_location="cpu")
if "best_configure" in state_dict and state_dict["best_configure"] is not None:
try:
load(state_dict_path, component)
except Exception as e:
if msg is not None:
e.args += (msg,)
raise
return model
15 changes: 7 additions & 8 deletions optimum/intel/neural_compressor/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@
from enum import Enum
from itertools import chain
from pathlib import Path
from typing import TYPE_CHECKING, Callable, ClassVar, Dict, Optional, Union
from typing import Callable, ClassVar, Dict, Optional, Union

import torch
from datasets import Dataset, load_dataset
from huggingface_hub import HfApi, hf_hub_download
from neural_compressor.adaptor.pytorch import PyTorch_FXAdaptor, _cfg_to_qconfig, _propagate_qconfig
from neural_compressor.config import PostTrainingQuantConfig
from neural_compressor.experimental.export import torch_to_int8_onnx
from neural_compressor.model.torch_model import IPEXModel, PyTorchModel
from neural_compressor.quantization import fit
Expand Down Expand Up @@ -57,19 +58,15 @@
from optimum.exporters.onnx import OnnxConfig
from optimum.quantization_base import OptimumQuantizer

from ..utils.constant import _TASK_ALIASES
from ..utils.constant import _TASK_ALIASES, MIN_QDQ_ONNX_OPSET, ONNX_WEIGHTS_NAME, WEIGHTS_NAME
from ..utils.import_utils import (
_neural_compressor_version,
_torch_version,
is_neural_compressor_version,
is_torch_version,
)
from .configuration import INCConfig
from .utils import MIN_QDQ_ONNX_OPSET, ONNX_WEIGHTS_NAME, WEIGHTS_NAME, INCDataLoader, _cfgs_to_fx_cfgs


if TYPE_CHECKING:
from neural_compressor.config import PostTrainingQuantConfig
from .utils import INCDataLoader, _cfgs_to_fx_cfgs


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -141,6 +138,7 @@ def quantize(
batch_size: int = 8,
data_collator: Optional[DataCollator] = None,
remove_unused_columns: bool = True,
file_name: str = None,
**kwargs,
):
"""
Expand All @@ -163,7 +161,7 @@ def quantize(
save_directory = Path(save_directory)
save_directory.mkdir(parents=True, exist_ok=True)
save_onnx_model = kwargs.pop("save_onnx_model", False)
output_path = save_directory.joinpath(WEIGHTS_NAME)
output_path = save_directory.joinpath(file_name or WEIGHTS_NAME)
calibration_dataloader = None

if INCQuantizationMode(quantization_config.approach) == INCQuantizationMode.STATIC:
Expand Down Expand Up @@ -228,6 +226,7 @@ def _save_pretrained(model: Union[PyTorchModel, IPEXModel], output_path: str):
logger.info(f"Model weights saved to {output_path}")
return
state_dict = model._model.state_dict()

if hasattr(model, "q_config"):
state_dict["best_configure"] = model.q_config
torch.save(state_dict, output_path)
Expand Down
3 changes: 1 addition & 2 deletions optimum/intel/neural_compressor/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,9 @@

from optimum.exporters import TasksManager

from ..utils.constant import _TASK_ALIASES
from ..utils.constant import _TASK_ALIASES, MIN_QDQ_ONNX_OPSET, ONNX_WEIGHTS_NAME, TRAINING_ARGS_NAME
from ..utils.import_utils import is_neural_compressor_version
from .configuration import INCConfig
from .utils import MIN_QDQ_ONNX_OPSET, ONNX_WEIGHTS_NAME, TRAINING_ARGS_NAME


if is_apex_available():
Expand Down
8 changes: 4 additions & 4 deletions optimum/intel/neural_compressor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import logging
import os
import warnings
from collections import UserDict
from typing import Dict

Expand All @@ -22,15 +23,13 @@
from packaging import version
from torch.utils.data import DataLoader

from ..utils.constant import WEIGHTS_NAME


logger = logging.getLogger(__name__)


CONFIG_NAME = "best_configure.yaml"
WEIGHTS_NAME = "pytorch_model.bin"
TRAINING_ARGS_NAME = "training_args.bin"
ONNX_WEIGHTS_NAME = "model.onnx"
MIN_QDQ_ONNX_OPSET = 14

parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version)
is_torch_less_than_1_13 = parsed_torch_version_base < version.parse("1.13.0")
Expand Down Expand Up @@ -102,6 +101,7 @@ def load_quantized_model(checkpoint_dir_or_file: str, model: torch.nn.Module, **
model (`torch.nn.Module`):
The original FP32 model.
"""
warnings.warn("This function has been depreciated and will be removed in optimum-intel v1.9.")
if os.path.isdir(checkpoint_dir_or_file):
checkpoint_dir_or_file = os.path.join(
os.path.abspath(os.path.expanduser(checkpoint_dir_or_file)), WEIGHTS_NAME
Expand Down
7 changes: 7 additions & 0 deletions optimum/intel/utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,10 @@
"summarization": "seq2seq-lm",
"translation": "seq2seq-lm",
}


WEIGHTS_NAME = "pytorch_model.bin"
DIFFUSION_WEIGHTS_NAME = "diffusion_pytorch_model.bin"
TRAINING_ARGS_NAME = "training_args.bin"
ONNX_WEIGHTS_NAME = "model.onnx"
MIN_QDQ_ONNX_OPSET = 14
Loading

0 comments on commit c9512a3

Please sign in to comment.